文章目录
LIGHTNINGMODULE
LightningModule将PyTorch代码整理成5个部分:
- Computations (init).
- Train loop (training_step)
- Validation loop (validation_step)
- Test loop (test_step)
- Optimizers (configure_optimizers)
Minimal Example
所需要的方法:
import pytorch_lightning as pl
class LitModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(28 * 28, 10)
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)
使用下面的代码进行训练:
train_loader = DataLoader(MNIST(os.getcwd(), download=True, transform=transforms.ToTensor()))
trainer = pl.Trainer()
model = LitModel()
trainer.fit(model, train_loader)
一些基本方法
Training
Training loop
使用training_step方法来增加training loop
class LitClassifier(pl.LightningModule):
def __init__(self, model):
super().__init__()
self.model = model
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
return loss
如果需要在epoch-level进行度量,并进行记录,可以使用*.log*方法
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
# logs metrics for each training_step,
# and the average across the epoch, to the progress bar and logger
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return loss
如果需要对每个training_step的输出做一些操作,可以通过改写training_epoch_end来实现
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat

PyTorch Lightning 的 LightningModule 结构化了 PyTorch 代码,包括初始化、训练、验证、测试和推理过程。训练循环中,通过 training_step 方法定义每个步骤的损失计算,validation_step 和 test_step 分别用于验证和测试阶段的损失计算,同时支持在 epoch 级别的度量和数据并行计算。在研究和生产环境中,LightningModule 还支持灵活的前向推理和模型迭代。此外,通过 configure_optimizers 定义优化器。
最低0.47元/天 解锁文章
9352

被折叠的 条评论
为什么被折叠?



