1. cat 进行维度拼接
a = torch.rand(4, 32, 8)
b = torch.rand(5, 32, 8)
c = torch.cat([a, b], dim=0) # 按第0维度进行拼接,除拼接之外的维度必须相同
print(c.shape)
结果:torch.Size([9, 32, 8])
2. stack 产生一个新的维度
a = torch.rand(5, 32, 8)
b = torch.rand(5, 32, 8)
c = torch.stack([a, b], dim=0) # 产生一个新的维度,待拼接的向量维度相同
print(c.shape)
结果:torch.Size([2, 5, 32, 8])
3. split: 按所指定的长度拆分
a = torch.rand(6, 32, 8)
b, c = a.split(3, dim=0) # 所给的是拆分后,每个向量的大小,指定拆分维度
print(b.shape)
print(c.shape)
结果:
torch.Size([3, 32, 8])
torch.Size([3, 32, 8])
4. chuck: 按所给数量进行拆分
a = torch.rand(6, 32, 8)
b, c, d = a.chunk(3, dim=0) # 所给的是拆分的个数,即拆分成多少个
print(b.shape)
print(c.shape)
结果:
torch.Size([2,