PyTorch reshape函数介绍

torch.reshape 是 PyTorch 用于改变张量形状的函数之一。它不会改变张量的数据,而是重新组织其元素以适应新的形状。


reshape 的使用

torch.reshape(input, shape) → Tensor
  • input:输入张量。
  • shape:新形状,使用整数或 -1 指定各维度大小。
    • -1 表示自动推断该维度大小,使总元素数保持不变。
示例
import torch

# 创建一个形状为 (2, 3) 的张量
x = torch.arange(6).view(2, 3)

# 使用 reshape 改变形状为 (3, 2)
y = torch.reshape(x, (3, 2))

print(y)
# 输出:
# tensor([[0, 1],
#         [2, 3],
#         [4, 5]])

使用 -1 自动推断

z = torch.reshape(x, (-1, 2))
print(z)
# 输出:
# tensor([[0, 1],
#         [2, 3],
#         [4, 5]])

与其他张量形状改变函数的区别

1. view
  • 特点view 也用于改变张量形状,但它要求输入张量在内存中是连续的。
  • 限制:如果张量不是连续的(即非 contiguous),使用 view 会报错,需要先调用 contiguous 方法。
  • 示例
x = torch.arange(6).view(2, 3)
y = x.view(3, 2)  # 可以直接使用

x = x.T  # 转置操作使张量变为非连续
y = x.view(3, 2)  # 会报错
2. permute
  • 特点:用于交换张量的维度,而不是改变形状。
  • 用途:适用于维度重新排列。
x = torch.rand(2, 3, 4)
y = x.permute(1, 0, 2)  # 改变维度顺序
3. resize_
  • 特点:修改张量形状,可能破坏原始数据,慎用。
  • 用途:多用于临时调整张量形状,不推荐在计算中使用。
4. squeeze / unsqueeze
  • 特点
    • squeeze:移除长度为 1 的维度。
    • unsqueeze:添加长度为 1 的维度。
  • 示例
x = torch.rand(1, 3, 1, 4)
y = x.squeeze()  # 去掉长度为 1 的维度
z = x.unsqueeze(2)  # 在第 2 个位置添加一个长度为 1 的维度
5. flatten
  • 特点:将多维张量展平为一维张量,或在指定维度范围内展平。
  • 用途:简化张量为线性输入。
  • 示例
    x = torch.rand(2, 3, 4)
    y = torch.flatten(x)  # 展平为 1D
    z = torch.flatten(x, start_dim=1)  # 从第 1 维开始展平
    print(z.shape)  # torch.Size([2, 12])

    reshape 的优势

  • 灵活性:不需要张量是连续的。
  • 安全性:自动处理非连续张量(相比 view)。
  • 性能:通常不会引入额外开销,尤其在连续内存情况下。
reshape 与 view 的选择
  • 如果确定张量是连续的,可用 view 提高性能。
  • 如果不确定张量是否连续,使用 reshape 更安全。

以下函数在改变张量形状或维度时不会破坏原始数据:

  • reshape
  • view(前提是张量连续)
  • permute
  • transpose
  • squeeze / unsqueeze
  • flatten
  • contiguous

这些操作只会影响数据的组织形式或内存布局,而不会修改数据本身。

总结

  • reshape 是 PyTorch 中改变张量形状的通用函数,灵活且易用。
  • 与其他形状操作函数(如 viewpermutesqueeze 等)的主要区别在于适用场景和对张量内存布局的要求。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值