nn.Linear()
用于实现全连接层(也称为线性层)。它是构建神经网络时常用的组件之一,主要用于将输入数据转换为输出数据,同时学习输入和输出之间的线性关系。
1. nn.Linear
的基本用法
nn.Linear
的基本形式如下:
nn.Linear(in_features, out_features, bias=True)
in_features
:输入特征的数量(即输入数据的维度)。out_features
:输出特征的数量(即输出数据的维度)。bias
:是否添加偏置项,默认为True
。如果设置为False
,则该层不会学习偏置项。
2. 例子
在 PyTorch 中使用 nn.Linear
:
import torch
import torch.nn as nn
# 定义一个线性层
linear_layer = nn.Linear(in_features=10, out_features=5, bias=True)
# 创建一个输入张量,形状为 (batch_size, in_features)
input_tensor = torch.randn(3, 10) # 假设 batch_size=3,每个样本有 10 个特征
# 通过线性层进行前向传播
output_tensor = linear_layer(input_tensor)
print("Input shape:", input_tensor.shape)
print("Output shape:", output_tensor.shape)
3. 输出解释
假设输入张量的是形状 (3, 10)
,表示有 3 个样本,每个样本有 10 个特征。经过 nn.Linear
层后,输出张量的形状将是 (3, 5)
,表示每个样本被转换为 5 个特征。
4. 应用场景
nn.Linear
常用于构建神经网络中的全连接层,尤其是在卷积神经网络(CNN)的最后阶段,将卷积层的输出特征图展平后,通过全连接层进行分类或回归任务。
5. 构建一个简单的神经网络
使用 nn.Linear
构建一个简单的神经网络:
import torch
import torch.nn as nn
import torch.nn.functional as F
# 定义一个简单的神经网络
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(10, 5) # 第一个全连接层
self.fc2 = nn.Linear(5, 2) # 第二个全连接层
def forward(self, x):
x = F.relu(self.fc1(x)) # 通过第一个全连接层并应用 ReLU 激活函数
x = self.fc2(x) # 通过第二个全连接层
return x
# 创建模型实例
model = SimpleNN()
# 创建输入张量
input_tensor = torch.randn(3, 10) # 假设 batch_size=3,每个样本有 10 个特征
# 前向传播
output_tensor = model(input_tensor)
print("Input shape:", input_tensor.shape)
print("Output shape:", output_tensor.shape)
6. 输出解释
- 输入张量的形状为
(3, 10)
。 - 经过第一个全连接层后,形状变为
(3, 5)
。 - 经过第二个全连接层后,形状变为
(3, 2)
。
总结
nn.Linear
是 PyTorch 中实现全连接层的模块,用于构建神经网络中的线性变换部分。通过组合多个 nn.Linear
层和激活函数,可以构建复杂的神经网络模型。