第五章:常见的神经网络架构
神经网络架构的选择直接影响模型的表现和适用场景。本章将系统介绍三类最具代表性的深度学习架构:卷积神经网络(CNN)、循环神经网络(RNN)及其变体,以及近年来最具影响力的Transformer架构。每节包含原理、结构图、典型应用和PyTorch代码示例,并对比各自的优势与适用场景。
5.1 卷积神经网络(CNN)
5.1.1 原理简介
卷积神经网络(CNN)专为处理具有空间结构的数据(如图像)而设计。其核心思想包括:
- 局部感受野:每个神经元只关注输入的一小块区域。
- 参数共享:同一卷积核在整个输入上滑动,大幅减少参数数量。
- 池化操作:降低特征图尺寸,提取主要特征。
结构示意图:
5.1.2 典型结构
一个典型的CNN由以下几层组成:
- 卷积层:提取局部特征。
- 激活层:常用ReLU函数。
- 池化层:降采样,常用最大池化。
- 全连接层:输出分类结果。
卷积操作示意图:
### 5.1.3 经典架构与应用
- LeNet-5:最早的手写数字识别CNN。
- AlexNet:首次在ImageNet大赛中获胜,推动深度学习热潮。
- VGG:使用更深的网络和小卷积核。
- ResNet:引入残差连接,极大加深网络深度。
PyTorch代码示例:LeNet-5
import torch.nn as nn
class LeNet5(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16*4*4, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(nn.functional.relu(self.conv1(x)))
x = self.pool(nn.functional.relu(self.conv2(x)))
x = x.view(-1, 16*4*4)
x = nn.functional.relu(self.fc1(x))
x = nn.functional.relu(self.fc2(x))
return self.fc3(x)
5.2 循环神经网络(RNN)及其变体
5.2.1 原理简介
循环神经网络(RNN)适合处理序列数据(如文本、语音、时间序列)。其特点是:
- 隐藏状态:能够记忆前面时刻的信息。
- 参数共享:同一组参数在每个时间步重复使用。
结构展开图:
5.2.2 主要变体
- LSTM(长短期记忆网络):通过门控机制解决RNN的梯度消失问题。
- GRU(门控循环单元):结构更简单,效果接近LSTM。
LSTM单元结构:
5.2.3 典型应用
- 语音识别、文本生成、机器翻译、时间序列预测等。
PyTorch代码示例:LSTM
import torch.nn as nn
class LSTMModel(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
_, (h_n, _) = self.lstm(x)
return self.fc(h_n.squeeze(0))
5.3 Transformer:序列建模的新范式
5.3.1 原理简介
Transformer 架构基于自注意力机制(Self-Attention),能够捕捉序列中任意位置之间的依赖关系,无需像RNN那样按顺序处理数据,支持高效并行计算。
结构示意图:
5.3.2 主要结构
- 编码器(Encoder):输入序列经过多层自注意力和前馈网络,提取全局特征。
- 解码器(Decoder):生成输出序列,同样由多层自注意力和前馈网络组成,并能关注编码器输出。
- 自注意力机制:每个位置都能关注序列中所有其他位置的信息。
- 位置编码:补充序列顺序信息。
5.3.3 典型应用
- 机器翻译(如Google翻译)
- 文本生成(如GPT、ChatGPT)
- 图像识别(ViT:Vision Transformer)
PyTorch代码示例:Transformer
import torch.nn as nn
class SimpleTransformer(nn.Module):
def __init__(self, d_model=512, nhead=8, num_layers=2):
super().__init__()
self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead)
self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
self.fc = nn.Linear(d_model, 10) # 假设分类任务
def forward(self, x):
# x: [seq_len, batch_size, d_model]
x = self.transformer_encoder(x)
x = x.mean(dim=0) # 池化
return self.fc(x)
5.4 架构对比与选择
架构 | 主要特点 | 优势 | 典型应用 |
---|---|---|---|
CNN | 局部感受野、参数共享 | 图像特征提取高效 | 图像、音频 |
RNN/LSTM | 时序建模、参数共享 | 适合短序列、时序依赖 | 文本、语音 |
Transformer | 全局自注意力、并行计算 | 长距离依赖、训练高效 | NLP、CV、生成模型 |
5.5 实践案例:用CNN提升MNIST手写数字识别
# 省略数据加载部分
class SimpleCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.fc1 = nn.Linear(64*7*7, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.pool(nn.functional.relu(self.conv1(x)))
x = self.pool(nn.functional.relu(self.conv2(x)))
x = x.view(-1, 64*7*7)
x = nn.functional.relu(self.fc1(x))
return self.fc2(x)
小结
- CNN:图像领域的主力架构,参数高效,特征提取能力强。
- RNN/LSTM/GRU:序列建模的基础,能捕捉时序依赖。
- Transformer:自注意力机制,适合长距离依赖和大规模并行计算。
- 选择合适的架构是深度学习成功的关键。
return self.fc2(x)