PyTorch 模型保存与加载 (速查版)

PyTorch模型保存与加载速查


1. 推理用: 保存 & 加载权重 (最常见)

import torch
import torch.nn as nn

model = nn.Linear(10, 2)

# 保存权重
torch.save(model.state_dict(), "model.pt")

# 加载权重 (推理/评估)
model2 = nn.Linear(10, 2)
state = torch.load("model.pt", map_location="cpu")
model2.load_state_dict(state)
model2.eval()   # 推理时别忘了

2. 继续训练用: 保存 & 加载完整状态

# ===== 保存 =====
torch.save({
    "epoch": epoch,
    "model": model.state_dict(),
    "optimizer": optimizer.state_dict(),
    "scheduler": scheduler.state_dict(),
}, "ckpt.pt")

# ===== 加载 =====
ckpt = torch.load("ckpt.pt", map_location="cpu")
model.load_state_dict(ckpt["model"])
optimizer.load_state_dict(ckpt["optimizer"])
scheduler.load_state_dict(ckpt["scheduler"])
start_epoch = ckpt["epoch"] + 1
model.train()   # 继续训练前别忘了

3. 微调用: 部分加载 (分类头不同等情况)

state = torch.load("pretrain.pt", map_location="cpu")
# 只加载匹配的层, 其余保持初始化
missing, unexpected = model.load_state_dict(state, strict=False)
print("未加载:", missing) 	 	# 模型需要,但 checkpoint 里没有的
print("多余:", unexpected)  	# checkpoint 里有,但模型不需要的

速记

  • 推理: 保存/加载 model.state_dict()
  • 继续训练: 把 optimizer/scheduler 一并保存
  • 微调: strict=False 部分加载
  • 安全: map_location="cpu" 加载更通用
  • 模式: 推理用 model.eval(),训练用 model.train()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值