interpolate函数

interpolate函数主要用于对张量(通常是图像数据或特征图)进行上采样(放大)或下采样(缩小)操作,也称为插值操作。

在深度学习中,特别是在处理图像相关的任务时,经常需要改变特征图的尺寸,以匹配不同模块之间的输入输出要求,或者恢复到合适的分辨率用于后续的可视化、计算损失等操作。

基本语法

torch.nn.functional.interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None)

其中scale_factorsize参数是互斥只需要设置一个

scale_factor就是乘,size就是缩放到指定尺寸

示例:size=(4, 4)

import torch
import torch.nn.functional as F
# 创建一个形状为 (1, 1, 2, 2) 的张量
input_tensor = torch.tensor([[[[1.0, 2.0],
                               [3.0, 4.0]]]])

print("原始张量形状:", input_tensor.shape)
print("原始张量内容:\n", input_tensor)
# 上采样到 (4, 4) 的尺寸
upsampled_tensor = F.interpolate(input_tensor, size=(4, 4), mode='bilinear', align_corners=True)

print("上采样后的张量形状:", upsampled_tensor.shape)
print("上采样后的张量内容:\n", upsampled_tensor)

输出:

原始张量形状: torch.Size([1, 1, 2, 2])
原始张量内容:
 tensor([[[[1., 2.],
          [3., 4.]]]])
上采样后的张量形状: torch.Size([1, 1, 4, 4])
上采样后的张量内容:
 tensor([[[[1.0000, 1.3333, 1.6667, 2.0000],
          [1.6667, 2.0000, 2.3333, 2.6667],
          [2.3333, 2.6667, 3.0000, 3.3333],
          [3.0000, 3.3333, 3.6667, 4.0000]]]])

实例:scale_factor=(4, 4)

import torch
import torch.nn.functional as F
# 创建一个形状为 (1, 1, 2, 2) 的张量
input_tensor = torch.tensor([[[[1.0, 2.0],
                               [3.0, 4.0]]]])

print("原始张量形状:", input_tensor.shape)
print("原始张量内容:\n", input_tensor)
# 上采样使用scale_factor = (4, 4) 的尺寸
upsampled_tensor = F.interpolate(input_tensor, scale_factor=(4, 4), mode='bilinear', align_corners=True)

print("上采样后的张量形状:", upsampled_tensor.shape)
print("上采样后的张量内容:\n", upsampled_tensor)

输出:

原始张量形状: torch.Size([1, 1, 2, 2])
原始张量内容:
 tensor([[[[1., 2.],
          [3., 4.]]]])
上采样后的张量形状: torch.Size([1, 1, 8, 8])
上采样后的张量内容:
 tensor([[[[1.0000, 1.1429, 1.2857, 1.4286, 1.5714, 1.7143, 1.8571, 2.0000],
          [1.2857, 1.4286, 1.5714, 1.7143, 1.8571, 2.0000, 2.1429, 2.2857],
          [1.5714, 1.7143, 1.8571, 2.0000, 2.1429, 2.2857, 2.4286, 2.5714],
          [1.8571, 2.0000, 2.1429, 2.2857, 2.4286, 2.5714, 2.7143, 2.8571],
          [2.1429, 2.2857, 2.4286, 2.5714, 2.7143, 2.8571, 3.0000, 3.1429],
          [2.4286, 2.5714, 2.7143, 2.8571, 3.0000, 3.1429, 3.2857, 3.4286],
          [2.7143, 2.8571, 3.0000, 3.1429, 3.2857, 3.4286, 3.5714, 3.7143],
          [3.0000, 3.1429, 3.2857, 3.4286, 3.5714, 3.7143, 3.8571, 4.0000]]]])

另外mode参数不同的mode会有不同效果应用于不同的情况:

  • 'nearest'(最近邻插值):计算最快,输出像素值离散。放大图像易有锯齿,缩小会丢信息,但在对平滑度要求不高、要保留像素离散特性或快速下采样场景有用。
  • 'bilinear'(双线性插值):用于二维数据,考虑周围四个像素加权插值。放大图像过渡自然、细节恢复好,计算复杂度适中,是平衡质量和成本的常用方法,align_corners参数会影响角点对齐效果。
  • 'trilinear'(三线性插值):用于三维数据,类似双线性插值但考虑周围 8 个体素,能保持三维数据细节,但计算成本高,适用于三维图像、建模等场景。

...

align_corners参数:

align_corners=True

输入和输出张量的角点严格对齐。

适用于需要精确空间对齐的任务。

可能导致内部像素的插值不一致。

align_corners=False

输入和输出张量的角点不严格对齐。

插值过程更注重保持比例关系,提供更平滑和一致的结果。

适用于大多数需要调整特征图尺寸的任务,如语义分割、特征图融合等。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值