在PyTorch中,flatten
函数可以用来将多维张量(tensor)展平成一维张量。这在准备数据输入到神经网络时非常常见,因为很多神经网络层(如全连接层)期望输入是一维的。
简单理解就是将张量合并到一个维度里.
实例:
import torch
# 创建一个多维张量
tensor = torch.tensor([
[[1, 2, 3],
[4, 5, 6]],
[[7, 8, 9],
[10, 11, 12]]
])
print(tensor.shape)
# 使用flatten函数展平张量
# 参数start_dim=1表示从第二维开始展平,end_dim=2表示展平到第二维结束
flattened_tensor = tensor.flatten(start_dim=1, end_dim=2)
print(flattened_tensor)
print(flattened_tensor.shape)
# 完全展平张量
completely_flattened_tensor = tensor.flatten()
print(completely_flattened_tensor)
print(completely_flattened_tensor.shape)
# 参数1表示从第1维开始之后全部展平
flattened_tensor_start = tensor.flatten(1)
print(flattened_tensor_start)
print(flattened_tensor_start.shape)
输出:
torch.Size([2, 2, 3])
# 参数start_dim=1表示从第1维开始展平,end_dim=2表示展平到第二维结束
tensor([
[ 1, 2, 3, 4, 5, 6],
[ 7, 8, 9, 10, 11, 12]])
torch.Size([2, 6])# 完全展平张量
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
torch.Size([12])# 从第1维开始之后全部展平
tensor([[ 1, 2, 3, 4, 5, 6],
[ 7, 8, 9, 10, 11, 12]])
torch.Size([2, 6])