DeepCTR-Torch 如何保存模型

这篇博客介绍了在使用DeepCTR库时遇到的模型保存和加载问题。当尝试用TensorFlow的`save_model`和`load_model`函数保存和加载DeepFM模型时,出现AttributeError。解决方法是改用PyTorch的方式,通过`torch.save`和`torch.load`来保存和加载模型的状态字典。示例代码展示了如何在DeepFM模型上实现这一过程。

官网给出的示例:

DeepCTR Documentation, Release 0.9.0
from tensorflow.python.keras.models import save_model,load_model
model = DeepFM()
save_model(model, 'DeepFM.h5')# save_model, same as before
from deepctr.layers import custom_objects
model = load_model('DeepFM.h5',custom_objects)# load_model,just add a parameter

实际执行的时候会报错:

AttributeError: 'DeepFM' object has no attribute 'outputs'

解决办法:

采用 torch的模型保存办法,下述示例是 Deep-torch 的示例文件 run_classification_criteo.py

from deepctr_torch.models import *

#创建model1

model = DeepFM(linear_feature_columns, dnn_feature_columns, task='binary')

#model 处理

#...

保存以及读取模型:

#model 保存

import torch

torch.save(model.state_dict(),"a.txt")

model2 = DeepFM(linear_feature_columns, dnn_feature_columns, task='binary')

model2.load_state_dict(torch.load("a.txt"))

pred_ans2 = model2.predict(test_model_input, batch_size=25
评论 1
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值