一、核心特性概述
在 PyTorch 中,nn.Linear(in_features, out_features)
线性层总是作用于输入张量的最后一个维度,其他维度被视为 “任意形状” 的 batch 维度。无论输入张量是几维,线性层都会提取最后一个维度进行处理。
-
处理逻辑:对输入张量最后一个维度执行线性变换,输出张量除最后一个维度变为
out_features
外,其他维度保持不变 。 -
维度要求:输入张量最后一个维度需等于
in_features
,输出张量最后一个维度为out_features
。
二、示例说明
1. 三维输入(常见于 NLP 序列数据)
x = torch.randn(2, 4, 512) # \[batch\_size, seq\_len, embed\_dim]
linear = nn.Linear(512, 512)
output = linear(x)
print(output.shape) # 输出: \[2, 4, 512]
线性层仅对最后维度 512
进行变换,[2, 4]
作为批量相关维度保留不变。
2. 四维输入(如图像 Transformer 数据)
x = torch.randn(2, 8, 16, 512) # \[batch, head, seq\_len, embed\_dim]
linear = nn.Linear(512, 512)
output = linear(x)
print(output.shape) # 输出: \[2, 8, 16, 512]
同样,线性层仅处理最后维度 512
,其余维度原样输出。
3. 二维输入(传统 NLP 词向量)
x = torch.randn(32, 512) # \[batch\_size, embed\_dim]
output = linear(x)
print(output.shape) # 输出: \[32, 512]
无论输入维度简单或复杂,线性层均基于最后维度执行变换。
三、数学原理
对于形状为 [..., in_features]
的输入张量 x
,经 nn.Linear(in_features, out_features)
变换后,输出张量形状变为 [..., out_features]
。其数学运算公式为:
y=x⋅WT+by = x \cdot W^T + by=x⋅WT+b
其中:
-
W∈Rout_features×in_featuresW \in \mathbb{R}^{out\_features \times in\_features}W∈Rout_features×in_features ,是可学习权重矩阵;
-
b∈Rout_featuresb \in \mathbb{R}^{out\_features}b∈Rout_features ,是可学习偏置向量。
四、应用场景举例(以 Transformer 为例)
输入张量形状 | 最后一个维度 | 是否匹配 Linear(embed_dim, embed_dim) | 应用场景 |
---|---|---|---|
[2, 4, 512] | 512 | ✅ 是,可直接使用 | 常规 Transformer 层间变换 |
[2, 8, 4, 64] | 64 | ✅ 是,适用于多头注意力中每个 head 的输出 | 多头注意力后处理 |
[10, 20] | 20 | ✅ 可用于任何需要线性变换的地方 | 简单词向量变换 |
五、总结
nn.Linear(embed_dim, embed_dim)
仅关注输入张量最后一个维度是否等于 in_features
,并基于此进行线性变换,其他维度保持不变。该特性赋予了线性层在 RNN、Transformer、CNN 等多种网络结构中灵活应用的能力 。