-
flatten()
.flatten()函数用于将输入的多维数组(或张量)转换为一个一维数组(或张量)。在 PyTorch 中,可以使用 .flatten()
方法来将张量(可以是多维的)平铺为一个一维张量。例如,如果有一个二维张量 x
,你可以使用 .flatten()
方法将其平铺为一个一维张量。示例代码如下:
import torch
x = torch.tensor([[1, 2, 3],
[4, 5, 6]])
# 将二维张量 x 平铺成一维张量
x_flattened = x.flatten()
print(x_flattened)
输出:
tensor([1, 2, 3, 4, 5, 6])
-
unsqueeze(dim)
增加维度,例如,.unsqueeze(2)
的作用是在张量的第三维(从0开始计数)上增加一个维度。这通常用于在处理图像或音频数据时,需要将单通道数据转换为多通道数据。
例如,如果有一个形状为 (3, 4) 的张量,使用 unsqueeze(2)
后,它的形状将变为 (3, 4, 1)。这个操作可以用来扩展张量的维度,以便与其他张量进行运算或者满足网络模型的输入要求。
以下是一个示例代码
import torch
# 创建一个示例张量
tensor_2d = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
# 在第三维上增加一个维度
tensor_expanded = tensor_2d.unsqueeze(2)
# 打印