Pytorch基础 - 7. torch.transpose() 和 torch.permute()

本文介绍了PyTorch中用于张量维度转置的torch.transpose()和torch.permute()函数,详细阐述了它们的用法和示例。torch.transpose()允许交换两个维度,而torch.permute()则可以一次性处理多个维度的转置。在使用这两个函数后,如果要结合view()改变张量形状,必须先调用contiguous()以确保连续性,或者直接使用reshape()替代。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

目录

1. torch.transpose(dim0, dim1) 

2. torch.permute(dims) 

3. 在转置函数之后使用view()


PyTorch中使用torch.transpose() 和 torch.permute()可对张量进行维度的转置,具体内容如下:

1. torch.transpose(dim0, dim1) 

参数: 需要转置的两个维度,即dim0和dim1,顺序可以互换,但一次只能进行两个维度的转换

示例:将shape为[3, 4]的张量,通过transpose()转换成维度为[4, 3]的张量

x = torch.rand(3, 4)
print(x.shape)  # torch.Size([3, 4])
y1 = x.transpose(0, 1)
print(y1.shape)  # torch.Size([4, 3])
y2 = x.transpose(1, 0)
print(y2.shape)  # torch.Size([4, 3])

2. torch.permute(dims) 

参数: dims,可以是list, tuple等,表示进行转置的时候可以一次性进行多维度的转置,且必须传入所有维度数,参数有顺序关系,参数的顺序表示原tensor的哪一维

示例1:进行二维转置

x.permute(1, 0)表示转置后x的shape为[原tensor第1维,原tensor第0维],也就是[4, 3]。更简单的理解就是把原来的维度比作一个数组 list=[3, 4],permute中的数表示数组对应的index,则转置后的x为 [list[1], list[0]] = [4, 3]

x = torch.rand(3, 4)
print(x.shape)  # torch.Size([3, 4])
y1 = x.permute(1, 0)
print(y1.shape)  # torch.Size([4, 3])

示例2:进行多维度同时转置。按照上面数组的方式进行理解,x的原维度表示数组 array = [3, 4, 5, 8],转置后的维度 = [array[1], array[0], array[3], array[2]] = [4, 3, 8, 5]

x = torch.rand(3, 4, 5, 8)
print(x.shape)  # torch.Size([3, 4, 5, 8])
y1 = x.permute(1, 0, 3, 2)
print(y1.shape)  # torch.Size([4, 3, 8, 5])

3. 在转置函数之后使用view()

注意:在使用transpose或permute之后,若要使用view()改变其形状,必须先contiguous()。或者直接使用reshape()直接代替contiguous().view(),具体原因可见本博客PyTorch系列中的 Pytorch基础 - 6. torch.reshape() 和 torch.view()

示例1:直接使用view(),会报错

 示例2:使用contiguous().view(),未报错,且得到了输出结果

x = torch.rand(3, 4)
print(x.shape)  # torch.Size([3, 4])
y1 = x.permute(1, 0)
print(y1.shape)  # torch.Size([4, 3])

z = y1.contiguous().view(-1)
print(z.shape)  # torch.Size([12])
torch.transpose() torch.permute() 都是 PyTorch 中用于变换张量维度的函数,它们的区别在于: torch.transpose() 函数是对张量的转置操作,只支持二维(矩阵)三维(批量矩阵)张量的转置。对于二维张量,可以使用该函数将其转置;对于三维张量,可以指定转置的两个维度进行转置。例如: ``` import torch x = torch.randn(3, 4) print(x) # tensor([[ 2.0514, 1.8634, -0.0375, -1.1761], # [-1.3048, 0.0905, 0.3555, 0.2998], # [-0.5564, 0.4689, -0.8771, 0.9331]]) y = torch.transpose(x, 0, 1) print(y) # tensor([[ 2.0514, -1.3048, -0.5564], # [ 1.8634, 0.0905, 0.4689], # [-0.0375, 0.3555, -0.8771], # [-1.1761, 0.2998, 0.9331]]) z = torch.randn(2, 3, 4) print(z) # tensor([[[ 0.4673, 0.1136, -1.3368, -0.1990], # [-0.5333, 0.0768, -0.6232, -0.8085], # [-1.2438, 1.8073, -0.6008, -1.0195]], # # [[ 0.4232, -0.2118, 1.2122, 0.7345], # [-1.3797, -0.3909, -0.2965, -1.3328], # [-0.8473, -0.6902, 0.1941, -0.8746]]]) w = torch.transpose(z, 0, 1) print(w) # tensor([[[ 0.4673, 0.1136, -1.3368, -0.1990], # [ 0.4232, -0.2118, 1.2122, 0.7345]], # # [[-0.5333, 0.0768, -0.6232, -0.8085], # [-1.3797, -0.3909, -0.2965, -1.3328]], # # [[-1.2438, 1.8073, -0.6008, -1.0195], # [-0.8473, -0.6902, 0.1941, -0.8746]]]) ``` torch.permute() 函数可以对张量的任意维度进行重新排列,可以对任意维度的张量进行操作。例如: ``` import torch x = torch.randn(2, 3, 4) print(x) # tensor([[[-1.1819, 0.6551, 0.3769, -0.7894], # [ 0.2819, -0.5533, -0.2477, -0.8851], # [-0.6033, -1.4273, -0.2664, -0.5853]], # # [[-1.0286, -1.7711, -0.5401, 0.5417], # [-0.4733, 2.4179, 0.0829, 0.7627], # [-0.4668, 0.0573, 0.2071, -0.5846]]]) y = x.permute(2, 0, 1) print(y) # tensor([[[-1.1819, 0.2819, -0.6033], # [-1.0286, -0.4733, -0.4668]], # # [[ 0.6551, -0.5533, -1.4273], # [-1.7711, 2.4179, 0.0573]], # # [[ 0.3769, -0.2477, -0.2664], # [-0.5401, 0.0829, 0.2071]], # # [[-0.7894, -0.8851, -0.5853], # [ 0.5417, 0.7627, -0.5846]]]) ``` 可以看到,torch.permute() 函数可以对张量的任意维度进行重新排列,而 torch.transpose() 函数只能对二维三维张量进行转置。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值