pytorch nearest upsample整数型tensor

本文详细比较了如何用自定义模块实现与PyTorch原生Upsample类别相似的功能,特别关注nearest插值,并探讨了在确保模型可重复性方面的注意事项。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

在用 torch.nn.Upsample 给分割 label 上采样时报错:RuntimeError: "upsample_nearest2d_out_frame" not implemented for 'Long'

参考 [1-3],用 [3] 给出的实现。稍微扩展一下,支持 h、w 用不同的 scale factor,并测试其与 PyTorch 的几个 upsample 类的异同,验证 [3] 的实现用 nearest 插值。

Code

  • linear 要 3D 输入、trilinear 要 5D 输入,故此两种插值法没比。
import torch
import torch.nn as nn


class UpsampleDeterministic(nn.Module):
    """deterministic upsample with `nearest` interpolation"""

    def __init__(self, scale_factor=2):
        """
        Input:
            scale_factor: int or (int, int), ratio to scale (along heigth & width)
        """
        super(UpsampleDeterministic, self).__init__()
        if isinstance(scale_factor, (tuple, list)):
            assert len(scale_factor) == 2
            self.scale_h, self.scale_w = scale_factor
        else:
            self.scale_h = self.scale_w = scale_factor
        assert isinstance(self.scale_h, int) and isinstance(self.scale_w, int)

    def forward(self, x):
        """
        Input:
            x: [n, c, h, w], torch.Tensor
        Output:
            upsampled x': [n, c, h * scale_h, w * scale_w]
        """
        return x[:, :, :, None, :, None].expand(
            -1, -1, -1, self.scale_h, -1, self.scale_w).reshape(
                x.size(0), x.size(1), x.size(2) * self.scale_h, x.size(3) * self.scale_w)


# 随机数据
x = torch.rand(2, 3, 4, 4) # [n, c, h, w]
# [3] 的实现
us_det = UpsampleDeterministic((2, 3))
# pytorch 自带的几种实现
us_list = {mode: nn.Upsample(scale_factor=(2, 3), mode=mode)
           for mode in ('nearest', 'bilinear', 'bicubic')}
# linear: 3D
# trilinear: 5D

y_det = us_det(x)
print(y_det.size())
for us_name, us in us_list.items():
    y = us(x)
    print(us_name, y.size(), (y_det != y).sum())

输出:

torch.Size([2, 3, 8, 12])
nearest torch.Size([2, 3, 8, 12]) tensor(0)
bilinear torch.Size([2, 3, 8, 12]) tensor(507)
bicubic torch.Size([2, 3, 8, 12]) tensor(576)

可见 [3] 的实现与 nearest 结果一致。

References

  1. 请慎用torch.nn.Upsample
  2. PyTorch中模型的可复现性
  3. Non Deterministic Behaviour even after cudnn.deterministic = True and cudnn.benchmark=False #12207
<think>好的,用户想了解Upsample方法的使用教程或实现方式。首先,我需要回顾一下Upsample的基本概念。Upsample,也就是上采样,是一种在深度学习,尤其是图像处理中常用的技术,用于增加图像的分辨率或尺寸。常见的应用包括图像超分辨率、语义分割等。 接下来,用户可能希望了解具体的实现方法。根据提供的引用内容,引用[3]提到了PaddlePaddle中的Upsample类,包括不同的插值方法,如最近邻和双线性插值。我需要详细解释这些方法,并给出具体的代码示例。另外,用户提到的引用[1]和[2]虽然主要关于图像修复和YOLO训练,但其中也涉及到了上采样的应用场景,比如在CodeFormer中使用RealESRGAN进行背景上采样,这可能作为实际案例来补充说明。 用户可能需要分步骤的教程,比如如何在不同框架中实现上采样。考虑到引用[3]是PaddlePaddle的文档,我应该重点介绍Paddle的Upsample模块,并比较其他框架如PyTorch的实现方式,这样用户可以有更全面的了解。此外,需要注意解释参数如size、scale_factor、mode的选择对结果的影响,以及align_corners等参数的作用。 还需要考虑用户可能的疑问点,比如不同插值方法的区别,何时使用最近邻,何时使用双线性,以及如何在实际项目中应用上采样。例如,在图像修复中,上采样如何帮助提高图像质量,或者目标检测中如何通过上采样恢复特征图尺寸。 最后,根据系统指令,需要生成相关问题,这些问题应该围绕上采样的原理、不同方法对比、参数调优和应用场景展开,帮助用户深入理解。同时,确保所有数学公式如插值法用正确的LaTeX格式,比如双线性插值的计公式可能需要用$$...$$单独展示,并正确引用相关文献。</think>### Upsample 方法使用教程及实现方式 #### 一、Upsample 概念 Upsample上采样)是深度学习中用于增加数据空间分辨率的方法,常用于图像超分辨率、语义分割等任务。其核心目标是将低分辨率特征图通过插值或反卷积等方式恢复为高分辨率数据[^3]。 #### 二、实现方式及代码示例 ##### 1. **PaddlePaddle 框架** PaddlePaddle 提供 `nn.Upsample` 类,支持多种插值方法: ```python import paddle import paddle.nn as nn # 定义上采样upsample = nn.Upsample( scale_factor=2, # 缩放因子(若 size=None) mode='bilinear', # 插值方式:'nearest', 'bilinear', 'bicubic'等 align_corners=False # 对齐像素角点 ) # 输入数据(假设为 [batch_size, channels, height, width]) input_tensor = paddle.randn([1, 3, 32, 32]) output = upsample(input_tensor) # 输出尺寸变为 [1, 3, 64, 64] ``` 参数说明: - `size`:目标输出尺寸(如 `(64,64)`) - `scale_factor`:尺寸缩放倍数(与 `size` 二选一) - `mode`:插值法,影响计速度和效果[^3]。 ##### 2. **PyTorch 实现对比** PyTorch 中对应接口为 `torch.nn.Upsample`: ```python import torch import torch.nn as nn upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) input_tensor = torch.randn(1, 3, 32, 32) output = upsample(input_tensor) ``` #### 三、数学原理 以双线性插值(Bilinear Interpolation)为例,其计公式为: $$ f(x,y) \approx \frac{y_2 - y}{y_2 - y_1} \left( \frac{x_2 - x}{x_2 - x_1}f(Q_{11}) + \frac{x - x_1}{x_2 - x_1}f(Q_{12}) \right) + \frac{y - y_1}{y_2 - y_1} \left( \frac{x_2 - x}{x_2 - x_1}f(Q_{21}) + \frac{x - x_1}{x_2 - x_1}f(Q_{22}) \right) $$ 其中 $Q_{ij}$ 表示邻近像素点值[^3]。 #### 四、应用场景与最佳实践 1. **图像修复**:如 CodeFormer 使用 RealESRGAN 对背景进行上采样增强[^1]。 2. **目标检测**:在 YOLOv5 中,上采样用于恢复特征图尺寸以提升小目标检测精度[^2]。 3. **语义分割**:通过上采样将编码器输出的低分辨率特征图还原为输入尺寸。 #### 五、选择插值方法的建议 | 方法 | 速度 | 适用场景 | |---------------|--------|------------------------------| | 最近邻插值 | 最快 | 实时性要求高、不追求精细效果 | | 双线性插值 | 中等 | 通用场景(平衡速度与质量) | | 双三次插值 | 较慢 | 需要高质量输出的任务 |
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值