索引、切片、连接、换位
torch.cat()
torch.cat(inputs, dimension=0) → Tensor
在给定维度上对输入张量序列seq进行连接操作.
torch.cat()可以看做 torch.split() 和 torch.chunk()的反操作。 cat() 函数
可以通过下面例子更好的理解
参数:
- inputs (sequence of Tensors) 可以是任意相同的Tensor类型的python序列。
- dimension (int, optional) –沿着此维连接张量序列。
torch.chunk
torch.chunk(tensor, chunks, dim=0)
在给定维度上将输入张量进行分块儿。
参数:
- tensor (Tensor) – 待分块的输入张量
- chunks (int) – 分块的个数
- dim (int) – 沿着此维度进行分块
torch.gather
torch.gather(input, dim, index, out=None) → Tensor
沿给定轴dim,将输入的索引张量index,指定位置的值进行聚合。
参数
- input (Tensor) – 源张量
- dim (int) – 索引的轴
- index (LongTensor) – 聚合元素的下标
- out (Tensor, optional) – 目标张量
torch.index_select
torch.index_select(input, dim, index, out=None) → Tensor
沿着指定维度对输入进行切片,取index中指定的相应项数。((index 为一个 LongTensor)然后返回到一个新的张量,返回的张量与原始张量_Tensor**_有相同的维度(在指定轴上)**
注意:返回的张量不与原始张量共享内存空间**。**
参数:
- input (Tensor) – 输入张量
- dim (int) – 索引的轴
- index (LongTensor) – 包含索引下标的一维张量<