-
torch.cat((input1, input2, ... ), dim=?)
torch.cat()可以将多个tensor在dim维度上进行拼接。如下:x1 = torch.tensor([[11,21,31],[21,31,41]],dtype=torch.int) x2 = torch.tensor([[12,22,32],[22,32,42]],dtype=torch.int) cat1 = torch.cat((x1, x2),0) cat2 = torch.cat((x1, x2),1) print(cat1,cat2) # 输出 tensor([[11, 21, 31], [21, 31, 41], [12, 22, 32], [22, 32, 42]], dtype=torch.int32) tensor([[11, 21, 31, 12, 22, 32], [21, 31, 41, 22, 32, 42]], dtype=torch.int32) -
.shape[i]
shape函数的功能是读取tensor某个维度的长度
对于图像来说:
image.shape[0]——图片高
image.shape[1]——图片长
image.shape[2]——图片通道数
而对于矩阵来说:
shape[0]:表示矩阵的行数
shape[1]:表示矩阵的列数
注:-1代表最后一个,所以shape[-1]代表最后一个维度,如在二维张量里,shape[-1]表示列数
x1 = torch.tensor([[11,21,31],[21,31,41]],dtype=torch.int)
print(x1.shape[0])
print(x1.shape[1])
print(x1.shape[-1])
# 输出
2
3
3
博客介绍了PyTorch中torch.cat()函数和shape函数的使用。torch.cat()可将多个tensor在指定维度拼接,shape函数能读取tensor某个维度的长度,还分别说明了在图像和矩阵场景下不同维度索引对应的含义,如-1代表最后一个维度。
837

被折叠的 条评论
为什么被折叠?



