1. 干嘛的
用来重复一个张量中的元素
2. 怎么用
torch.repeat_interleave(input, repeats, dim=None, *, output_size=None) → Tensor
参数:
- input: 输入的张量(必填)
- repeats: 重复的次数, 可以是整数, 也可以是张量(必填)
- dim: 沿着某个维度进行重复(可选)
- output_size: 和重复次数保持一致, 还不知作用是啥.
官网的的例子
>>> import torch
>>> x = torch.tensor([1, 2, 3])
>>> x.repeat_interleave(2) # 未指定dim参数, 则会把输入的input展平, 再重复其中每个元素
tensor([1, 1, 2, 2, 3, 3])
>>> y = torch.tensor([[1, 2], [3, 4]])
>>> torch.repeat_interleave(y, 2) # 未指定dim参数, 则会把输入的input展平, 再重复其中每个元素
tensor([1, 1, 2, 2, 3, 3, 4, 4])
>>> torch.repeat_interleave(y, 3, dim=1) # 将张量y在dim=1的维度的元素重复3次
tensor([[1, 1, 1, 2, 2, 2],
[3, 3, 3, 4, 4, 4]])
>>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0) # 将张量y在dim=0的维度的元素[1,2]重复1次,元素[3,4]重复2次
tensor([[1, 2],
[3, 4],
[3, 4]])
>>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0, output_size=3) # 将张量y在dim=0的维度的元素[1,2]重复1次,元素[3,4]重复2次, 并设定output_size=重复次数相加
tensor([[1, 2],
[3, 4],
[3, 4]])
个人举例
将张量x5和x6变成相同的维度, 再求欧式距离
>>> x5 = torch.tensor([[[1.0, 2.0, 3.0],
... [2.0, 3.0, 4.0]]]) # x5的维度(1, 2, 3)
>>>
>>> x6 = torch.tensor([[[3.0, 4.0, 5.0],
... [4.0, 5.0, 6.0],
... [5.0, 6.0, 7.0]]]) # x6的维度(1, 3, 3)
>>> # 将x5和x6的维度都变成(1, 6, 3)
>>> x5_repeat = torch.repeat_interleave(x5, len(x6[0]), dim=1)
>>> x6_repeat = torch.repeat_interleave(x6, len(x5[0]), dim=1)
>>> # 对x5_repeat和x6_repeat求其欧氏距离
>>> torch.pairwise_distance(x5_repeat, x6_repeat)
tensor([[3.4641, 3.4641, 5.1962, 3.4641, 5.1962, 5.1962]])
3. 与repeat对比
Tensor.repeat(*sizes) → Tensor
直接对一个张量的元素进行块
级别的重复,例如:
>>> x = torch.tensor([1, 2, 3])
>>> x.repeat(4) # 在dim=0的维度以1,2,3整体为最小单位进行重复
tensor([1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3])
>>> x.repeat(4, 2) # 在dim=0的维度以[1,2,3]整体为最小单位进行重复, 在dim=1的维度以1,2,3整体为最小单位进行重复
tensor([[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3]])
>>> x.repeat(4, 2, 1) # 在dim=0的维度以[[1,2,3],[1,2,3]]整体为最小单位进行重复,
# 在dim=1的维度以[1,2,3]整体为最小单位进行重复,
# 在dim=2的维度以1,2,3整体为最小单位进行重复
tensor([[[1, 2, 3],
[1, 2, 3]],
[[1, 2, 3],
[1, 2, 3]],
[[1, 2, 3],
[1, 2, 3]],
[[1, 2, 3],
[1, 2, 3]]])
>>> x.repeat(4, 2, 1).size()
torch.Size([4, 2, 3])