如何将PyTorch代码转换为PyTorch Lightning框架
pytorch-lightning 项目地址: https://gitcode.com/gh_mirrors/pyt/pytorch-lightning
前言
PyTorch Lightning是一个轻量级的PyTorch封装框架,它通过提供清晰的结构和自动化处理许多样板代码,让深度学习研究变得更加简单高效。本文将详细介绍如何将标准的PyTorch代码迁移到PyTorch Lightning框架中。
1. 保留原有模型结构
PyTorch Lightning完全兼容标准的nn.Module
,因此你可以保留原有的模型架构不做任何修改:
import torch
import torch.nn as nn
import torch.nn.functional as F
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.layer_1 = nn.Linear(28 * 28, 128)
self.layer_2 = nn.Linear(128, 10)
def forward(self, x):
x = x.view(x.size(0), -1)
x = self.layer_1(x)
x = F.relu(x)
x = self.layer_2(x)
return x
2. 配置训练逻辑
将训练逻辑从传统的训练循环中提取出来,放入training_step
方法中:
import lightning as L
class LitModel(L.LightningModule):
def __init__(self, model):
super().__init__()
self.model = model
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
return loss
3. 优化器和学习率调度器配置
将优化器和学习率调度器的配置移到configure_optimizers
方法中:
class LitModel(L.LightningModule):
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler]
4. 验证逻辑配置(可选)
如果需要验证集,可以添加validation_step
方法:
class LitModel(L.LightningModule):
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
val_loss = F.cross_entropy(y_hat, y)
self.log("val_loss", val_loss)
5. 测试逻辑配置(可选)
测试逻辑可以放在test_step
方法中:
class LitModel(L.LightningModule):
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
test_loss = F.cross_entropy(y_hat, y)
self.log("test_loss", test_loss)
6. 预测逻辑配置(可选)
预测逻辑可以放在predict_step
方法中:
class LitModel(L.LightningModule):
def predict_step(self, batch, batch_idx):
x, y = batch
pred = self.model(x)
return pred
7. 移除设备相关代码
PyTorch Lightning会自动处理设备分配,因此可以移除所有.cuda()
或.to(device)
调用:
class LitModel(L.LightningModule):
def training_step(self, batch, batch_idx):
# 自动处理设备
z = torch.randn(4, 5, device=self.device)
...
如果需要在__init__
中初始化张量并希望自动移动到设备上,可以使用register_buffer
:
class LitModel(L.LightningModule):
def __init__(self):
super().__init__()
self.register_buffer("running_mean", torch.zeros(num_features))
8. 使用自定义数据
PyTorch Lightning兼容标准的PyTorch DataLoader:
from torch.utils.data import DataLoader
train_loader = DataLoader(train_dataset, batch_size=32)
val_loader = DataLoader(val_dataset, batch_size=32)
高级技巧
- 单独运行验证集:可以直接调用
trainer.validate()
方法运行验证集 - 测试集使用:测试集不会在训练过程中自动运行,需要显式调用
trainer.test()
- 预测功能:预测功能需要显式调用
trainer.predict()
总结
通过以上步骤,你可以将标准的PyTorch代码优雅地迁移到PyTorch Lightning框架中。这种转换不仅使代码更加模块化和可维护,还能自动获得多GPU训练、混合精度训练等高级功能,同时保持了对PyTorch原生功能的完全兼容性。
pytorch-lightning 项目地址: https://gitcode.com/gh_mirrors/pyt/pytorch-lightning
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考