保存和加载 pytorch模型参数(Save and Load model's parameters in Pytorch)

本文介绍了如何在PyTorch中保存和加载模型参数,包括使用state_dict保存和加载模型权重,以及保存和加载整个模型。推荐使用state_dict方法,它允许在不同环境中加载模型。同时,文章讨论了在保存多个模型和预训练参数时的注意事项,并展示了如何处理加载时不匹配的键。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import torch.optim as optim

保存和读取模型

  1. torch.save: 将序列化的对象保存到磁盘,其中函数使用python的pickle模块
    序列化,模型,张量,和字典(Models, tensors,and dictionaries)都可以使用该函数保存。

2.torch.load: 使用pickle的unpickle特性,将序列化对象文件反序列化到内存中
3. torch.nn.Module.load_state_dict: 使用 反序列化的state_dict 加载一个模型的参数字典

state_dict是啥?

pytorch中模型的学习的参数(weights and biases)都保存在模型的parameters中(通过model.parameters()可访问)。

简单来说,state_dict是一个python类型的dictionary对象,将每一层与它的parameter tensor映射起来{firat_layer: parameters}

注意:只有对于有learnable parameters网络层(卷积层,线性层等) 和registered buffers(批归一化的runnin_mean); 以及 optimizer对象(torch.optim),其中包含优化器optimizer的状态,和超参数。


# define model
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()

        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(6, 16, 5)

        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))

        x = x.view(-1, 16*5*5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        return x

# initialize model
model = TheModelClass()

# initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# print model's state_dict
print("model's state_dict")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())


# print optimizer's state_dict
print("optimizer's state_dict")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

### ---output---
'''
model's state_dict
conv1.weight     torch.Size([6, 3, 5, 5])
conv1.bias       torch.Size([6])
conv2.weight     torch.Size([16, 6, 5, 5])
conv2.bias       torch.Size([16])
fc1.weight       torch.Size([120, 400])
fc1.bias         torch.Size([120])
fc2.weight       torch.Size([84, 120])
fc2.bias         torch.Size([84])
fc3.weight       torch.Size([10, 84])
fc3.bias         torch.Size([10])
optimizer's state_dict
state    {}
param_groups     [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [2471740491528, 2471740491600, 2471740491672, 2471740491744, 2471740491816, 2471740491888, 2471740491960, 2471740492032, 2471740492104, 2471740492176]}]
'''

‘’’

Saving & Loading model for Inference
1. Save/Load state_dict 推荐使用此方法

Save:
torch.save(model.state_dict(), PATH)
Load:

  model = TheModelCalss(*args, **kwargs)
  model.load_state_dict(torch.load(PATH))
  model.eval()

pytorch习惯将模型存储的文件类型设为 .pt,或.pth
注意: 在预测之前需要调用model.eval()将dropout和batch normlization设为evaluation model
否则会出现预测结果不一致的问题

注意: 函数load_stare_dict()参数为dictionary对象,这意味着需要解序列化保存的state_dict,
例如:下列语法错误: model.load_state_dict(PATH)

2. Save/Load Entire Model

Save:
torch.save(model, PATH)

Load:

# model class must be defined somewhere
model = torch.load(PATH)
model.eval()

该方法会使用python的pickle模块保存整个model。

缺点是pickle并不保存model类本身,而是保存包含该model class文件的路径。比如将model class定义在model.py中,pickle会保存…/mode.py这个文件路径。这就使得在别的项目中使用非常容易出错。

同样习惯设置保存文件名为.pth/.pth 和要在预测前调用model.eval()

######Saving & Loading a general checkpoint for Inference and /or Resuming Training

Save:

torch.save({
     'epoch': epoch,
     'model_state_dict': model.staet_dict(),
     'optimizer_state_dict': optimizer.state_dict(),
     'loss': loss,
     ...
 }, PATH)

Load:

  model = TheModelClass(*args, **kwargs)
  optimizer = TheOptimizerClass(*args, **kwargs)

  checkpoint = torch.load(PATH)
  model.load_state_dict(checkpoint['model_state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  epoch = checkpoint['epoch']
  loss = checkpoint['loss']
  ...

  model.eval()

  # -or-
  model.train()

要保存多个组件,需要将他们组织在一个字典中,使用torch.save()来序列化字典,通常使用的文件类型为.tar

要加载该model参数,首先初始化model和optimizer,之后使用torch.load()加载dictionary.按照存储时的定义来加载相关参数

在运行inference之前调用model.eval()
在重新恢复训练,调用model.train()

在一个文件中保存多个model

Save:

torch.save({
            'modelA_state_dict': modelA.state_dict(),
            'modelB_state_dict': modelB.state_dict(),
            'optimizerA_state_dict': optimizerA.state_dict(),
            'optimizerB_state_dict': optimizerB.state_dict(),
            ...
            }, PATH)

Load:

modelA = TheModelAClass(*args, **kwargs)
modelB = TheModelBClass(*args, **kwargs)
optimizerA = TheOptimizerAClass(*args, **kwargs)
optimizerB = TheOptimizerBClass(*args, **kwargs)

checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])

