加载模型继续训练与测试

pytorch lightning

继续训练

checkpoint_callback = ModelCheckpoint(
    filename='./Chxxxxxxch=69-step=1470.ckpt',
    monitor='Val ACC',
    save_top_k=10,
    mode='max')

gpus = 1
trainer = pl.Trainer(devices=gpus, logger=TensorBoardLogger(save_dir='./Cxxxis/logs_flair_loc3'), log_every_n_steps=1,
                    callbacks=checkpoint_callback, max_epochs=30)

trainer.fit(model, train_loader, val_loader)

测试

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = flair.load_from_checkpoint('./is/logs_flair2/lightning_logs/version_4/checkpoints/epoch=9-step=720.ckpt')
model.eval()
model.to(device)

preds = []
labels = []

with torch.no_grad():
    for data, label in tqdm(val_dataset):
        data = data.to(device).float().unsqueeze(0)
        pred = torch.sigmoid(model(data)[0].cpu())
        preds.append(pred)
        labels.append(label)
preds = torch.tensor(preds)
labels = torch.tensor(labels).int()

对模型最后两层稍作删减后,仍可加载checkpoint使用对应的参数。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值