第三章: 神经网络原理详解与Pytorch入门
第二部分:深度学习框架PyTorch入门
第四节:Pytorch模型构建
内容:如何搭建复杂网络以及如何修改模型与保存
一、构建复杂神经网络结构
在 PyTorch 中,构建复杂模型通常通过继承 nn.Module
类,分模块组织层与前向传播逻辑。
示例:自定义一个卷积神经网络(CNN)
import torch.nn as nn
import torch.nn.functional as F
class ComplexCNN(nn.Module):
def __init__(self):
super(ComplexCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) # 输入通道1,输出通道32
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(2, 2) # 池化层
self.dropout = nn.Dropout(0.25)
self.fc1 = nn.Linear(64 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x))) # 输出: [batch, 32, 14, 14]
x = self.pool(F.relu(self.conv2(x))) # 输出: [batch, 64, 7, 7]
x = x.view(-1, 64 * 7 * 7) # Flatten
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x
使用
.view()
或torch.flatten()
将多维张量展开为全连接输入
二、动态修改模型结构
在模型训练或实验过程中,我们可能希望替换已有层或添加额外层。
修改已有层
model = ComplexCNN()
model.fc2 = nn.Linear(128, 5) # 将输出从10类改为5类
添加额外层
from collections import OrderedDict
new_layers = nn.Sequential(OrderedDict([
('fc1', nn.Linear(64 * 7 * 7, 256)),
('relu1', nn.ReLU()),
('fc2', nn.Linear(256, 128)),
('relu2', nn.ReLU()),
('fc3', nn.Linear(128, 10))
]))
model.classifier = new_layers # 动态赋值新的子模块
有些大型网络可使用
model.children()
和model.named_modules()
逐层迭代替换
三、模型保存与加载
1. 保存模型参数(推荐方式)
torch.save(model.state_dict(), 'model_weights.pth')
加载时:
model = ComplexCNN()
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()
使用
.eval()
切换到推理模式(如关闭 Dropout)
2. 保存整个模型(包含结构)
torch.save(model, 'model_full.pth')
加载时:
model = torch.load('model_full.pth')
model.eval()
保存结构的方式依赖 Python 解释器,建议用于快速调试或模型冻结导出场景
四、训练完整流程回顾
model = ComplexCNN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
for epoch in range(epochs):
model.train()
for batch in dataloader:
inputs, labels = batch
outputs = model(inputs)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
总结
主题 | 内容 |
---|---|
自定义模型 | 继承 nn.Module 实现结构和前向逻辑 |
层结构修改 | 使用 .fc = 替换层,或构建 nn.Sequential 组合新层 |
模型保存推荐方式 | 使用 state_dict() 保存权重,更稳定且通用 |
模型加载 | load_state_dict() 加载权重并 .eval() 进入推理状态 |