【PyTorch】PyTorch 中改变张量形状的几种方法

PyTorch 中改变张量形状的几种方法

在深度学习领域,PyTorch 是一个广泛使用的框架,它提供了丰富的API来处理张量(tensor)。在模型开发过程中,我们经常需要改变张量的形状以满足特定的需求。本文将介绍在 PyTorch 中改变张量形状的几种方法,并给出推荐的使用场景。比如:我们想合并一个张量的最后两个维度。

一、方法

1. 使用 reshape 方法

reshape 方法可以改变张量的形状而不改变其数据。这是最常用的方法之一,因为它不要求原始张量在内存中是连续的。

import torch
# 创建一个随机初始化的张量
keycache = torch.rand([21923, 16, 1, 128])
# 使用 reshape 改变形状
keycache_reshaped = keycache.reshape(keycache.size(0), keycache.size(1), -1)
print(keycache_reshaped.shape)

在上面的代码中,我们通过指定前两个维度的大小,并使用 -1 自动计算最后一个维度的大小,来改变张量的形状。

2. 使用 view 方法

view 方法与 reshape 类似,但它要求原始张量在内存中是连续的。如果张量是连续的,view 可以更高效地工作。

# 使用 view 改变形状
keycache_reshaped = keycache.view(keycache.size(0), keycache.size(1), -1)
print(keycache_reshaped.shape)

二、技巧

1. 解包获取维度大小

可以通过解包操作直接从张量的 size 属性中获取维度的大小,然后使用这些值来改变形状。

# 使用解包操作获取维度大小并改变形状
# 使用 _ 来忽略不需要的维度,因为这里我们只关心前两个维度。
n, m, _, _ = keycache.size()
keycache_reshaped = keycache.reshape(n, m, -1)
print(keycache_reshaped.shape)

这种方法在代码中更简洁,并且当只需要部分维度的大小时非常有用。

2. 切片获取维度大小

另一种简洁的方法是使用切片解包来获取维度大小,然后再使用 reshape
这里的 * 操作符用于解包 keycache.shape[:2] 这个元组,将元组中的元素作为独立的参数传递给 reshape 方法。其中前两个维度保持不变,最后一个维度由 -1 自动计算,以保持元素总数不变。

# 使用切片和 reshape 改变形状
keycache_reshaped = keycache.reshape(*keycache.shape[:2], -1)
print(keycache_reshaped.shape)

这种方法不仅代码更简洁,而且易于理解。

三、推荐

选择哪种方法取决于你的具体需求。如果你不确定张量是否在内存中连续,或者不关心性能,那么 reshape 方法是一个更安全的选择。如果你确信张量是连续的,并且需要最优性能,那么 view 方法可能是最佳选择。
总之,这几种方法各有千秋,你可以根据实际情况和个人偏好来选择使用。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值