模型训练与预测
1、导入所需模块
2、缺失值处理,特正规一会,类别特征转化等
3、训练模型,选择合适的机器学习模型,利用训练集对模型进行训练
4、预测结果
随机森林
随机森林参数介绍:https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html#sklearn.ensemble.RandomForestClassifier
随机森林属于集成学习,准确率较高,能够有效的运行在大数据集上,处理具有高位特征的输入样本,评估特征重要度
sklearn调用的随机森林分类树的预测算法
from sklearn import datasets
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
#导入数据
iris = datasets.load_iris()
feature = iris.feature_names
X = iris.data
y = iris.target
#随机森林
clf = RandomForestClassifier(n_estimators = 200)
train_X, test_X, train_y, test_y = train_test_split(X, y, test_size = 0.1, random_state = 5)
clf.fit(train_X, train_y)
test_pred = clf.predict(test_X)
#特征的重要性查看
print(str(feature) + '\n' + str(clf.feature_importances_))
#F1_score用于模型评价
score = f1_score(test_y, test_pred, average = 'macro')
print("随机森林-macro:",score)
score=f1_score(test_y,test_pred, average='weighted')
print("随机森林-weighted:", score)
[‘sepal length (cm)’, ‘sepal width (cm)’, ‘petal length (cm)’, ‘petal width (cm)’]
[0.10214829 0.02506777 0.42989154 0.4428924 ]
随机森林-macro: 0.818181818181818
随机森林-weighted: 0.8
lightGBM模型
lightGBM模型学习资料:https://mp.weixin.qq.com/s/64xfT9WIgF3yEExpSxyshQ
LightGBM(Light Gradient Boosting Machine)是一个实现GBDT算法的框架,支持高效率的并行训练,并且具有更快的训练速度、更低的内存消耗、更好的准确率、支持分布式可以快速处理海量数据等优点。
#Lightgbm
import lightgbm as lgb
from sklearn import datasets
from sklearn.model_selection import train_test_split
import numpy as np
from sklearn.metrics import roc_auc_score, accuracy_score
import matplotlib.pyplot as plt
#加载数据
iris = datasets.load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size = 0.3)
#转化为DataSet数据格式
train_data = lgb.Dataset(X_train, label = y_train)
validation_data = lgb.Dataset(X_test, label = y_test)
#参数
results = {
}
params = {
'learning_rate': 0.1, 'lambda_l1': 0.1, 'lambda_l2': 0.9, 'max_depth': 1, 'objective': 'multiclass', 'num_class': 3, 'verbose': -1}
#模型训练
gbm = lgb.train(params, train_data, valid_sets = (validation_data, train_data), valid_names = ('validate','train'), evals_result = results)
#模型预测
y_pred_test = gbm.predict(X_test)
y_pred_data = gbm.predict(X_train)
y_pred_data = [list(x).index(max(x)) for x in y_pred_data]
y_pred_test = [list(x).index(max(x)) for x in y_pred_test]
#模型评估
print(accuracy_score(y_test, y_pred_test))
print('训练集', f1_score(y_train, y_pred_data, average = 'macro'))
print('验证集', f1_score(y_test, y_pred_test, average = 'macro'))
lgb.plot_metric(results)
结果如下
[1] train's multi_logloss: 0.976938 validate's multi_logloss: 0.981023
[2] train's multi_logloss: 0.875592 validate's multi_logloss: 0.88226
[3] train's multi_logloss: 0.788956 validate's multi_logloss: 0.802856
[4] train's multi_logloss: 0.713724 validate's multi_logloss: 0.734451
[5] train's multi_logloss: 0.649018 validate's multi_logloss: 0.669279
[6] train's multi_logloss: 0.591531 validate's multi_logloss: 0.62037
[7] train's multi_logloss: 0.541379 validate's multi_logloss: 0.573368
[8] train's multi_logloss: 0.496263 validate's multi_logloss: 0.533024
[9] train's multi_logloss: 0.456348 validate's multi_logloss: 0.50304
[10] train's multi_logloss: 0.420516 validate's multi_logloss: 0.466105
[11] train's multi_logloss: 0.388361 validate's multi_logloss: 0.444444
[12] train's multi_logloss: 0.359458 validate's multi_logloss: 0.414478
[13] train's multi_logloss: 0.333213 validate's multi_logloss: 0.393904
[14] train's multi_logloss: 0.30962 validate's multi_logloss: 0.378558
[15] train's multi_logloss: 0.28816 validate's multi_logloss: 0.356121
[16] train's multi_logloss: 0.268664 validate's multi_logloss: 0.343697
[17] train's multi_logloss: 0.250909 validate's multi_logloss: 0.329402
[18] train's multi_logloss: 0.234768 validate's multi_logloss: 0.318374
[19] train's multi_logloss: 0.220133 validate's multi_logloss: 0.309139
[20] train's multi_logloss: 0.206499 validate's multi_logloss: 0.297277
[21] train's multi_logloss: 0.194156 validate's multi_logloss: 0.287945
[22] train's multi_logloss: 0.182832 validate's multi_logloss: 0.284079
[23] train's multi_logloss: 0.172406 validate's multi_logloss: 0.274423
[24] train's multi_logloss: 0.162929 validate's multi_logloss: 0.270819
[25] train's multi_logloss: 0.154074 validate's multi_logloss: 0.266948
[26] train's multi_logloss: 0.14596 validate's multi_logloss: 0.259182
[27] train's multi_logloss: 0.138531 validate's multi_logloss: 0.256827
[28] train's multi_logloss: 0.131556 validate's multi_logloss: 0.254639
[29] train's multi_logloss: 0.125138 validate's multi_logloss: 0.250006
[30] train's multi_logloss: 0.119295 validate's multi_logloss: 0.244976
[31] train's multi_logloss: 0.113977 validate's multi_logloss: 0.241169
[32] train's multi_logloss: 0.10876 validate's multi_logloss: 0.240898
[33] train's multi_logloss: 0.103932 validate's multi_logloss: 0.239186
[34] train's multi_logloss: 0.099617 validate's multi_logloss: 0.236112
[35] train's multi_logloss: 0.0954041 validate's multi_logloss: 0.236957
[36] train's multi_logloss: 0.0915208 validate's multi_logloss: 0.237267
[37] train's multi_logloss: 0.0880111 validate's multi_logloss: 0.236139
[38] train's multi_logloss: 0.0845806 validate's multi_logloss: 0.236863
[39] train's multi_logloss: 0.081451 validate's multi_logloss: 0.237791
[40] train's multi_logloss: 0.0785937 validate's multi_logloss: 0.23897
[41] train's multi_logloss: 0.0759003 validate's multi_logloss: 0.238387
[42] train's multi_logloss: 0.0733256 validate's multi_logloss: 0.239295
[43] train's multi_logloss: 0.0707755 validate's multi_logloss: 0.241507
[44] train's multi_logloss: 0.0684962 validate's multi_logloss: 0.24261
[45] train's multi_logloss: 0.0662404 validate's multi_logloss: 0.244052
[46] train's multi_logloss: 0.0642087 validate's multi_logloss: 0.246097
[47] train's multi_logloss: 0.0623323 validate's multi_logloss: 0.247787
[48] train's multi_logloss: 0.0605096 validate's multi_logloss: 0.246881
[49] train's multi_logloss: 0.0588124 validate's multi_logloss: 0.248425
[50] train's multi_logloss: 0.0571935 validate's multi_logloss: 0.250581
[51] train's multi_logloss: 0.0556659 validate's multi_logloss: 0.252133
[52] train's multi_logloss: 0.0542134 validate's multi_logloss: 0.254299
[53] train's multi_logloss: 0.0528306 validate's multi_logloss: 0.255876
[54] train's multi_logloss: 0.0515196 validate's multi_logloss: 0.257431
[55] train's multi_logloss: 0.0501929 validate's multi_logloss: 0.257479
[56] train's multi_logloss: 0.0489915 validate's multi_logloss: 0.259634
[57] train's multi_logloss: 0.0478563 validate's multi_logloss: 0.261256
[58] train's multi_logloss: 0.0467488 validate's multi_logloss: 0.263452
[59] train's multi_logloss: 0.0456721 validate's multi_logloss: 0.26314
[60] train's multi_logloss: 0.0446422 validate's multi_logloss: 0.262897
[61] train's multi_logloss: 0.0436813 validate's multi_logloss: 0.264889
[62] train's multi_logloss: 0.042755 validate's multi_logloss: 0.265027
[63] train's multi_logloss: 0.0418598 validate's multi_logloss: 0.267001
[64] train's multi_logloss: 0.0410012 validate's multi_logloss: 0.268787
[65] train's multi_logloss: 0.0401795 validate's multi_logloss: 0.270944
[66] train's multi_logloss: 0.0393914 validate's multi_logloss: 0.272873
[67] train's multi_logloss: 0.0386258 validate's multi_logloss: 0.272925
[68] train's multi_logloss: 0.0379039 validate's multi_logloss: 0.275053
[69] train's multi_logloss: 0.0371902 validate's multi_logloss: 0.276963
[70] train's multi_logloss: 0.0365181 validate's multi_logloss: 0.278877
[71] train's multi_logloss: 0.0358715 validate's multi_logloss: 0.280588
[72] train's multi_logloss: 0.0352353 validate's multi_logloss: 0.282603
[73] train's multi_logloss: 0.0346372 validate's multi_logloss: 0.284594
[74] train's multi_logloss: 0.0340456 validate's multi_logloss: 0.286232
[75] train's multi_logloss: 0.0334864 validate's multi_logloss: 0.288148
[76] train's multi_logloss: 0.0329445 validate's multi_logloss: 0.289775
[77] train's multi_logloss: 0.0324157 validate's multi_logloss: 0.291646
[78] train's multi_logloss: 0.031911 validate's multi_logloss: 0.29161
[79] train's multi_logloss: 0.0314163 validate's multi_logloss: 0.293302
[80] train's multi_logloss: 0.0309399 validate's multi_logloss: 0.294895
[81] train's multi_logloss: 0.0304758 validate's multi_logloss: 0.29668
[82] train's multi_logloss: 0.0300298 validate's multi_logloss: 0.298222
[83] train's multi_logloss: 0.0295942 validate's multi_logloss: 0.299954
[84] train's multi_logloss: 0.0291753 validate's multi_logloss: 0.300086
[85] train's multi_logloss: 0.0287618 validate's multi_logloss: 0.301625
[86] train's multi_logloss: 0.0283652 validate's multi_logloss: 0.303182
[87] train's multi_logloss: 0.0279776 validate's multi_logloss: 0.304707
[88] train's multi_logloss: 0.0276011 validate's multi_logloss: 0.305106
[89] train's multi_logloss: 0.0272367 validate's multi_logloss: 0.306678
[90] train's multi_logloss: 0.026883 validate's multi_logloss: 0.30822
[91] train's multi_logloss: 0.0265358 validate's multi_logloss: 0.309666
[92] train's multi_logloss: 0.0262009 validate's multi_logloss: 0.309765
[93] train's multi_logloss: 0.0258715 validate's multi_logloss: 0.311171
[94] train's multi_logloss: 0.0255548 validate's multi_logloss: 0.312661
[95] train's multi_logloss: 0.0252423 validate's multi_logloss: 0.312928
[96] train's multi_logloss: 0.0249416 validate's multi_logloss: 0.314297
[97] train's multi_logloss: 0.0246452 validate's multi_logloss: 0.315116
[98] train's multi_logloss: 0.0243606 validate's multi_logloss: 0.31647
[99] train's multi_logloss: 0.0240774 validate's multi_logloss: 0.317843
[100] train's multi_logloss: 0.0238036 validate's multi_logloss: 0.318235
0.9333333333333333
训练集 1.0
验证集 0.9332591768631814
[1] train's multi_logloss: 0.981573 validate's multi_logloss: 0.985608
[2] train's multi_logloss: 0.883927 validate's multi_logloss: 0.89032
[3] train's multi_logloss: 0.800254 validate's multi_logloss: 0.811111
[4] train's multi_logloss: 0.727456 validate's multi_logloss: 0.747084
[5] train's multi_logloss: 0.664656 validate's multi_logloss: 0.683869
[6] train's multi_logloss: 0.60884 validate's multi_logloss: 0.636064
[7] train's multi_logloss: 0.560072 validate's multi_logloss: 0.590141
[8] train's multi_logloss: 0.516133 validate's multi_logloss: 0.550807
[9] train's multi_logloss: 0.477245 validate's multi_logloss: 0.517379
[10] train's multi_logloss: 0.442304 validate's multi_logloss: 0.485167
[11] train's multi_logloss: 0.410901 validate's multi_logloss: 0.463524
[12] train's multi_logloss: 0.382674 validate's multi_logloss: 0.435015
[13] train's multi_logloss: 0.357052 validate's multi_logloss: 0.414516
[14] train's multi_logloss: 0.334035 validate's multi_logloss: 0.39999
[15] train's multi_logloss: 0.31286 validate's multi_logloss: 0.377932
[16] train's multi_logloss: 0.293857 validate's multi_logloss: 0.365563
[17] train's multi_logloss: 0.276418 validate's multi_logloss: 0.350061
[18] train's multi_logloss: 0.260565 validate's multi_logloss: 0.337537
[19] train's multi_logloss: 0.246165 validate's multi_logloss: 0.327766
[20] train's multi_logloss: 0.2328 validate's multi_logloss: 0.316106
[21] train's multi_logloss: 0.220709 validate's multi_logloss: 0.308427
[22] train's multi_logloss: 0.209576 validate's multi_logloss: 0.298055
[23] train's multi_logloss: 0.199358 validate's multi_logloss: 0.292903
[24] train's multi_logloss: 0.18999 validate's multi_logloss: 0.287517
[25] train's multi_logloss: 0.1813 validate's multi_logloss: 0.279557
[26] train's multi_logloss: 0.173332 validate's multi_logloss: 0.275507
[27] train's multi_logloss: 0.165919 validate's multi_logloss: 0.271781
[28] train's multi_logloss: 0.159137 validate's multi_logloss: 0.265719
[29] train's multi_logloss: 0.152858 validate's multi_logloss: 0.264156
[30] train's multi_logloss: 0.147033 validate's multi_logloss:</