构造
构造一个[0]*x0+[1]*x1+[2]*x2…
x = [1024, 256, 512, 4096, 8192, 128, 64, 32]
torch.tensor(np.repeat(np.arange(len(x)), x)) # 这个快
torch.repeat_interleave(torch.tensor(x)) # 插入式repeat,和torch.repeat不同
索引
从一个token序列中取出指定index的子序列
tensor(bs, seq_len, hidden_size)
index(bs, num_selecet)
torch.gather(tensor, 1, index.unsqueeze(-1).repeat(1,1,tensor.shape[-1]))
1万+

被折叠的 条评论
为什么被折叠?



