数据访问
1.加载真实数据集
2.numpy读取文件后读取对应的样本数据
3.抽取对应的建模数据集(训练数据集和测试数据集)
4.抽取自变量矩阵X
5.抽取因变量矩阵Y
6.基于(X-Y)训练集构建决策树模型
7.生成模型文件
8.根据生成的模型文件,输入测试数据样本,调用模型预测测试样本的结果
示例:利用决策树模型预测
加载,训练,保存模型,评估
import time
import numpy as np
from sklearn import metrics
from sklearn import tree
from sklearn.externals import joblib
raw_data='dcdata.data'
#load the CSV file as a numpy matrix
dataset=np.loadtxt(raw_data,delimiter=',')
#separate the data from the target attributes
X=dataset[:,0:7]
y=dataset[:,8]
#训练集合
X_train=dataset[0:500,0:8]
y_train=dataset[0:500,8]
#测试集合
X_test=dataset[500:,0:8]
y_test=dataset[500:,8]
#利用训练数据集建立决策树模型
print('\n调用scikit的tree.DecisionTreeClassifier()')
model=tree.DecisionTreeClassifier(min_samples_leaf=2)
start_time=time.time()
model.fit(X_train,y_train)
print('training took %fs!' % (time.time()-start_time))
#把训练好的决策树模型数据文件保存在磁盘中
joblib.dump(value=model,filename='Decisiontree.model')
#训练样本的准确性评估(准确率和召回率)
expected=y_test
predicted=model.predict(X_test)
print(metrics.confusion_matrix(expected,predicted))
print(metrics.classification_report(expected,predicted))
结果:
调用scikit的tree.DecisionTreeClassifier()
training took 0.015627s!
[[151 31]
[ 35 51]]
precision recall f1-score support
0.0 0.81 0.83 0.82 182
1.0 0.62 0.59 0.61 86
accuracy 0.75 268
macro avg 0.72 0.71 0.71 268
weighted avg 0.75 0.75 0.75 268
还原模型,预测
import numpy as np
from sklearn.externals import joblib
dataset=np.loadtxt(fname='dcdata.data',delimiter=',')
x_predict=dataset[500:510,0:8]
y_real=dataset[500:510,8]
gnbmodel=joblib.load(filename='Decisiontree.model')
y_predict=gnbmodel.predict(x_predict)
print('预测值')
print(y_predict)
print('真实值')
print(y_real)
结果:
预测值
[0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
真实值
[0. 0. 1. 0. 0. 0. 1. 0. 0. 0.]
安装mysql
改密码ALTER user ‘root’@‘localhost’ IDENTIFIED BY ‘新密码’
启动mysql -u root -p
引用MySQL-python库包
import MySQLdb
显示用户权限show grants for ‘hadoop’@‘localhost’;
//赋予权限grant select,insert,update,delete on pythondb.* to ‘hadoop’@‘localhost’;
需要DROP权限
grant ALL on pythondb.* to ‘hadoop’@‘localhost’;
访问mysql
import MySQLdb
#打开数据库连接
db=MySQLdb.connect('127.0.0.1','hadoop','hadoop','pythondb',charset='utf8') //127.0.0.1:localhost
#使用cursor()方法获取操作游标
cursor=db.cursor()
#使用execute方法执行SQL语句
cursor.execute('SELECT VERSION()')
#使用fetchone()方法获取一条数据
data=cursor.fetchone()
print('Database version:%s' % data)
#关闭数据库连接
db.close()
结果:
Database version:8.0.16
创建数据库表
#如果数据库连接存在,可以使用execute()方法来为数据库创建表
import MySQLdb
#打开数据库连接
db=MySQLdb.connect('127.0.0.1','hadoop','hadoop','pythondb',charset='utf8')
#使用cursor()方法获取操作游标
cursor=db.cursor()
#如果数据表已经存在使用execute()方法删除表
cursor.execute('DROP TABLE IF EXISTS EMPLOYEE')
#创建数据表SQL语句
sql='''CREATE TABLE EMPLOYEE(
FIRST_NAME CHAR(20) NOT NULL,
LAST_NAME CHAR(20),
AGE INT,
SEX CHAR(1),
INCOME FLOAT)'''
cursor.execute(sql)
#关闭数据库连接
db.close()
结果:
(去mysql那边查看)
14 19:00