在 PyTorch 中,upsample
函数(实际上更准确地说是nn.Upsample
类)主要用于对输入数据(通常是张量形式的图像或特征图)进行上采样操作,即放大数据的尺寸。
基本语法:
torch.nn.Upsample(size=None, scale_factor=None, mode='nearest', align_corners=None)
其中参数设置和
interpolate类似
实例:
import torch
import torch.nn as nn
# 创建一个形状为 (1, 1, 2, 2) 的张量
input_tensor = torch.tensor([[[[1.0, 2.0],
[3.0, 4.0]]]])
print("原始张量形状:", input_tensor.shape)
print("原始张量内容:\n", input_tensor)
# 定义一个 Upsample 层,使用双线性插值,放大2倍
upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
# 应用上采样
output_tensor = upsample(input_tensor)
print("上采样后的张量形状:", output_tensor.shape)
print("上采样后的张量内容:\n", output_tensor)
输出:scale_factor=2
原始张量形状: torch.Size([1, 1, 2, 2])
原始张量内容:
tensor([[[[1., 2.],
[3., 4.]]]])
上采样后的张量形状: torch.Size([1, 1, 4, 4])
上采样后的张量内容:
tensor([[[[1.0000, 1.3333, 1.6667, 2.0000],
[1.6667, 2.0000, 2.3333, 2.6667],
[2.3333, 2.6667, 3.0000, 3.3333],
[3.0000, 3.3333, 3.6667, 4.0000]]]])