模型的保存和加载(pickle)

本文详细介绍了如何使用pickle模块保存和加载机器学习模型,避免重复训练,提高效率。通过具体案例展示了线性回归模型的保存过程及加载后进行预测的步骤。

模型的保存和加载

前言:

模型训练是一个耗时的过程,一个优秀的机器学习是非常宝贵的。可以模型保存到磁盘中,也可以在需要使用的时候从磁盘中重新加载模型即可。不需要重新训练。

模型保存和加载相关API:

import pickle
pickle.dump(内存对象, 磁盘文件) # 保存模型
model = pickle.load(磁盘文件)  # 加载模型

案例:把训练好的模型保存到磁盘中。

注:这里的训练数据就是一些数据而已下面是txt文档里面的截图,这里的内容手动放入到数组当中也可以
在这里插入图片描述

import numpy as np
import sklearn.linear_model as lm
import pickle
# 采集数据(这里我是从文件中读取的数据,也可以手动输入一组数字)
x, y = np.loadtxt('../ml_data/single.txt', delimiter=',', usecols=(0,1), unpack=True)
# 讲x改变为n行一列
x = x.reshape(-1, 1)
# 创建模型
model = lm.LinearRegression()
# 训练模型
model.fit(x,y)
with open('linear.pkl','wb') as f:
    pickle.dump(model,f)
    print('dump sucess')

案例: 加载训练好的模型进行线性回归线的绘制

import numpy as np
import pickle
import matplotlib.pyplot as mp
# 采集数据
x, y = np.loadtxt('../ml_data/single.txt', delimiter=',', usecols=(0,1), unpack=True)
x = x.reshape(-1, 1)
with open('linear.pkl','rb') as f:
    model = pickle.load(f)
# 进行预测
pred_y = model.predict(x)
mp.figure('Linear Regression', facecolor='lightgray')
mp.title('Linear Regression', fontsize=20)
mp.xlabel('x', fontsize=14)
mp.ylabel('y', fontsize=14)
mp.tick_params(labelsize=10)
mp.grid(linestyle=':')
mp.scatter(x, y, c='dodgerblue', alpha=0.75, s=60, label='Sample')
mp.plot(x, pred_y, c='orangered', label='Regression')
mp.legend()
mp.show()

在这里插入图片描述

