reshape函数:重塑形状
torch.reshape(input, shape) → Tensor
返回一个重塑形状的张量,该张量与输入张量数据和元素数量相同。
如果可能,返回的张量将是输入的视图。否则,它将是一个副本。
>>> a = torch.arange(4.)
>>> torch.reshape(a, (2, 2))
tensor([[ 0., 1.],
[ 2., 3.]])
>>> b = torch.tensor([[0, 1], [2, 3]])
>>> torch.reshape(b, (-1,))
tensor([ 0, 1, 2, 3])
flatten函数:展平
torch.flatten(input, start_dim=0, end_dim=-1) → Tensor
展平为一维张量。
如果有start_dim或end_dim参数,则仅展平以start_dim开始并以end_dim结束的标注,元素顺序保持不变。
>>> t = torch.tensor([[[1, 2],
... [3, 4]],
... [[5, 6],
... [7, 8]]])
>>> torch.flatten(t)
tensor([1, 2, 3, 4, 5, 6, 7, 8])
>>> torch.flatten(t, start_dim=1)
tensor([[1, 2, 3, 4],
[5, 6, 7, 8]])
transpose函数:转置
torch.transpose(input, dim0, dim1) → Tensor
返回输入张量的转置结果,dim0和dim1交换。
>>> x = torch.randn(2, 3)
>>> x
tensor([[ 1.0028, -0.9893, 0.5809],
[-0.1669, 0.7299, 0.4942]])
>>> torch.transpose(x, 0, 1)
tensor([[ 1.0028, -0.1669],
[-0.9893, 0.7299],
[ 0.5809, 0.4942]])
permute函数:维度交换
torch.permute(input, dims) → Tensor
返回输入张量的视图,并对输入张量维度进行重新排列。
>>> x = torch.randn(2, 3, 5)
>>> x.size()
torch.Size([2, 3, 5])
>>> torch.permute(x, (2, 0, 1)).size()
torch.Size([5, 2, 3])
unsqueeze函数:升维
torch.unsqueeze(input, dim) → Tensor
unsqueeze()函数起升维的作用,对输入的既定位置(dim)升一维。
dim范围在:[-input.dim() - 1, input.dim() + 1)之间,比如输入input是一维,dim可以是-2,-1,0,1,而负dim相当于 dim = dim + input.dim() + 1,也就是dim实际上就只有0,1。
则dim=0时数据为行方向升维,dim=1时为列方向升维。
>>> x = torch.tensor([1, 2, 3, 4]) #torch.Size([4])
>>> torch.unsqueeze(x, 0)
tensor([[ 1, 2, 3, 4]]) #torch.Size([1, 4])
>>> torch.unsqueeze(x, 1)
tensor([[ 1],
[ 2],
[ 3],
[ 4]]) #torch.Size([4, 1])
squeeze函数:降维(未完成)
torch.squeeze(input, dim=None) → Tensor
cat函数:拼接
torch.cat(tensors, dim=0, *, out=None) → Tensor
拼接指定维度的张量。所有张量必须具有相同的形状(连接维度除外)或为空。
torch.cat()可以看作是torc.split()和torc.chunk()的反向操作。
参数:
tensors (sequence of Tensors) –同类型张量的任何python序列。提供的非空张量必须具有相同的形状,cat维度除外。
dim(int,可选)– 张量连接的维度
关键字参数:
out(张量,可选)– 输出张量。
>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.6580, -1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497]])
>>> torch.cat((x, x, x), 0)
tensor([[ 0.6580, -1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497],
[ 0.6580, -1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497],
[ 0.6580, -1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497]])
>>> torch.cat((x, x, x), 1)
tensor([[ 0.6580, -1.0969, -0.4614, 0.6580, -1.0969, -0.4614, 0.6580,
-1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497, -0.1034, -0.5790, 0.1497, -0.1034,
-0.5790, 0.1497]])