Pytorch学习第一天: 模型保存和加载方式

本文介绍了三种PyTorch模型的保存与加载方法:仅保存模型参数的状态字典、保存整个模型、保存训练过程中的检查点;并推荐了用于显示训练进度的tqdm工具。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1. 方式一: 加载和保存状态字典

只保存参数的方法在加载的时候要事先定义好跟原模型一致的模型,并在该模型的实例对象(假设名为model)上进行加载,即在使用上述加载语句前已经有定义了一个和原模型一样的Net, 并且进行了实例化 model=Net( ) .

## 保存模型的字典参数, 保存模型的后缀通常为'.pt' 或者 ‘.pth’
torch.save(model.state_dict(), PATH) ## PATH必须要有文件名称如./model.pth

## 加载模型
model = init() ## 意思为每一个不同类别的模型,需要先初始化,这里并不会加载模型,只是申请了一个网络框架;
model.load_state_dict(torch.load(PATH))
model.eval()

2. 方式二: 加载和保存整个模型

## 保存整个模型
torch.save(model, PATH)

## 加载模型
model = torch.load(PATH)
model.eval()

3. 方式三: 保存训练的某个checkpoint

## a. 先建立一个字典,保存三个参数:

state = {‘model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}

## b. 保存checkpoints
torch.save(state, PATH)

## c. 加载checkpoints去继续训练或预测

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch'] + 1

4. pytorch训练显示进度条的工具

tqdm https://github.com/tqdm/tqdm

在这里插入图片描述

参考学习链接ls

[1] PyTorch | 保存和加载模型-知乎
[2] Pytorch在加载的模型基础上继续训练

### YOLOv8 第三天学习教程或常见问题 #### TensorRT 部署中的挑战与解决方案 在YOLOv8的TensorRT部署过程中,可能会遇到一些特定的技术难题。为了实现高效的推理性能,在Python环境中完成YOLOv8到TensorRT的转换至关重要[^1]。 ```python import torch from ultralytics import YOLO from models.experimental import attempt_load from utils.general import non_max_suppression, scale_coords from utils.datasets import letterbox import numpy as np import tensorrt as trt import pycuda.driver as cuda import pycuda.autoinit import time def load_model(model_path): model = attempt_load(model_path, map_location=torch.device('cpu')) return model # 假设已经有一个函数可以将pytorch模型转成tensorrt引擎 engine = create_tensorrt_engine(load_model("path_to_your_model")) ``` #### 手势识别系统的构建 利用YOLOv8系列模型配合PySide6图形界面开发的手势识别应用展示了如何集成深度学习算法进入实际产品中。此项目不仅涵盖了模型的选择训练,还包括了用户交互设计等方面的内容[^2]。 #### 数据准备与预处理 对于任何机器学习任务来说,良好的数据基础都是成功的关键之一。当涉及到目标检测时,合理的划分测试集(`test.txt`)训练集(`train.txt`), 并按照指定格式准备好配置文件(.yaml),能够显著提升最终效果。例如: - 训练参数设置如下: - 使用的数据集路径:`data=datasets/mmship/data/RGB/MMship.yaml` - 模型定义文件:`model=yolov8n.yaml` - 权重初始化方式:`pretrained=ultralytics/yolov8n.pt` - 迭代次数设定为:`epochs=12` 以上这些都直接影响着模型的表现质量[^3]。 #### ONNX 格式的导出流程 除了支持原生框架内的预测外,YOLOv8还允许使用者轻松地将其内部结构转化为其他流行的标准形式如ONNX (Open Neural Network Exchange) ,从而便于跨平台移植或者加速计算过程。具体操作方法非常简单直观[^4]: ```python from ultralytics import YOLO # 加载已有的最佳权重文件 model = YOLO('best.pt') # 将其保存为.onnx结尾的新文件 model.export(format='onnx') ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

爱发呆de白菜头

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值