PyTorch Lightning基础教程:快速构建和训练模型
pytorch-lightning 项目地址: https://gitcode.com/gh_mirrors/pyt/pytorch-lightning
前言
在深度学习项目开发中,模型训练往往需要编写大量重复性代码,包括训练循环、验证逻辑、分布式训练等。PyTorch Lightning框架通过提供高级抽象,让开发者能够专注于模型设计而非工程细节。本文将详细介绍如何使用PyTorch Lightning快速构建和训练一个基础模型。
环境准备
首先确保已安装必要的库:
pip install torch torchvision pytorch-lightning
模型架构设计
我们以自编码器(AutoEncoder)为例,它由编码器(Encoder)和解码器(Decoder)两部分组成:
import torch
from torch import nn
import torch.nn.functional as F
class Encoder(nn.Module):
def __init__(self):
super().__init__()
self.l1 = nn.Sequential(
nn.Linear(28 * 28, 64), # 输入层到隐藏层
nn.ReLU(), # 激活函数
nn.Linear(64, 3) # 隐藏层到潜在空间
)
def forward(self, x):
return self.l1(x)
class Decoder(nn.Module):
def __init__(self):
super().__init__()
self.l1 = nn.Sequential(
nn.Linear(3, 64), # 潜在空间到隐藏层
nn.ReLU(), # 激活函数
nn.Linear(64, 28 * 28) # 隐藏层到输出层
)
def forward(self, x):
return self.l1(x)
这个简单的自编码器将28x28=784维的输入压缩到3维潜在空间,再重建回原始维度。
创建LightningModule
PyTorch Lightning的核心是LightningModule
,它封装了模型定义、训练逻辑和优化器配置:
import lightning as L
class LitAutoEncoder(L.LightningModule):
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
def training_step(self, batch, batch_idx):
x, _ = batch # 忽略标签
x = x.view(x.size(0), -1) # 展平图像
z = self.encoder(x) # 编码
x_hat = self.decoder(z) # 解码
loss = F.mse_loss(x_hat, x) # 计算重建损失
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
关键点说明:
training_step
定义了单批次训练逻辑configure_optimizers
返回优化器实例- 自动支持GPU训练、混合精度等特性
准备数据
使用MNIST数据集作为示例:
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader
transform = transforms.ToTensor()
dataset = MNIST(os.getcwd(), download=True, transform=transform)
train_loader = DataLoader(dataset, batch_size=32)
训练模型
使用Lightning的Trainer
简化训练流程:
# 初始化模型
autoencoder = LitAutoEncoder(Encoder(), Decoder())
# 创建训练器并开始训练
trainer = L.Trainer(max_epochs=10)
trainer.fit(model=autoencoder, train_dataloaders=train_loader)
Trainer
自动处理了以下内容:
- 训练循环的epoch迭代
- 梯度计算和参数更新
- 日志记录
- 进度显示
与传统PyTorch代码对比
传统PyTorch需要手动编写训练循环:
autoencoder = LitAutoEncoder(Encoder(), Decoder())
optimizer = autoencoder.configure_optimizers()
for epoch in range(10):
for batch_idx, batch in enumerate(train_loader):
loss = autoencoder.training_step(batch, batch_idx)
loss.backward()
optimizer.step()
optimizer.zero_grad()
而随着项目复杂度增加(如添加验证、测试、学习率调度等),手动循环会变得冗长且难以维护。PyTorch Lightning通过标准化这些流程,显著提高了代码的可维护性和可扩展性。
进阶功能
虽然本文展示的是基础用法,但PyTorch Lightning还支持:
- 多GPU/TPU训练
- 16位混合精度训练
- 早停机制
- 模型检查点
- 超参数优化
- 分布式训练
这些功能都可以通过简单的配置实现,无需重写训练逻辑。
总结
PyTorch Lightning通过提供高级抽象,让开发者能够:
- 专注于模型设计而非工程细节
- 减少样板代码量
- 轻松实现复杂训练逻辑
- 保持代码整洁和可维护性
对于初学者,建议从基础用法开始,逐步探索框架提供的各种高级特性。
pytorch-lightning 项目地址: https://gitcode.com/gh_mirrors/pyt/pytorch-lightning
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考