11. 机器学习——sklearn 中模型的保存和加载

本文介绍了使用sklearn库中的joblib模块来保存和加载模型的过程。通过预测波士顿房价的数据集,演示了如何训练一个岭回归模型,然后将其保存为.pkl文件,最后再从该文件中加载模型并进行预测。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

sklearn模型的保存和加载

from sklearn.externals import joblib

在这里插入图片描述
示例;

保存代码

from sklearn.datasets import load_boston
from sklearn.linear_model import Ridge
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error
from sklearn.externals import joblib

#线性回归预测房价
#获取数据
lb = load_boston()

#分割数据集到训练集和测试集
x_train, x_test, y_train, y_test = train_test_split(lb.data, lb.target, test_size=0.25)

print(y_test.shape) #一维的
#进行标准化处理(特征值和目标值都需要标准化处理)
std_x = StandardScaler()

x_train = std_x.fit_transform(x_train)
x_test = std_x.transform(x_test)

#目标值
std_y = StandardScaler()  

y_train = std_y.fit_transform(y_train.reshape(-1,1))   #sklearn 0.19之后必须要求穿进去的数组是二维的

y_test = std_y.transform(y_test.reshape(-1,1))

#岭回归
rd = Ridge()
rd.fit(x_train,y_train)
print(rd.coef_)
#预测测试集房子的价格
y_predict = rd.predict(x_test)
y_predict = std_y.inverse_transform(y_predict) #反标准化
print('房子的价格',y_predict)

print("岭回归的均方误差:", mean_squared_error(std_y.inverse_transform(y_test), y_predict))

#  保存训练好的模型
joblib.dump(rd, "./test.pkl")

加载模型代码:

from sklearn.datasets import load_boston
from sklearn.linear_model import Ridge
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error
from sklearn.externals import joblib

#线性回归预测房价
#获取数据
lb = load_boston()

#分割数据集到训练集和测试集
x_train, x_test, y_train, y_test = train_test_split(lb.data, lb.target, test_size=0.25)

print(y_test.shape) #一维的
#进行标准化处理(特征值和目标值都需要标准化处理)
std_x = StandardScaler()

x_train = std_x.fit_transform(x_train)
x_test = std_x.transform(x_test)

#目标值
std_y = StandardScaler()  

y_train = std_y.fit_transform(y_train.reshape(-1,1))   #sklearn 0.19之后必须要求穿进去的数组是二维的

y_test = std_y.transform(y_test.reshape(-1,1))

# 预测房价结果
model = joblib.load("./test.pkl")

y_predict = std_y.inverse_transform(model.predict(x_test))

print("保存的模型预测的结果:", y_predict)

结果:

(127,)
保存的模型预测的结果: [[16.23388778]
 [34.46783927]
 [17.39207575]
 [20.34968571]
 [24.9583924 ]
 [15.61849472]
 [35.87311213]
 [25.18871655]
 [20.8395896 ]
 [31.79856304]
 [26.8584293 ]
 [29.58731165]
 [31.52979379]
 [26.98212629]
 [23.76453604]
 [15.39954177]
 [21.12670518]
 [37.27918572]
 [17.40884427]
 [26.38441272]
 [30.99260298]
 [11.90111108]
 [33.466762  ]
 [18.64718159]
 [16.3283264 ]
 [29.90053098]
 [28.12101341]
 [35.4850507 ]
 [32.16509515]
 [26.20302906]
 [33.60972295]
 [17.35841016]
 [24.81281014]
 [19.69721431]
 [19.89760743]
 [14.5323029 ]
 [19.40147275]
 [18.18225318]
 [26.51806171]
 [18.09390639]
 [28.72281652]
 [21.61058982]
 [31.25590894]
 [23.58129881]
 [24.62088413]
 [20.96266656]
 [36.04464208]
 [28.14231578]
 [21.57702682]
 [17.35323118]
 [ 0.1633365 ]
 [26.55874935]
 [11.26994628]
 [21.05984545]
 [14.58014368]
 [25.96573008]
 [25.43953375]
 [24.02403274]
 [17.25132785]
 [25.22760346]
 [30.15305598]
 [ 6.03964511]
 [ 0.96805392]
 [25.62738799]
 [32.96577784]
 [17.99656637]
 [21.45162342]
 [ 6.26979511]
 [14.36472757]
 [12.55306632]
 [ 8.40092134]
 [35.36277669]
 [20.05142827]
 [24.03724039]
 [12.24981433]
 [36.3640668 ]
 [23.54912163]
 [30.46909114]
 [42.36433528]
 [22.96102391]
 [31.74392036]
 [19.33368472]
 [25.14295165]
 [31.10898614]
 [34.99304306]
 [26.69857033]
 [19.15719529]
 [18.07479011]
 [14.31562316]
 [16.65082112]
 [18.80402695]
 [16.61279377]
 [20.5808669 ]
 [32.06112169]
 [24.95389865]
 [23.23594216]
 [30.60768144]
 [14.66489868]
 [23.94118664]
 [24.91552985]
 [22.78715915]
 [21.37019274]
 [23.49912324]
 [19.07787155]
 [24.93010068]
 [37.19696341]
 [32.38019967]
 [18.81492816]
 [35.75030244]
 [ 8.53630579]
 [28.3872857 ]
 [19.58905858]
 [20.99951863]
 [24.624171  ]
 [23.65550723]
 [23.16321702]
 [20.46835972]
 [24.1533658 ]
 [27.53038139]
 [11.70073628]
 [34.24201791]
 [29.18474934]
 [22.19946866]
 [20.21059142]
 [20.1780728 ]
 [32.67601873]
 [15.98186465]]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值