torch.
split
(tensor, split_size_or_sections, dim=0)
split_size_or_sections 为切分后的每块大小,不是切分为多少块
import torch
x = torch.randn(1, 2, 4, 4)
y = torch.split(x, 1, dim=1) # 每块大小为1
# print(x[0])
for i in y:
print(i.size())
a = torch.rand(1, 4, 8, 6)
b = torch.split(a, 2, dim=1) # 每块大小为2
for i in b:
print(i.size())