<think>我们之前讨论了PyTorch的模型保存加载,现在用户询问的是JAX框架下的模型保存加载。根据引用内容,用户提供的引用中提到了JAX以及相关工具(如Paxml),但并没有直接给出JAX保存模型的示例。因此,我们需要根据JAX的常见实践来回答。 在JAX中,模型参数通常存储在`jax.Array`或`jax.numpy.Array`中,并且整个模型的状态(包括参数、优化器状态等)通常使用一个字典或类似结构来保存。JAX本身不提供像PyTorch那样的内置模型保存加载函数,但我们可以使用`flax`(一个流行的JAX神经网络库)或`jax.tree_util`结合序列化库(如`pickle`)来实现。 以下是一个基于Flax的模型保存加载的指南(因为Flax是JAX生态中最常用的神经网络库之一): ### 一、使用Flax保存加载模型 #### 1. 保存模型 在Flax中,我们可以使用`flax.serialization.to_bytes`将模型状态(包括参数)转换为字节,然后保存到文件。也可以使用`flax.serialization.to_state_dict`获取状态字典,再使用其他方式(如`pickle`)保存。 ```python import flax from flax.training import checkpoints # 假设我们有一个训练好的模型`model`它的参数`params` # 创建TrainState(包含参数优化器状态等) state = train_state.TrainState.create( apply_fn=model.apply, params=params, tx=optimizer ) # 保存检查点(包括参数优化器状态) checkpoints.save_checkpoint( ckpt_dir='./checkpoints', target=state, step=step, # 当前训练步数 prefix='my_model_', keep=3 # 保留最近的3个检查点 ) ``` #### 2. 加载模型 加载时,我们使用`checkpoints.restore_checkpoint`来恢复检查点。 ```python # 首先创建一个与原始状态结构相同的空状态(用于加载) state = train_state.TrainState.create( apply_fn=model.apply, params=initial_params, # 初始参数(结构相同) tx=optimizer ) # 加载检查点 state = checkpoints.restore_checkpoint( ckpt_dir='./checkpoints', target=state, prefix='my_model_' ) ``` ### 二、使用纯JAX(无Flax)保存加载 如果不使用Flax,我们可以直接使用JAX的树结构序列化工具。 #### 1. 保存参数 ```python import jax import pickle # 假设params是一个包含模型参数的字典(或任意树结构) with open('params.pkl', 'wb') as f: pickle.dump(params, f) ``` #### 2. 加载参数 ```python with open('params.pkl', 'rb') as f: params = pickle.load(f) ``` 注意:这种方法简单,但可能在不同版本的JAX或Python之间不兼容。另外,对于大型参数,使用`pickle`可能不是最高效的方式。 ### 三、使用`orbax`(推荐) Google推荐使用`orbax`库(属于`jax`生态系统的一部分)进行更高效的检查点保存加载。特别是在分布式训练中,`orbax`可以处理并行保存加载。 ```python from orbax import checkpoint # 保存 checkpointer = checkpoint.PyTreeCheckpointer() checkpointer.save('./checkpoints/orbax', state) # 加载 restored = checkpointer.restore('./checkpoints/orbax') ``` ### 四、关键注意事项 1. **设备放置**:加载的参数可能被绑定到特定的设备(如GPU)。在加载后,可以使用`jax.device_put`将参数移动到当前设备。 2. **版本兼容性**:与PyTorch类似,不同版本的JAX/Flax可能不兼容。建议记录版本信息。 3. **模型结构**:加载参数时,必须有一个与保存时相同结构的模型(或状态)对象。 ### 五、完整示例(使用Flax) ```python import flax from flax import linen as nn from flax.training import train_state import optax import jax import jax.numpy as jnp # 定义模型 class MLP(nn.Module): @nn.compact def __call__(self, x): x = nn.Dense(128)(x) x = nn.relu(x) x = nn.Dense(10)(x) return x # 初始化模型 model = MLP() key = jax.random.PRNGKey(0) dummy = jnp.ones((1, 784)) params = model.init(key, dummy) # 创建TrainState tx = optax.adam(1e-3) state = train_state.TrainState.create( apply_fn=model.apply, params=params, tx=tx ) # 保存 checkpoints.save_checkpoint( ckpt_dir='./ckpt', target=state, step=0, prefix='model_', keep=3 ) # 加载:先创建一个空状态(结构相同) state_restored = train_state.TrainState.create( apply_fn=model.apply, params=model.init(key, dummy), # 初始参数 tx=tx ) state_restored = checkpoints.restore_checkpoint( ckpt_dir='./ckpt', target=state_restored, prefix='model_' ) ``` ### 六、数学原理 在JAX中,模型参数被组织为树状结构(pytree),其中每个叶子节点是一个`jax.Array`。保存过程实质上是将这个树结构序列化为字节流。在加载时,我们将字节流反序列化回相同的树结构。前向传播计算可以表示为: $$ \mathbf{y} = f(\mathbf{x}; \theta) $$ 其中 $\theta$ 是模型参数,$f$ 是模型的前向传播函数。 --- §§相关问题§§ 1. JAX中如何保存加载分布式训练的模型? 2. 使用Flax时,如何保存加载额外的自定义状态(如批归一化统计量)? 3. 在JAX中,如何将模型参数从GPU移动到CPU进行保存? 4. 如何在不重新初始化模型的情况下加载部分参数?</think>### JAX 模型保存加载指南 JAX 使用 `flax` 库(官方神经网络库) `orbax`(高效检查点库)实现模型保存加载。以下是核心方法: --- #### 一、保存模型参数(推荐) ```python import flax from flax.training import orbax_utils import orbax.checkpoint as ocp # 保存模型参数 checkpointer = ocp.PyTreeCheckpointer() save_args = orbax_utils.save_args_from_target(params) # params 为模型参数字典 checkpointer.save('model_dir', params, save_args=save_args) ``` - **优点**: - 文件小(仅保存参数,不包含模型结构) - 兼容不同硬件(TPU/GPU/CPU) - 支持分布式训练检查点[^3] --- #### 二、加载模型参数 ```python # 加载到相同结构的模型 checkpointer = ocp.PyTreeCheckpointer() params = checkpointer.restore('model_dir') ``` - **设备映射**: ```python # 强制加载到 CPU params = checkpointer.restore('model_dir', device=jax.devices('cpu')[0]) ``` --- #### 三、保存完整训练状态(含优化器) ```python state = { 'params': params, 'optimizer': optimizer_state, 'epoch': 15 } checkpointer = ocp.PyTreeCheckpointer() save_args = orbax_utils.save_args_from_target(state) checkpointer.save('full_state', state, save_args=save_args) ``` --- #### 四、恢复训练状态 ```python # 定义空状态结构模板 state_template = { 'params': init_params, # 初始化的参数字典 'optimizer': init_optimizer_state, 'epoch': 0 } # 加载到模板 checkpointer = ocp.PyTreeCheckpointer() full_state = checkpointer.restore('full_state', item=state_template) ``` --- #### 五、关键注意事项 1. **模型结构分离** JAX 不保存模型结构,需单独维护模型定义代码[^2]: ```python class MLP(nn.Module): @nn.compact def __call__(self, x): x = nn.Dense(128)(x) return nn.Dense(10)(x) ``` 2. **分布式训练保存** 使用 `jax.distributed.initialize()` 确保多设备参数正确聚合: ```python from jax.sharding import Mesh, PartitionSpec sharding = jax.sharding.NamedSharding(Mesh(...), PartitionSpec(...)) checkpointer.save(..., shardings={'params': sharding}) ``` 3. **数学原理** 参数保存为树形结构(pytree),每个叶子节点是 $W \in \mathbb{R}^{m \times n}$ 权重矩阵。前向传播计算: $$ \mathbf{y} = \phi(\mathbf{x}W + \mathbf{b}) $$ 其中 $\phi$ 为激活函数(如 ReLU)。 --- #### 六、完整示例 ```python import jax import flax.linen as nn from flax.training import orbax_utils import orbax.checkpoint as ocp # 1. 定义模型 class CNN(nn.Module): @nn.compact def __call__(self, x): x = nn.Conv(32, (3,3))(x) x = nn.relu(x) return nn.Dense(10)(x) # 2. 初始化 model = CNN() params = model.init(jax.random.PRNGKey(0), jnp.ones([1,28,28,1])) # 3. 保存 checkpointer = ocp.PyTreeCheckpointer() save_args = orbax_utils.save_args_from_target(params) checkpointer.save('cnn_params', params, save_args=save_args) # 4. 加载 new_params = checkpointer.restore('cnn_params') ``` > **最佳实践**:生产环境使用 `orbax` 保存参数,研究场景可结合 `flax.serialization` 快速调试[^3][^4]。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值