代码
import numpy as np
import torch
a = np.arange(24).reshape(2,3,4)
b = a[:,-1,:]
print(b.shape)
x = torch.from_numpy(a)
y = x[:,-1,:].unsqueeze(1)
print(y.size())
z = y.expand(2,4,4)
print(z.size())
输出
(2, 4)
torch.Size([2, 1, 4])
torch.Size([2, 4, 4])