使用PyTorch Lightning实现最佳模型保存策略

使用PyTorch Lightning实现最佳模型保存策略

deeplearning-models A collection of various deep learning architectures, models, and tips deeplearning-models 项目地址: https://gitcode.com/gh_mirrors/de/deeplearning-models

前言

在深度学习模型训练过程中,如何有效地保存最佳模型是一个重要课题。本文将基于深度学习模型集合中的PyTorch Lightning实现,详细介绍如何利用其内置功能自动保存验证集上表现最好的模型。

PyTorch Lightning简介

PyTorch Lightning是构建在PyTorch之上的高级框架,它通过将研究代码与工程代码分离,大大简化了深度学习模型的开发流程。主要优势包括:

  1. 更清晰的代码结构
  2. 内置训练循环和验证循环
  3. 自动支持分布式训练
  4. 丰富的回调函数系统
  5. 内置模型检查点保存机制

实现最佳模型保存

1. 模型架构定义

我们首先定义一个多层感知机(MLP)模型,继承自pl.LightningModule

class MultiLayerPerceptron(pl.LightningModule):
    def __init__(self, image_shape=(1, 28, 28), hidden_units=(32, 16)):
        super().__init__()
        
        # 初始化准确率计算器
        self.train_acc = Accuracy()
        self.valid_acc = Accuracy() 
        self.test_acc = Accuracy()
        
        # 构建网络结构
        input_size = image_shape[0] * image_shape[1] * image_shape[2]
        all_layers = [nn.Flatten()]
        for hidden_unit in hidden_units:
            layer = nn.Linear(input_size, hidden_unit)
            all_layers.append(layer)
            all_layers.append(nn.ReLU())
            input_size = hidden_unit
        
        all_layers.append(nn.Linear(hidden_units[-1], 10))
        all_layers.append(nn.Softmax(dim=1))
        self.model = nn.Sequential(*all_layers)

2. 训练流程实现

PyTorch Lightning将训练过程分解为几个关键方法:

def training_step(self, batch, batch_idx):
    x, y = batch
    logits = self(x)
    loss = nn.functional.cross_entropy(logits, y)
    preds = torch.argmax(logits, dim=1)
    self.train_acc.update(preds, y)
    self.log("train_loss", loss, prog_bar=True)
    return loss

def training_epoch_end(self, outs):
    self.log("train_acc", self.train_acc.compute())
    self.train_acc.reset()

3. 验证流程实现

验证流程与训练类似,但需要注意:

def validation_step(self, batch, batch_idx):
    x, y = batch
    logits = self(x)
    loss = nn.functional.cross_entropy(logits, y)
    preds = torch.argmax(logits, dim=1)
    self.valid_acc.update(preds, y)
    self.log("valid_loss", loss, prog_bar=True)
    return loss

def validation_epoch_end(self, outs):
    self.log("valid_acc", self.valid_acc.compute(), prog_bar=True)
    self.valid_acc.reset()

4. 数据加载模块

使用LightningDataModule封装数据加载逻辑:

class MnistDataModule(pl.LightningDataModule):
    def __init__(self, data_path='./'):
        super().__init__()
        self.data_path = data_path
        self.transform = transforms.Compose([transforms.ToTensor()])
        
    def prepare_data(self):
        MNIST(root=self.data_path, download=True)

    def setup(self, stage=None):
        mnist_all = MNIST(
            root=self.data_path,
            train=True,
            transform=self.transform,
            download=False
        )
        
        self.train, self.val = random_split(
            mnist_all, [55000, 5000], generator=torch.Generator().manual_seed(1)
        )
        
        self.test = MNIST(
            root=self.data_path,
            train=False,
            transform=self.transform,
            download=False
        )

最佳模型保存策略

1. 配置模型检查点回调

PyTorch Lightning提供了ModelCheckpoint回调,可以灵活配置保存策略:

from pytorch_lightning.callbacks import ModelCheckpoint

# 配置只保存验证准确率最高的1个模型
callbacks = [ModelCheckpoint(
    save_top_k=1,        # 只保存最好的1个模型
    mode='max',          # 监控指标越大越好
    monitor="valid_acc"  # 监控验证集准确率
)]

2. 训练器配置与启动

配置并启动训练器,自动应用模型保存策略:

pytorch_model = MultiLayerPerceptron()

trainer = pl.Trainer(
    max_epochs=15,
    callbacks=callbacks,  # 应用回调
    progress_bar_refresh_rate=50,
)

trainer.fit(model=pytorch_model, datamodule=mnist_dm)

技术要点解析

  1. 监控指标选择:我们选择验证集准确率(valid_acc)作为模型保存的依据,这是分类任务中最直观的评估指标。

  2. 保存策略save_top_k=1表示只保留表现最好的一个模型,避免存储空间浪费。

  3. 模式选择mode='max'表示我们希望监控的指标越大越好,对于准确率、精确率等指标使用max,对于损失值等指标则使用min。

  4. 日志记录:通过self.log()方法记录的指标会自动被ModelCheckpoint回调捕获,用于决定是否保存当前模型。

实际应用建议

  1. 多指标监控:可以配置多个ModelCheckpoint回调,同时监控不同指标,如损失值和准确率。

  2. 定期保存:除了保存最佳模型,还可以配置定期保存检查点,防止训练意外中断。

  3. 模型恢复:保存的模型可以方便地加载并继续训练或用于推理:

    model = MultiLayerPerceptron.load_from_checkpoint(checkpoint_path)
    
  4. 自定义文件名:可以通过filename参数自定义保存的文件名模式,方便后续识别。

总结

通过PyTorch Lightning的ModelCheckpoint回调,我们可以轻松实现最佳模型保存策略,无需手动编写复杂的模型选择和保存逻辑。这种方法不仅代码简洁,而且可靠性高,是深度学习实践中值得掌握的重要技巧。

deeplearning-models A collection of various deep learning architectures, models, and tips deeplearning-models 项目地址: https://gitcode.com/gh_mirrors/de/deeplearning-models

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

霍薇樱Quintessa

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值