更简单实用的pytorch——pytorch_lighting
介绍
PyTorch Lightning的优点
PyTorch Lightning 是一个“batteries included”的深度学习框架,适合需要最大灵活性同时大规模增强性能的专业人工智能研究人员和机器学习工程师。
Lightning 组织 PyTorch 代码以删除样板文件并释放可扩展性。
具体而言,Lightning 把深度学习中网络定义、前向传播、优化器、训练方式、训练输出都进行了高级封装,可以使得代码更加简洁易写,同时也能根据用户需求进行灵活调整。
如何安装?
pip install lightning
conda install lightning -c conda-forge
PyTorch Lightning
和 PyTorch
本身不会直接冲突,因为 PyTorch Lightning
是建立在 PyTorch
基础之上的高级封装,旨在简化深度学习模型的训练过程。然而,如果两者的版本不兼容,或者在同一个环境中安装了相互冲突的依赖包,可能会出现问题。为了避免这些问题,应该确保按照官方文档推荐的版本兼容性矩阵来安装相应版本的 PyTorch
和 PyTorch Lightning
PyTorch Lightning
和 PyTorch
的版本对应关系。
使用教程
定义LightningModule
LightningModule 使您的 PyTorch nn.Module 能够在训练步骤(还有可选的验证步骤和测试步骤)内以复杂的方式一起运行。
import os
from torch import optim, nn, utils, Tensor
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning as L
# define any number of nn.Modules (or use your current ones)
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
# define the LightningModule
class LitAutoEncoder(L.LightningModule):
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
def training_step(self, batch, batch_idx):
# training_step defines the train loop.
# it is independent of forward
x, _ = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = nn.functional.mse_loss(x_hat, x)
# Logging to TensorBoard (if installed) by default
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr=1e-3)
return optimizer
# init the autoencoder
autoencoder = LitAutoEncoder(encoder, decoder)```
定义数据集
Lightning 支持任何可迭代( DataLoader
、 numpy 等)用于训练/验证/测试/预测分割。
# setup data
dataset = MNIST(os.getcwd(), download=True, transform=ToTensor())
train_loader = utils.data.DataLoader(dataset)```
训练模型
Lightning Trainer 将任何 LightningModule 与任何数据集“混合”,并抽象出扩展所需的所有工程复杂性。
# train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
trainer = L.Trainer(limit_train_batches=100, max_epochs=1)
trainer.fit(model=autoencoder, train_dataloaders=train_loader)```
The Lightning Trainer automates 40+ tricks including:
Lightning Trainer 可自动执行 40 多个技巧,包括:
- Epoch 和批次迭代
optimizer.step()
、loss.backward()
、optimizer.zero_grad()
调用- Calling of
model.eval()
, enabling/disabling grads during evaluation
调用model.eval()
,在评估期间启用/禁用等级 - Checkpoints保存和加载
- Tensorboard (see loggers options)
- Multi-GPU support 多 GPU 支持
- TPU 加速器
- 16-bit precision AMP support
使用模型
训练完模型后,您可以导出到 onnx、torchscript 并将其投入生产,或者只是加载权重并运行预测。
# load checkpoint
checkpoint = "./lightning_logs/version_0/checkpoints/epoch=0-step=100.ckpt"
autoencoder = LitAutoEncoder.load_from_checkpoint(checkpoint, encoder=encoder, decoder=decoder)
# choose your trained nn.Module
encoder = autoencoder.encoder
encoder.eval()
# embed 4 fake images!
fake_image_batch = torch.rand(4, 28 * 28, device=autoencoder.device)
embeddings = encoder(fake_image_batch)
print("⚡" * 20, "\nPredictions (4 image embeddings):\n", embeddings, "\n", "⚡" * 20)```
可视化训练
如果您安装了tensorboard ,则可以使用它来可视化实验。
在命令行上运行此命令并打开浏览器访问 http://localhost:6006/
tensorboard --logdir .
Supercharge training
使用 Trainer 参数启用高级训练功能。这些是最先进的技术,可以自动集成到您的训练循环中,而无需更改您的代码。
# train on 4 GPUs
trainer = L.Trainer(
devices=4,
accelerator="gpu",
)
# train 1TB+ parameter models with Deepspeed/fsdp
trainer = L.Trainer(
devices=4,
accelerator="gpu",
strategy="deepspeed_stage_2",
precision=16
)
# 20+ helpful flags for rapid idea iteration
trainer = L.Trainer(
max_epochs=10,
min_epochs=5,
overfit_batches=1
)
# access the latest state of the art techniques
trainer = L.Trainer(callbacks=[StochasticWeightAveraging(...)])
最大限度地提高灵活性
Lightning 的核心指导原则是始终提供最大的灵活性,而不隐藏任何 PyTorch。
根据项目的复杂性,Lightning 提供 5 种额外的灵活性。
自定义训练循环
使用 LightningModule 中提供的 20 多种方法(Hook)中的任何一个,在训练循环中的任何位置注入自定义代码。
class LitAutoEncoder(L.LightningModule):
def backward(self, loss):
loss.backward()
扩展训练器
如果您有多行具有类似功能的代码,则可以使用回调将它们轻松分组在一起,并同时打开或关闭所有这些行。
trainer = Trainer(callbacks=[AWSCheckpoints()])
使用raw PyTorch loop
对于某些类型的前沿研究工作,Lightning 为专家提供了以各种方式完全控制优化或训练循环的能力。
拓展阅读
[第 2 级:添加验证和测试集 — PyTorch Lightning 2.3.3 文档 — Level 2: Add a validation and test set — PyTorch Lightning 2.3.3 documentation](https://lightning.ai/docs/pytorch/stable/model/build_model_advanced.html#manual-optimization)