#-*- coding: utf-8 -*- #提取日历史数据 import pandas as pd from finta import TA #金融指标库,我这里使用它帮我计算均线值 import numpy as np import matplotlib.pyplot as plt import seaborn as sns # 这个库可以用来画 heatmap 图 import pymysql import time import datetime import sklearn from sqlalchemy import create_engine dbconn = pymysql.connect( host="127.0.0.1", database="A_HISDATA", user="root", password="111111", port=3306, charset='utf8') connt = create_engine('mysql+mysqldb://root:111111@localhost:3306/A_HISDATA?charset=utf8') # 数据,来自数据库,指定日期第一段 sqlcmd1 = "select * from a_hisdata.telco" df = pd.read_sql(sqlcmd1, dbconn) print(df.columns) #开始模型数据准备 X = df.drop(['index','churn'],axis=1) y = df['churn'] from sklearn.ensemble import ExtraTreesClassifier from sklearn.feature_selection import SelectFromModel print(X.shape) clf = ExtraTreesClassifier() clf = clf.fit(X, y) print(clf.feature_importances_) print(clf.n_features_) #model = SelectFromModel(clf, prefit=True) #X_new = model.transform(X) #print(X_new.shape) t = pd.DataFrame(X.columns) t.columns = ['var'] t['importance'] = clf.feature_importances_ t = t.sort_values(by=['importance'],ascending=False) print(t.columns) var_ok = t[t['importance']>0.05]['var'].values#按需选择变量 print(var_ok) X_ok = X[var_ok] print(X_ok)