torch.nn.Linear 是 pytorch 的线性变换层,定义如下:
Linear(in_features: int, out_features: int, bias: bool = True, device: Any | None = None, dtype: Any | None = None)
全连接层 Fully Connect 一般就就用这个函数来实现。因此在潜意识里,变换的输入张量的 shape 为 (batchsize, in_features),输出张量的 shape 为 (batchsize, out_features)。
当然这是常用的方式,但是 Linear 的输入张量的维度其实并不需要必须为上述的二维,多维也是完全可以的,Linear 仅是对输入的最后一维做线性变换,不影响其他维。
可以看下官网的解释:Linear — PyTorch 1.11.0 documentation

一个例子如下:
import torch
input = torch.randn(30, 20, 10) # [30, 20, 10]
linear = torch.nn.Linear(10, 15) # (*, 10) --> (*, 15)
output = linear(input)
print(output.size()) # 输出 [30, 20, 15]
本文详细介绍了PyTorch中的torch.nn.Linear层,该层用于实现线性变换,不仅限于常见的二维输入,而是可以接受任意维度的张量,只要最后一维对应`in_features`。通过一个实例展示了即使输入张量为三维,依然能正确应用Linear层,并得到期望的输出形状。这表明Linear层的灵活性,可以在各种复杂的神经网络结构中发挥作用。
2017

被折叠的 条评论
为什么被折叠?



