PyTorch 中 nn.Linear 层特性总结笔记

一、核心特性概述

在 PyTorch 中,nn.Linear(in_features, out_features) 线性层总是作用于输入张量的最后一个维度,其他维度被视为 “任意形状” 的 batch 维度。无论输入张量是几维,线性层都会提取最后一个维度进行处理。

  • 处理逻辑:对输入张量最后一个维度执行线性变换,输出张量除最后一个维度变为 out_features 外,其他维度保持不变 。

  • 维度要求:输入张量最后一个维度需等于 in_features ,输出张量最后一个维度为 out_features

二、示例说明

1. 三维输入(常见于 NLP 序列数据)

x = torch.randn(2, 4, 512)  # \[batch\_size, seq\_len, embed\_dim]
linear = nn.Linear(512, 512)
output = linear(x)
print(output.shape)  # 输出: \[2, 4, 512]

线性层仅对最后维度 512 进行变换,[2, 4] 作为批量相关维度保留不变。

2. 四维输入(如图像 Transformer 数据)

x = torch.randn(2, 8, 16, 512)  # \[batch, head, seq\_len, embed\_dim]
linear = nn.Linear(512, 512)
output = linear(x)
print(output.shape)  # 输出: \[2, 8, 16, 512]

同样,线性层仅处理最后维度 512 ,其余维度原样输出。

3. 二维输入(传统 NLP 词向量)

x = torch.randn(32, 512)  # \[batch\_size, embed\_dim]
output = linear(x)
print(output.shape)  # 输出: \[32, 512]

无论输入维度简单或复杂,线性层均基于最后维度执行变换。

三、数学原理

对于形状为 [..., in_features] 的输入张量 x ,经 nn.Linear(in_features, out_features) 变换后,输出张量形状变为 [..., out_features] 。其数学运算公式为:

y=x⋅WT+by = x \cdot W^T + by=xWT+b

其中:

  • W∈Rout_features×in_featuresW \in \mathbb{R}^{out\_features \times in\_features}WRout_features×in_features ,是可学习权重矩阵;

  • b∈Rout_featuresb \in \mathbb{R}^{out\_features}bRout_features ,是可学习偏置向量。

四、应用场景举例(以 Transformer 为例)

输入张量形状 最后一个维度 是否匹配 Linear(embed_dim, embed_dim)应用场景
[2, 4, 512]512✅ 是,可直接使用 常规 Transformer 层间变换
[2, 8, 4, 64]64✅ 是,适用于多头注意力中每个 head 的输出 多头注意力后处理
[10, 20]20✅ 可用于任何需要线性变换的地方 简单词向量变换

五、总结

nn.Linear(embed_dim, embed_dim) 仅关注输入张量最后一个维度是否等于 in_features ,并基于此进行线性变换,其他维度保持不变。该特性赋予了线性层在 RNN、Transformer、CNN 等多种网络结构中灵活应用的能力 。

torch.nn.LinearPyTorch中的一个模块,用于定义一个线性。它接受两个参数,即输入和输出的维度。通过调用torch.nn.Linear(input_dim, output_dim),可以创建一个线性,其中input_dim是输入的维度,output_dim是输出的维度。Linear模块的主要功能是执行线性变换,将输入数据乘以权重矩阵,并加上偏置向量。这个函数的具体实现可以参考PyTorch官方文档中的链接。 在引用中的示例中,linear1是一个Linear模块的实例。可以通过print(linear1.weight.data)来查看linear1的权重。示例中给出了权重的具体数值。 在引用中的示例中,x是一个Linear模块的实例,输入维度为5,输出维度为2。通过调用x(data)来计算线性变换的结果。在这个示例中,输入data的维度是(5,5),输出的维度是(5,2)。可以使用torch.nn.functional.linear函数来实现与torch.nn.Linear相同的功能,其中weight和bias分别表示权重矩阵和偏置向量。 以上是关于torch.nn.Linear的一些介绍和示例。如果需要更详细的信息,可以参考PyTorch官方文档中关于torch.nn.Linear的说明。 https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* [torch.nn.Linear详解](https://blog.csdn.net/sazass/article/details/123568203)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v92^chatsearchT0_1"}}] [.reference_item style="max-width: 33.333333333333336%"] - *2* [torch.nn.Linear](https://blog.csdn.net/weixin_41620490/article/details/127833324)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v92^chatsearchT0_1"}}] [.reference_item style="max-width: 33.333333333333336%"] - *3* [pytorch 笔记:torch.nn.Linear() VS torch.nn.function.linear()](https://blog.csdn.net/qq_40206371/article/details/124473437)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v92^chatsearchT0_1"}}] [.reference_item style="max-width: 33.333333333333336%"] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值