学习目标:掌握torch.cat()使用
torch.cat()函数可以将两个张量根据指定的维度拼接起来,不改变维度数。
data1 = torch.randint(0,10,[1,2,3])
data2 = torch.randint(0,10,[1,2,3])
print('data1--->',data1.shape,data1)
print('data2--->',data2.shape,data2)
# 1. 按0维度拼接
data3 = torch.cat([data1,data2],dim=0)
print('data3--->',data3.shape,data3)
# 2. 按1维度拼接
data4 = torch.cat([data1,data2],dim=1)
print('data4--->',data4.shape,data4)
# 3. 按2维度拼接
data5 = torch.cat([data1,data2],dim=2)
print('data5--->',data5.shape,data5)
输出结果:
data1---> torch.Size([1, 2, 3]) tensor([[[2, 4, 6],
[6, 8, 4]]])
data2---> torch.Size([1, 2, 3]) tensor([[[9, 0, 7],
[9, 2, 7]]])
data3---> torch.Size([2, 2, 3]) tensor([[[2, 4, 6],
[6, 8, 4]],
[[9, 0, 7],
[9, 2, 7]]])
data4---> torch.Size([1, 4, 3]) tensor([[[2, 4, 6],
[6, 8, 4],
[9, 0, 7],
[9, 2, 7]]])
data5---> torch.Size([1, 2, 6]) tensor([[[2, 4, 6, 9, 0, 7],
[6, 8, 4, 9, 2, 7]]])

1923

被折叠的 条评论
为什么被折叠?