modelA.eval()
modelB.eval()
# - or -
modelA.train()
modelB.train()

使用不同模型的参数来预训练

Save:

torch.save(modelA.state_dict(), PATH)

Load:

modelB = TheModelBClass(*args, **kwargs)
 modelB.load_state_dict(torch.load(PATH), strict=False)

strict=False, 会忽略不匹配的键,

如果想用某层网络的参数来初始化该层,但是某些键不匹配,可以将键名改为模型中网络层键名

你可以按照以下步骤使用CNN模型进行手写数字识别,包括CSV文件的读取、保存加载以及测试数据集单张图片。 1. 导入所需的库: ```python import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader import pandas as pd import numpy as np from sklearn.model_selection import train_test_split ``` 2. 定义CNN模型: ```python class CNN(nn.Module): def __init__(self): super(CNN, self).__init__() self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1) self.relu1 = nn.ReLU() self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1) self.relu2 = nn.ReLU() self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) self.fc1 = nn.Linear(7*7*64, 128) self.relu3 = nn.ReLU() self.fc2 = nn.Linear(128, 10) def forward(self, x): x = self.conv1(x) x = self.relu1(x) x = self.pool1(x) x = self.conv2(x) x = self.relu2(x) x = self.pool2(x) x = x.view(-1, 7*7*64) x = self.fc1(x) x = self.relu3(x) x = self.fc2(x) return x model = CNN() ``` 3. 读取CSV文件并准备数据集: ```python class CustomDataset(Dataset): def __init__(self, csv_path): self.data = pd.read_csv(csv_path, header=None) def __len__(self): return len(self.data) def __getitem__(self, index): label = self.data.iloc[index, 0] image = self.data.iloc[index, 1:].values.reshape(28, 28).astype(np.uint8) image = np.expand_dims(image, axis=0) return image, label csv_path = 'path/to/your/csv/file.csv' dataset = CustomDataset(csv_path) ``` 4. 划分训练集测试集: ```python train_dataset, test_dataset = train_test_split(dataset, test_size=0.2, random_state=42) ``` 5. 创建数据加载器: ```python batch_size = 64 train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) ``` 6. 定义损失函数优化器: ```python criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) ``` 7. 训练模型: ```python num_epochs = 10 for epoch in range(num_epochs): train_loss = 0.0 model.train() for images, labels in train_loader: optimizer.zero_grad() outputs = model(images.float()) loss = criterion(outputs, labels) loss.backward() optimizer.step() train_loss += loss.item() * images.size(0) train_loss /= len(train_loader.dataset) print(f"Epoch: {epoch+1}/{num_epochs}, Training Loss: {train_loss:.4f}") ``` 8. 保存训练好的模型: ```python torch.save(model.state_dict(), 'path/to/save/model.pth') ``` 9. 加载保存的模型并进行测试: ```python model.load_state_dict(torch.load('path/to/save/model.pth')) model.eval() test_loss = 0.0 correct = 0 with torch.no_grad(): for images, labels in test_loader: outputs = model(images.float()) loss = criterion(outputs, labels) test_loss += loss.item() * images.size(0) _, predicted = torch.max(outputs.data, 1) correct += (predicted == labels).sum().item() test_loss /= len(test_loader.dataset) accuracy = correct / len(test_loader.dataset) * 100 print(f"Test Loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%") ``` 10. 使用模型进行单张图片的预测: ```python from PIL import Image image_path = 'path/to/your/image.jpg' image = Image.open(image_path).convert('L') image = np.array(image) image = np.expand_dims(image, axis=0) image = torch.from_numpy(image).unsqueeze(0) output = model(image.float()) _, predicted = torch.max(output.data, 1) print(f"Predicted Label: {predicted.item()}") ``` 请将 `path/to/your/csv/file.csv` `path/to/save/model.pth` 替换为您的实际文件路径。另外,确保您的CSV文件以及要测试的单张图片符合要求。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值