import torch
w = torch.Tensor(6,10,1)
h = torch.Tensor(6,10,1)
#w和h是大小为6*10*1的张量
将w和h在第二维度上(最后那个1)拼接起来,使用torch.cat()操作。
new = torch.cat(w,h,dim=2)
会出现报错:
TypeError: cat() received an invalid combination of arguments - got (Tensor, Tensor, dim=int), but expected one of:
* (tuple of Tensors tensors, int dim, *, Tensor out)
* (tuple of Tensors tensors, name dim, *, Tensor out)
TypeError: cat() received an invalid combination of arguments - got (Tensor, Tensor, dim=int), but expected one of:
* (tuple of Tensors tensors, int dim, *, Tensor out)
* (tuple of Tensors tensors, name dim, *, Tensor out)
这是因为在cat操作中忘记加括号了,不能多加也不能少加括号。最后new的大小为6*10*2
#报错代码
new = torch.cat(w,h,dim=2)
#正确代码
new = torch.cat((w,h),dim=2)