pytorch lightning
继续训练
checkpoint_callback = ModelCheckpoint(
filename='./Chxxxxxxch=69-step=1470.ckpt',
monitor='Val ACC',
save_top_k=10,
mode='max')
gpus = 1
trainer = pl.Trainer(devices=gpus, logger=TensorBoardLogger(save_dir='./Cxxxis/logs_flair_loc3'), log_every_n_steps=1,
callbacks=checkpoint_callback, max_epochs=30)
trainer.fit(model, train_loader, val_loader)
测试
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = flair.load_from_checkpoint('./is/logs_flair2/lightning_logs/version_4/checkpoints/epoch=9-step=720.ckpt')
model.eval()
model.to(device)
preds = []
labels = []
with torch.no_grad():
for data, label in tqdm(val_dataset):
data = data.to(device).float().unsqueeze(0)
pred = torch.sigmoid(model(data)[0].cpu())
preds.append(pred)
labels.append(label)
preds = torch.tensor(preds)
labels = torch.tensor(labels).int()
对模型最后两层稍作删减后,仍可加载checkpoint使用对应的参数。