interpolate
函数主要用于对张量(通常是图像数据或特征图)进行上采样(放大)或下采样(缩小)操作,也称为插值操作。
在深度学习中,特别是在处理图像相关的任务时,经常需要改变特征图的尺寸,以匹配不同模块之间的输入输出要求,或者恢复到合适的分辨率用于后续的可视化、计算损失等操作。
基本语法:
torch.nn.functional.interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None)
其中scale_factor
和size
参数是互斥只需要设置一个
scale_factor就是乘,size就是缩放到指定尺寸
示例:size=(4, 4)
import torch
import torch.nn.functional as F
# 创建一个形状为 (1, 1, 2, 2) 的张量
input_tensor = torch.tensor([[[[1.0, 2.0],
[3.0, 4.0]]]])
print("原始张量形状:", input_tensor.shape)
print("原始张量内容:\n", input_tensor)
# 上采样到 (4, 4) 的尺寸
upsampled_tensor = F.interpolate(input_tensor, size=(4, 4), mode='bilinear', align_corners=True)
print("上采样后的张量形状:", upsampled_tensor.shape)
print("上采样后的张量内容:\n", upsampled_tensor)
输出:
原始张量形状: torch.Size([1, 1, 2, 2])
原始张量内容:
tensor([[[[1., 2.],
[3., 4.]]]])
上采样后的张量形状: torch.Size([1, 1, 4, 4])
上采样后的张量内容:
tensor([[[[1.0000, 1.3333, 1.6667, 2.0000],
[1.6667, 2.0000, 2.3333, 2.6667],
[2.3333, 2.6667, 3.0000, 3.3333],
[3.0000, 3.3333, 3.6667, 4.0000]]]])
实例:scale_factor=(4, 4)
import torch
import torch.nn.functional as F
# 创建一个形状为 (1, 1, 2, 2) 的张量
input_tensor = torch.tensor([[[[1.0, 2.0],
[3.0, 4.0]]]])
print("原始张量形状:", input_tensor.shape)
print("原始张量内容:\n", input_tensor)
# 上采样使用scale_factor = (4, 4) 的尺寸
upsampled_tensor = F.interpolate(input_tensor, scale_factor=(4, 4), mode='bilinear', align_corners=True)
print("上采样后的张量形状:", upsampled_tensor.shape)
print("上采样后的张量内容:\n", upsampled_tensor)
输出:
原始张量形状: torch.Size([1, 1, 2, 2])
原始张量内容:
tensor([[[[1., 2.],
[3., 4.]]]])
上采样后的张量形状: torch.Size([1, 1, 8, 8])
上采样后的张量内容:
tensor([[[[1.0000, 1.1429, 1.2857, 1.4286, 1.5714, 1.7143, 1.8571, 2.0000],
[1.2857, 1.4286, 1.5714, 1.7143, 1.8571, 2.0000, 2.1429, 2.2857],
[1.5714, 1.7143, 1.8571, 2.0000, 2.1429, 2.2857, 2.4286, 2.5714],
[1.8571, 2.0000, 2.1429, 2.2857, 2.4286, 2.5714, 2.7143, 2.8571],
[2.1429, 2.2857, 2.4286, 2.5714, 2.7143, 2.8571, 3.0000, 3.1429],
[2.4286, 2.5714, 2.7143, 2.8571, 3.0000, 3.1429, 3.2857, 3.4286],
[2.7143, 2.8571, 3.0000, 3.1429, 3.2857, 3.4286, 3.5714, 3.7143],
[3.0000, 3.1429, 3.2857, 3.4286, 3.5714, 3.7143, 3.8571, 4.0000]]]])
另外mode参数不同的mode会有不同效果应用于不同的情况:
'nearest'
(最近邻插值):计算最快,输出像素值离散。放大图像易有锯齿,缩小会丢信息,但在对平滑度要求不高、要保留像素离散特性或快速下采样场景有用。'bilinear'
(双线性插值):用于二维数据,考虑周围四个像素加权插值。放大图像过渡自然、细节恢复好,计算复杂度适中,是平衡质量和成本的常用方法,align_corners
参数会影响角点对齐效果。'trilinear'
(三线性插值):用于三维数据,类似双线性插值但考虑周围 8 个体素,能保持三维数据细节,但计算成本高,适用于三维图像、建模等场景。
...
align_corners参数:
align_corners=True
:输入和输出张量的角点严格对齐。
适用于需要精确空间对齐的任务。
可能导致内部像素的插值不一致。
align_corners=False
:输入和输出张量的角点不严格对齐。
插值过程更注重保持比例关系,提供更平滑和一致的结果。
适用于大多数需要调整特征图尺寸的任务,如语义分割、特征图融合等。