Pytorch 零基础学习系列 之 张量的拼接与切分

张量的拼接

(1) 使用 torch.cat() 拼接

torch.cat() 将张量按维度 dim 进行拼接,不会扩张张量的维度。

torch.cat(tensors, dim=0, out=None)

其中,

  • tensors:张量序列
  • dim:要拼接的维度
import torch

t = torch.ones((3, 2))
print(t)
t0 = torch.cat([t, t], dim=0) # 在第0个维度上拼接
t1 = torch.cat([t, t], dim=1) # 在第1个维度上拼接
print(t0, '\n\n', t1)

运行结果:

tensor([[1., 1.], [1., 1.], [1., 1.]])

tensor([[1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.]])

tensor([[1., 1., 1., 1.], [1., 1., 1., 1.], [1., 1., 1., 1.]])

t2 = torch.cat([t, t, t], dim=0)
print(t2)

运行结果:

tensor([[1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.]])

(2) 使用 torch.stack() 拼接

torch.stack() 在新创建的维度上进行拼接,会扩张张量的维度。

torch.stack(tensors, dim=0, out=None)

参数:

  • tensors:张量序列
  • dim:要拼接的维度
t = torch.ones((3, 2))
t1 = torch.stack([t, t], dim=2) # 在新创建的维度上进行拼接。
print(t1, t1.shape) # 拼接完会从2维变成3维

tensor([[[1., 1.], [1., 1.]],

[[1., 1.], ​ [1., 1.]],

[[1., 1.], ​ [1., 1.]]]) torch.Size([3, 2, 2])

我们可以看到维度从拼接前的(3, 2)变成了(3, 2, 2),即在最后的维度上进行了拼接。

张量的切分

(1) 使用 torch.chunk() 切分

torch.chunk() 可以将张量按维度dim进行平均切分,return 张量列表。如果不能整除,最后一份张量小于其他张量。

torch.chunk(input, chunks, dim=0)

参数:

  • input:要切分的张量
  • chunks:要切分的份数
  • dim:要切分的维度
a = torch.ones((5, 2))
t = torch.chunk(a, dim=0, chunks=2) # 在这5个维度切分
for idx, t_chunk in enumerate(t):
    print(idx, t_chunk, t_chunk.shape)

运行结果:

0 tensor([[1., 1.], [1., 1.], [1., 1.]]) torch.Size([3, 2])
1 tensor([[1., 1.], [1., 1.]]) torch.Size([2, 2])

可以看出后一个张量小于前一个张量的,前者第0个维度上是3,后者是2。

(2) 使用 torch.split() 切分

torch.split() 将张量按维度dim进行切分,return:张量列表

torch.split(tensor, spilt_size_or_sections, dim=0)

参数:

  • tensor:要切分的张量
  • split_size_or_sections:
    • 为int时,表示每一份的长度;
    • 为list时,按list元素切分
  • dim:要切分的维度
a = torch.ones((5, 2))
t = torch.split(a, 2, dim=0) # 指定每个张量长度为2
for idx, t_split in enumerate(t):
    print(idx, t_split, t_split.shape)

运行结果:

0 tensor([[1., 1.], [1., 1.]]) torch.Size([2, 2]) 1 tensor([[1., 1.], [1., 1.]]) torch.Size([2, 2]) 2 tensor([[1., 1.]]) torch.Size([1, 2])

a = torch.ones((5, 2))
t = torch.split(a, [2, 1, 2], dim=0) # 指定每个张量长度为列表大小[2,1,2]
for idx, t_split in enumerate(t):
    print(idx, t_split, t_split.shape)

运行结果:

0 tensor([[1., 1.], [1., 1.]]) torch.Size([2, 2]) 1 tensor([[1., 1.]]) torch.Size([1, 2]) 2 tensor([[1., 1.], [1., 1.]]) torch.Size([2, 2])

a = torch.ones((5, 2))
t = torch.split(a, [2, 1, 1], dim=0) # list中求和部位长度将抛出异常
for idx, t_split in enumerate(t):
    print(idx, t_split, t_split.shape)

RuntimeError: split_with_sizes expects split_sizes to sum exactly to 5 (input tensor’s size at dimension 0), but got split_sizes=[2, 1, 1]

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值