更简单实用的pytorch——pytorch_lighting

更简单实用的pytorch——pytorch_lighting

介绍

PyTorch Lightning的优点

PyTorch Lightning 是一个“batteries included”的深度学习框架,适合需要最大灵活性同时大规模增强性能的专业人工智能研究人员和机器学习工程师。

Lightning 组织 PyTorch 代码以删除样板文件并释放可扩展性。

具体而言,Lightning 把深度学习中网络定义、前向传播、优化器、训练方式、训练输出都进行了高级封装,可以使得代码更加简洁易写,同时也能根据用户需求进行灵活调整。

在这里插入图片描述

在这里插入图片描述

如何安装?

pip install lightning

conda install lightning -c conda-forge

PyTorch LightningPyTorch 本身不会直接冲突,因为 PyTorch Lightning 是建立在 PyTorch 基础之上的高级封装,旨在简化深度学习模型的训练过程。然而,如果两者的版本不兼容,或者在同一个环境中安装了相互冲突的依赖包,可能会出现问题。为了避免这些问题,应该确保按照官方文档推荐的版本兼容性矩阵来安装相应版本的 PyTorchPyTorch Lightning

PyTorch LightningPyTorch版本对应关系

使用教程

定义LightningModule

LightningModule 使您的 PyTorch nn.Module 能够在训练步骤(还有可选的验证步骤和测试步骤)内以复杂的方式一起运行。

import os
from torch import optim, nn, utils, Tensor
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning as L

# define any number of nn.Modules (or use your current ones)
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))


# define the LightningModule
class LitAutoEncoder(L.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        x, _ = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        # Logging to TensorBoard (if installed) by default
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


# init the autoencoder
autoencoder = LitAutoEncoder(encoder, decoder)```

定义数据集

Lightning 支持任何可迭代( DataLoader 、 numpy 等)用于训练/验证/测试/预测分割。

# setup data
dataset = MNIST(os.getcwd(), download=True, transform=ToTensor())
train_loader = utils.data.DataLoader(dataset)```

训练模型

Lightning Trainer 将任何 LightningModule 与任何数据集“混合”,并抽象出扩展所需的所有工程复杂性。

# train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
trainer = L.Trainer(limit_train_batches=100, max_epochs=1)
trainer.fit(model=autoencoder, train_dataloaders=train_loader)```

The Lightning Trainer automates 40+ tricks including:
Lightning Trainer 可自动执行 40 多个技巧,包括:

使用模型

训练完模型后,您可以导出到 onnx、torchscript 并将其投入生产,或者只是加载权重并运行预测。

# load checkpoint
checkpoint = "./lightning_logs/version_0/checkpoints/epoch=0-step=100.ckpt"
autoencoder = LitAutoEncoder.load_from_checkpoint(checkpoint, encoder=encoder, decoder=decoder)

# choose your trained nn.Module
encoder = autoencoder.encoder
encoder.eval()

# embed 4 fake images!
fake_image_batch = torch.rand(4, 28 * 28, device=autoencoder.device)
embeddings = encoder(fake_image_batch)
print("⚡" * 20, "\nPredictions (4 image embeddings):\n", embeddings, "\n", "⚡" * 20)```

可视化训练

如果您安装了tensorboard ,则可以使用它来可视化实验。

在命令行上运行此命令并打开浏览器访问 http://localhost:6006/

tensorboard --logdir .

Supercharge training

使用 Trainer 参数启用高级训练功能。这些是最先进的技术,可以自动集成到您的训练循环中,而无需更改您的代码。

# train on 4 GPUs
trainer = L.Trainer(
    devices=4,
    accelerator="gpu",
 )

# train 1TB+ parameter models with Deepspeed/fsdp
trainer = L.Trainer(
    devices=4,
    accelerator="gpu",
    strategy="deepspeed_stage_2",
    precision=16
 )

# 20+ helpful flags for rapid idea iteration
trainer = L.Trainer(
    max_epochs=10,
    min_epochs=5,
    overfit_batches=1
 )

# access the latest state of the art techniques
trainer = L.Trainer(callbacks=[StochasticWeightAveraging(...)])

最大限度地提高灵活性

Lightning 的核心指导原则是始终提供最大的灵活性,而不隐藏任何 PyTorch。

根据项目的复杂性,Lightning 提供 5 种额外的灵活性。

自定义训练循环

在这里插入图片描述

使用 LightningModule 中提供的 20 多种方法(Hook)中的任何一个,在训练循环中的任何位置注入自定义代码。

class LitAutoEncoder(L.LightningModule):
    def backward(self, loss):
        loss.backward()
扩展训练器

讲解视频

如果您有多行具有类似功能的代码,则可以使用回调将它们轻松分组在一起,并同时打开或关闭所有这些行。

trainer = Trainer(callbacks=[AWSCheckpoints()])
使用raw PyTorch loop

对于某些类型的前沿研究工作,Lightning 为专家提供了以各种方式完全控制优化或训练循环的能力。

拥有你的循环(高级) — PyTorch Lightning 2.3.3 文档 — Own your loop (advanced) — PyTorch Lightning 2.3.3 documentation

拓展阅读

[第 2 级:添加验证和测试集 — PyTorch Lightning 2.3.3 文档 — Level 2: Add a validation and test set — PyTorch Lightning 2.3.3 documentation](https://lightning.ai/docs/pytorch/stable/model/build_model_advanced.html#manual-optimization)

PyTorch Lightning 教程 — PyTorch Lightning 2.3.3 文档 — PyTorch Lightning Tutorials — PyTorch Lightning 2.3.3 documentation

将模型部署到生产中 — PyTorch Lightning 2.3.3 文档 — Deploy models into production — PyTorch Lightning 2.3.3 documentation

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

hardw_littlew

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值