PyTorch实现Softmax回归的简洁方法:基于dsgiitr/d2l-pytorch项目的实践指南
引言
在机器学习领域,Softmax回归是多类分类任务的基础模型。本文将基于PyTorch框架,介绍如何简洁高效地实现Softmax回归模型。我们将使用Fashion-MNIST数据集作为示例,展示从数据加载到模型训练的全过程。
环境准备与数据加载
首先我们需要导入必要的库并准备数据集:
import torch
import torch.nn as nn
from d2l import torch as d2l
# 设置批量大小为256
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
这里我们使用了256的批量大小,这是一个在内存效率和训练速度之间取得良好平衡的选择。
模型构建
网络结构设计
Softmax回归模型本质上是一个单层的全连接神经网络:
class Reshape(torch.nn.Module):
def forward(self, x):
return x.view(-1, 784) # 将28x28图像展平为784维向量
net = nn.Sequential(
Reshape(), # 数据预处理层
nn.Linear(784, 10) # 全连接层
)
我们定义了一个Reshape
类来处理输入数据的形状转换,然后接一个具有10个输出单元的全连接层(对应Fashion-MNIST的10个类别)。
参数初始化
良好的初始化对模型训练至关重要:
def init_weights(m):
if type(m) == nn.Linear:
torch.nn.init.normal_(m.weight, std=0.01)
net.apply(init_weights)
这里我们使用标准差为0.01的正态分布来初始化权重,这种小范围的初始化有助于模型训练的稳定性。
Softmax与交叉熵损失
数值稳定性问题
Softmax函数定义为:
$\hat y_j = \frac{e^{z_j}}{\sum_{i=1}^{n} e^{z_i}}$
直接计算时可能会遇到数值稳定性问题:
- 上溢(overflow):当$z_i$很大时,$e^{z_i}$可能超过浮点数表示范围
- 下溢(underflow):当$z_i$很小时,$e^{z_i}$可能被舍入为0
解决方案
PyTorch的CrossEntropyLoss
巧妙地结合了Softmax和交叉熵计算,避免了数值问题:
loss = nn.CrossEntropyLoss()
这种实现使用了"log-sum-exp"技巧,直接从logits计算损失,既高效又稳定。
优化器配置
我们使用学习率为0.1的小批量随机梯度下降(SGD):
optimizer = torch.optim.SGD(net.parameters(), lr=0.1)
这个简单的优化器在许多情况下表现良好,特别是对于像Softmax回归这样的浅层模型。
模型训练
训练过程简洁明了:
num_epochs = 10
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, lr)
经过10个epoch的训练,模型在测试集上达到了约83%的准确率,这与从零开始实现的版本性能相当,但代码量大大减少。
关键知识点总结
- 网络结构:Softmax回归本质上是单层全连接网络
- 数值稳定性:PyTorch的交叉熵损失函数内部处理了Softmax的数值稳定性问题
- 初始化:小范围随机初始化有助于模型收敛
- 优化:SGD等简单优化器对浅层模型通常足够
实践建议
- 超参数调优:可以尝试调整学习率、批量大小等超参数观察模型表现
- 过拟合处理:如果测试准确率开始下降,考虑添加正则化或早停策略
- 扩展性:这个简洁实现可以轻松扩展到更复杂的深度网络
通过PyTorch的高级API,我们能够以极简的代码实现高效的Softmax回归模型,这为后续更复杂模型的实现奠定了良好基础。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考