张量的拼接
(1) 使用 torch.cat() 拼接
torch.cat() 将张量按维度 dim 进行拼接,不会扩张张量的维度。
torch.cat(tensors, dim=0, out=None)
其中,
- tensors:张量序列
- dim:要拼接的维度
import torch
t = torch.ones((3, 2))
print(t)
t0 = torch.cat([t, t], dim=0) # 在第0个维度上拼接
t1 = torch.cat([t, t], dim=1) # 在第1个维度上拼接
print(t0, '\n\n', t1)
运行结果:
tensor([[1., 1.], [1., 1.], [1., 1.]])
tensor([[1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.]])
tensor([[1., 1., 1., 1.], [1., 1., 1., 1.], [1., 1., 1., 1.]])
t2 = torch.cat([t, t, t], dim=0)
print(t2)
运行结果:
tensor([[1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.]])
(2) 使用 torch.stack() 拼接
torch.stack() 在新创建的维度上进行拼接,会扩张张量的维度。
torch.stack(tensors, dim=0, out=None)
参数:
- tensors:张量序列
- dim:要拼接的维度
t = torch.ones((3, 2))
t1 = torch.stack([t, t], dim=2) # 在新创建的维度上进行拼接。
print(t1, t1.shape) # 拼接完会从2维变成3维
tensor([[[1., 1.], [1., 1.]],
[[1., 1.], [1., 1.]],
[[1., 1.], [1., 1.]]]) torch.Size([3, 2, 2])
我们可以看到维度从拼接前的(3, 2)变成了(3, 2, 2),即在最后的维度上进行了拼接。
张量的切分
(1) 使用 torch.chunk() 切分
torch.chunk() 可以将张量按维度dim进行平均切分,return 张量列表。如果不能整除,最后一份张量小于其他张量。
torch.chunk(input, chunks, dim=0)
参数:
- input:要切分的张量
- chunks:要切分的份数
- dim:要切分的维度
a = torch.ones((5, 2))
t = torch.chunk(a, dim=0, chunks=2) # 在这5个维度切分
for idx, t_chunk in enumerate(t):
print(idx, t_chunk, t_chunk.shape)
运行结果:
0 tensor([[1., 1.], [1., 1.], [1., 1.]]) torch.Size([3, 2])
1 tensor([[1., 1.], [1., 1.]]) torch.Size([2, 2])
可以看出后一个张量小于前一个张量的,前者第0个维度上是3,后者是2。
(2) 使用 torch.split() 切分
torch.split() 将张量按维度dim进行切分,return:张量列表
torch.split(tensor, spilt_size_or_sections, dim=0)
参数:
- tensor:要切分的张量
- split_size_or_sections:
- 为int时,表示每一份的长度;
- 为list时,按list元素切分
- dim:要切分的维度
a = torch.ones((5, 2))
t = torch.split(a, 2, dim=0) # 指定每个张量长度为2
for idx, t_split in enumerate(t):
print(idx, t_split, t_split.shape)
运行结果:
0 tensor([[1., 1.], [1., 1.]]) torch.Size([2, 2]) 1 tensor([[1., 1.], [1., 1.]]) torch.Size([2, 2]) 2 tensor([[1., 1.]]) torch.Size([1, 2])
a = torch.ones((5, 2))
t = torch.split(a, [2, 1, 2], dim=0) # 指定每个张量长度为列表大小[2,1,2]
for idx, t_split in enumerate(t):
print(idx, t_split, t_split.shape)
运行结果:
0 tensor([[1., 1.], [1., 1.]]) torch.Size([2, 2]) 1 tensor([[1., 1.]]) torch.Size([1, 2]) 2 tensor([[1., 1.], [1., 1.]]) torch.Size([2, 2])
a = torch.ones((5, 2))
t = torch.split(a, [2, 1, 1], dim=0) # list中求和部位长度将抛出异常
for idx, t_split in enumerate(t):
print(idx, t_split, t_split.shape)
RuntimeError: split_with_sizes expects split_sizes to sum exactly to 5 (input tensor’s size at dimension 0), but got split_sizes=[2, 1, 1]