Pytorch-Lightning基本方法介绍

PyTorch Lightning 的 LightningModule 结构化了 PyTorch 代码,包括初始化、训练、验证、测试和推理过程。训练循环中,通过 training_step 方法定义每个步骤的损失计算,validation_step 和 test_step 分别用于验证和测试阶段的损失计算,同时支持在 epoch 级别的度量和数据并行计算。在研究和生产环境中,LightningModule 还支持灵活的前向推理和模型迭代。此外,通过 configure_optimizers 定义优化器。

LIGHTNINGMODULE

LightningModule将PyTorch代码整理成5个部分:

  • Computations (init).
  • Train loop (training_step)
  • Validation loop (validation_step)
  • Test loop (test_step)
  • Optimizers (configure_optimizers)

Minimal Example

所需要的方法:

import pytorch_lightning as pl
class LitModel(pl.LightningModule):

     def __init__(self):
         super().__init__()
         self.l1 = torch.nn.Linear(28 * 28, 10)

     def forward(self, x):
         return torch.relu(self.l1(x.view(x.size(0), -1)))

     def training_step(self, batch, batch_idx):
         x, y = batch
         y_hat = self(x)
         loss = F.cross_entropy(y_hat, y)
         return loss

     def configure_optimizers(self):
         return torch.optim.Adam(self.parameters(), lr=0.02)

使用下面的代码进行训练:

train_loader = DataLoader(MNIST(os.getcwd(), download=True, transform=transforms.ToTensor()))
trainer = pl.Trainer()
model = LitModel()

trainer.fit(model, train_loader)

一些基本方法

Training
Training loop

使用training_step方法来增加training loop

class LitClassifier(pl.LightningModule):

     def __init__(self, model):
         super().__init__()
         self.model = model

     def training_step(self, batch, batch_idx):
         x, y = batch
         y_hat = self.model(x)
         loss = F.cross_entropy(y_hat, y)
         return loss

如果需要在epoch-level进行度量,并进行记录,可以使用*.log*方法

def training_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self.model(x)
    loss = F.cross_entropy(y_hat, y)

    # logs metrics for each training_step,
    # and the average across the epoch, to the progress bar and logger
    self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
    return loss

如果需要对每个training_step的输出做一些操作,可以通过改写training_epoch_end来实现

def training_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self.model(x)
    loss = F.cross_entropy(y_hat
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值