本关必读
当数据量很大的时候,训练一个模型需要消耗很大的时间成本,每次都重新训练模型预测是非常冗余且没有必要的,我们可以将训练模型存储下来,每当要预测新数据的时候只需加载该模型。
训练模型的持久化需要调用python的内建模块pickle,pickle可以用来将python对象转化为字节流存储至磁盘,也可以逆向操作将磁盘上的字节流恢复为python对象。pickle的常用函数包括:
#将对象obj写到文件file中
pickle.dump(obj, file[, protocol])
#从文件file读取数据流并将其重建返回原始对象
pickle.load(file)
关于pickle的详细使用可以参考官方文档:https://docs.python.org/2/library/pickle.html
python 文件操作:
open(路径+文件名,读写模式)
读写模式:
r:只读;
r+:读写;
w:新建(会覆盖原有文件);
a:追加;
b:二进制文件;
示例:
#打开本地file文件,并开启写模式
fw=open('file', 'wb')
#向file文件中写入‘hello,world’
fw.write('hello,world')
#打开file文件并读取其中内容
fw=open('file', 'rb')
fw.read()
‘wb’,'rb'分别表示以二进制流的方式进行写入和读取。
####本关任务
本关在上一关的基础上,希望将分类模型存储下来,当需要预测数据时加载该模型返回预测值。
本关需编程实现step3/dumpClassificationModel.py 的dumpModel()函数存储分类模型,并且实现loadModel()函数来加载存储模型对预测数据分类,分类模型的实现在createModel()函数中:
# 导入数据集,分类器相关包
from sklearn import datasets, svm, metrics
import pickle
# 导入digits数据集
digits = datasets.load_digits()
n_samples = len(digits.data)
data = digits.data
# 使用前一半的数据集作为训练数据,后一半数据集作为测试数据
train_data,train_target = data[:n_samples // 2],digits.target[:n_samples // 2]
test_data,test_target = data[n_samples // 2:],digits.target[n_samples // 2:]
def createModel():
classifier = svm.SVC()
classifier.fit(train_data,train_target)
return classifier
local_file = 'dumpfile'
def dumpModel():
'''
存储分类模型
'''
clf = createModel()
# 请在此添加实现代码 #
#********** Begin *********#
#********** End **********#
def loadModel():
'''
加载模型,并使用模型对测试数据进行预测,返回预测值
返回值:
predicted - 模型预测值
'''
predicted = None
# 请在此添加实现代码 #
#********** Begin *********#
#********** End **********#
return predicted
实现提示:
local_file对应即将存储在平台的文件的名称。dumpModel()函数中首先需要打开local_file文件,并开启写入模式,再使用pickle模块将模型存储下来,loadModel()函数中也需要先打开local_file文件,开启读取模式,再使用pickle将local_file文件中存储的模型load至模型变量中,再使用该模型预测。
测试说明
本关的测试文件是step3/testDigitsClassification.py该代码负责对你的实现代码进行测试,注意该测试不能被修改,该测试代码具体如下:
import dumpClassificationModel
import os
dumpClassificationModel.dumpModel()
if os.path.exists('dumpfile'):
print("dump success")
else:
print("dump fail")
predicted = dumpClassificationModel.loadModel()
print(predicted[:10])
测试函数将直接调用step3/dumpClassificationModel.py 的dumpModel()存储模型,并检测平台是否存在该文件,然后调用loadModel()函数得到模型实际预测值,平台通过比较预期输出与实际输出的前10个值来判断成功加载模型并预测。
# 导入数据集,分类器相关包
from sklearn import datasets, svm, metrics
import pickle
# 导入digits数据集
digits = datasets.load_digits()
n_samples = len(digits.data)
data = digits.data
# 使用前一半的数据集作为训练数据,后一半数据集作为测试数据
train_data,train_target = data[:n_samples // 2],digits.target[:n_samples // 2]
test_data,test_target = data[n_samples // 2:],digits.target[n_samples // 2:]
def createModel():
classifier = svm.SVC()
classifier.fit(train_data,train_target)
return classifier
local_file = 'dumpfile'
def dumpModel():
'''
存储分类模型
'''
clf = createModel()
# 请在此处补全模型存储语句 #
#********** Begin *********#
f_model = open(local_file, 'wb')
pickle.dump(clf, f_model)
#********** End **********#
def loadModel():
'''
加载模型,并使用模型对测试数据进行预测,返回预测值
返回值:
predicted - 模型预测值
'''
predicted = None
# 请在此处补全模型加载语句,并对预测数据分类返回预测值#
#********** Begin *********#
fw = open(local_file, 'rb')
classifier = pickle.loads(fw.read())
predicted = classifier.predict(test_data)
#********** End **********#
return predicted