使用PyTorch Lightning实现最佳模型保存策略
前言
在深度学习模型训练过程中,如何有效地保存最佳模型是一个重要课题。本文将基于深度学习模型集合中的PyTorch Lightning实现,详细介绍如何利用其内置功能自动保存验证集上表现最好的模型。
PyTorch Lightning简介
PyTorch Lightning是构建在PyTorch之上的高级框架,它通过将研究代码与工程代码分离,大大简化了深度学习模型的开发流程。主要优势包括:
- 更清晰的代码结构
- 内置训练循环和验证循环
- 自动支持分布式训练
- 丰富的回调函数系统
- 内置模型检查点保存机制
实现最佳模型保存
1. 模型架构定义
我们首先定义一个多层感知机(MLP)模型,继承自pl.LightningModule
:
class MultiLayerPerceptron(pl.LightningModule):
def __init__(self, image_shape=(1, 28, 28), hidden_units=(32, 16)):
super().__init__()
# 初始化准确率计算器
self.train_acc = Accuracy()
self.valid_acc = Accuracy()
self.test_acc = Accuracy()
# 构建网络结构
input_size = image_shape[0] * image_shape[1] * image_shape[2]
all_layers = [nn.Flatten()]
for hidden_unit in hidden_units:
layer = nn.Linear(input_size, hidden_unit)
all_layers.append(layer)
all_layers.append(nn.ReLU())
input_size = hidden_unit
all_layers.append(nn.Linear(hidden_units[-1], 10))
all_layers.append(nn.Softmax(dim=1))
self.model = nn.Sequential(*all_layers)
2. 训练流程实现
PyTorch Lightning将训练过程分解为几个关键方法:
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = nn.functional.cross_entropy(logits, y)
preds = torch.argmax(logits, dim=1)
self.train_acc.update(preds, y)
self.log("train_loss", loss, prog_bar=True)
return loss
def training_epoch_end(self, outs):
self.log("train_acc", self.train_acc.compute())
self.train_acc.reset()
3. 验证流程实现
验证流程与训练类似,但需要注意:
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = nn.functional.cross_entropy(logits, y)
preds = torch.argmax(logits, dim=1)
self.valid_acc.update(preds, y)
self.log("valid_loss", loss, prog_bar=True)
return loss
def validation_epoch_end(self, outs):
self.log("valid_acc", self.valid_acc.compute(), prog_bar=True)
self.valid_acc.reset()
4. 数据加载模块
使用LightningDataModule
封装数据加载逻辑:
class MnistDataModule(pl.LightningDataModule):
def __init__(self, data_path='./'):
super().__init__()
self.data_path = data_path
self.transform = transforms.Compose([transforms.ToTensor()])
def prepare_data(self):
MNIST(root=self.data_path, download=True)
def setup(self, stage=None):
mnist_all = MNIST(
root=self.data_path,
train=True,
transform=self.transform,
download=False
)
self.train, self.val = random_split(
mnist_all, [55000, 5000], generator=torch.Generator().manual_seed(1)
)
self.test = MNIST(
root=self.data_path,
train=False,
transform=self.transform,
download=False
)
最佳模型保存策略
1. 配置模型检查点回调
PyTorch Lightning提供了ModelCheckpoint
回调,可以灵活配置保存策略:
from pytorch_lightning.callbacks import ModelCheckpoint
# 配置只保存验证准确率最高的1个模型
callbacks = [ModelCheckpoint(
save_top_k=1, # 只保存最好的1个模型
mode='max', # 监控指标越大越好
monitor="valid_acc" # 监控验证集准确率
)]
2. 训练器配置与启动
配置并启动训练器,自动应用模型保存策略:
pytorch_model = MultiLayerPerceptron()
trainer = pl.Trainer(
max_epochs=15,
callbacks=callbacks, # 应用回调
progress_bar_refresh_rate=50,
)
trainer.fit(model=pytorch_model, datamodule=mnist_dm)
技术要点解析
-
监控指标选择:我们选择验证集准确率(
valid_acc
)作为模型保存的依据,这是分类任务中最直观的评估指标。 -
保存策略:
save_top_k=1
表示只保留表现最好的一个模型,避免存储空间浪费。 -
模式选择:
mode='max'
表示我们希望监控的指标越大越好,对于准确率、精确率等指标使用max,对于损失值等指标则使用min。 -
日志记录:通过
self.log()
方法记录的指标会自动被ModelCheckpoint
回调捕获,用于决定是否保存当前模型。
实际应用建议
-
多指标监控:可以配置多个
ModelCheckpoint
回调,同时监控不同指标,如损失值和准确率。 -
定期保存:除了保存最佳模型,还可以配置定期保存检查点,防止训练意外中断。
-
模型恢复:保存的模型可以方便地加载并继续训练或用于推理:
model = MultiLayerPerceptron.load_from_checkpoint(checkpoint_path)
-
自定义文件名:可以通过
filename
参数自定义保存的文件名模式,方便后续识别。
总结
通过PyTorch Lightning的ModelCheckpoint
回调,我们可以轻松实现最佳模型保存策略,无需手动编写复杂的模型选择和保存逻辑。这种方法不仅代码简洁,而且可靠性高,是深度学习实践中值得掌握的重要技巧。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考