PyTorch使用(6)-张量形状操作

文章目录

  • 1. reshape函数
    • 1.1. 功能与用法
    • 1.2. 特点
  • 2. transpose和permute函数
    • 2.1. transpose
    • 2.2. permute
    • 2.3. 区别
  • 3. view和contiguous函数
    • 3.1. view
    • 3.2. contiguous
    • 3.3. 特点
  • 4. squeeze和unsqueeze函数
    • 4.1. squeeze
    • 4.2. unsqueeze
  • 5. 应用场景
  • 6. 形状操作综合比较
  • 7. 最佳实践建议
  • 8. 总结

1. reshape函数

1.1. 功能与用法

reshape函数可以改变张量的形状而不改变其数据,新形状的元素总数必须与原张量一致。

import torch

x = torch.arange(6)  # tensor([0, 1, 2, 3, 4, 5])

# 改变形状为2x3
y = x.reshape(2, 3)
"""
tensor([[0, 1, 2],
        [3, 4, 5]])
"""

# 自动推断维度大小
z = x.reshape(3, -1)  # -1表示自动计算该维度大小
"""
tensor([[0, 1],
        [2, 3],
        [4, 5]])
"""

1.2. 特点

  • 不改变原始数据,只改变视图

  • 可以处理非连续内存的张量

  • 当无法返回视图时会自动复制数据

2. transpose和permute函数

2.1. transpose

交换两个指定维度:

x = torch.randn(2, 3, 4)

# 交换维度0和1
y = x.transpose(0, 1)  # 形状变为(3, 2, 4)

# 对于2D张量,transpose相当于矩阵转置
matrix = torch.randn(3, 4)
matrix.T == matrix.transpose(0, 1)  # True

2.2. permute

重新排列所有维度:

x = torch.randn(2, 3, 4, 5)

# 重新排列维度顺序
y = x.permute(2, 0, 3, 1)  # 新形状(4, 2, 5, 3)

2.3. 区别

  • transpose只能交换两个维度
  • permute可以任意重新排列所有维度

3. view和contiguous函数

3.1. view

类似于reshape,但要求张量是连续的:

x = torch.arange(6)

# 改变形状
y = x.view(2, 3)

# 会报错的情况
x_non_contiguous = x.t()  # 转置后不连续
try:
    x_non_contiguous.view(6)
except RuntimeError as e:
    print(e)  # 需要连续张量

3.2. contiguous

使张量在内存中连续排列:

x = torch.randn(3, 4).transpose(0, 1)  # 不连续张量

# 转换为连续张量
x_cont = x.contiguous()  # 可能复制数据

# 现在可以使用view
y = x_cont.view(12)

3.3. 特点

  • view比reshape更快,但有限制

  • 转置、切片等操作可能导致不连续

  • 需要view操作前应检查连续性

4. squeeze和unsqueeze函数

4.1. squeeze

移除所有大小为1的维度:

x = torch.randn(1, 3, 1, 2)

# 移除所有大小为1的维度
y = x.squeeze()  # 形状变为(3, 2)

# 只移除指定维度
z = x.squeeze(dim=0)  # 形状变为(3, 1, 2)

4.2. unsqueeze

在指定位置增加大小为1的维度:

x = torch.randn(3, 4)

# 在维度0增加一个维度
y = x.unsqueeze(0)  # 形状变为(1, 3, 4)

# 在维度1增加一个维度
z = x.unsqueeze(1)  # 形状变为(3, 1, 4)

5. 应用场景

  • unsqueeze常用于广播前的维度对齐

  • squeeze常用于移除不必要的单维度

  • 神经网络输入/输出经常需要调整维度

6. 形状操作综合比较

操作是否改变数据是否要求连续适用场景性能
reshape通用形状改变
view快速形状改变
transpose交换两个维度
permute复杂维度重排
squeeze移除单维度
unsqueeze增加单维度

7. 最佳实践建议

  • 优先使用view:当确定张量连续时,view性能更好

  • 注意连续性:复杂操作后使用is_contiguous()检查

  • 维度顺序:保持合理的维度顺序(N,C,H,W等)

  • 避免频繁reshape:多次形状改变可能降低性能

  • 使用-1推断:合理利用-1自动计算维度大小

# 形状操作典型工作流示例
def prepare_input(data):
    # 增加batch维度
    data = data.unsqueeze(0)
    
    # 确保内存连续
    if not data.is_contiguous():
        data = data.contiguous()
        
    # 改变形状为网络输入格式
    return data.view(1, -1, data.size(-1))

8. 总结

  • reshape:用来改变张量的形状,返回一个新的张量。
  • transpose:交换张量的两个维度。
  • permute:按指定的维度顺序重新排列张量的所有维度。
  • view:用来改变张量的形状,要求张量在内存中是连续的。
  • contiguous:确保张量是连续的,可以在需要 view 操作时使用。
  • squeeze:去除张量中维度为1的维度。
  • unsqueeze:在张量的指定位置添加一个维度。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值