在PyTorch中,in_place
操作是一种直接更改给定张量的内容而不进行复制的操作。例如,使用x.add_()
或x.sub()
等后缀为_
的操作将改变x本身,而不是创建一个新的张量。这种方法可以帮助节省GPU显存 。
使用 𝐓𝐞𝐧𝐬𝐨𝐫 初始化一个 𝟏 × 𝟑 的矩阵 𝑴 和一个 𝟐 × 𝟏 的矩阵 N,然后利用in_place对 M 和 N 进行减法操作。如果不改变 M 和 N 的形状直接进行in_place操作,报错如下:output with shape [1, 3] doesn't match the broadcast shape [2, 3]
对于两个不同形状的tensor,需要先利用.view()函数将二者形状统一,再进行in_place操作。这里我利用torch.cat()函数对tensor进行拼接再计算。
import torch
M = torch.rand(1, 3)
N = torch.rand(2, 1)
M = torch.cat((M, M), 0)
N = torch.cat((N, N, N), 1)
M.sub_(N)
print(M)