Binary Search , Random Seed and Binary Sort Tree

主题:比较静态表和动态表:二叉顺序树的 ASL(Average Search Length)和时间复杂度 O(f(n))

   

Binary Search 二分查找(折半查找)

  适用:顺序表

  效率:O(log2(n))

  算法示例:http://www.cnblogs.com/xwdreamer/archive/2012/05/07/2487246.html

  程序实例:http://www.cnblogs.com/yu-chao/archive/2012/03/23/2413686.html

代码:

  1 /*
  2 
  3     目的:掌握折半查找、二叉查找树的插入与查找算法
  4  
  5     要求:针对输入文件 dict.txt (其中每一行为一个英文单词),网络地址:http://202.113.29.10/class/ds12/dict.txt
  6  
  7     1、请尝试编写折半查找算法,对每个单词进行查找,统计平均的比较次数;
  8  
  9     2、请尝试随机排列其中的单词,如果你不知道怎么做,请看这里,如果你不会做,可以直接用这个dictr.txt;
 10  
 11     3、对该随机单词序列,依次使用二叉排序树的插入算法插入建立二叉排序树,然后统计每个单词的平均查找长
 12 
 13 */
 14 #include<iostream>
 15 #include<fstream>
 16 #include<string>
 17 #include<vector>
 18 #include<time.h>
 19 #include<iomanip>
 20 using namespace std;
 21 typedef vector<string>::size_type index;
 22 
 23 class Static_Search_Table
 24 {
 25 private:
 26     vector<string> dict;
 27     index sum;
 28 public:
 29     Static_Search_Table();
 30     ~Static_Search_Table();
 31     long Bi_Search(int &pos,const string key);
 32     string Element(int pos);
 33     void Performance_Analysis();
 34     void Traverse();
 35 };
 36 Static_Search_Table::Static_Search_Table()
 37 {
 38     fstream read_dict("d:\\mydir\\dict.txt",ios::in);
 39     if(read_dict.fail())
 40         {cerr<<"open failed!"<<endl;exit(1);}
 41     string temp;
 42     dict.push_back("xxx dict");
 43     while(read_dict>>temp)
 44         dict.push_back(temp);
 45     sum=dict.size()-1;
 46     cout<<"the whole number of word is:"<<sum<<endl;
 47     read_dict.close();
 48 }
 49 Static_Search_Table::~Static_Search_Table()
 50 {
 51     cout<<"mission completed!"<<endl;
 52     system("pause");
 53 }
 54 /*
 55     二分查找又称折半查找,
 56         优点:比较次数少,查找速度快,平均性能好
 57         缺点:要求待查表为有序表,且插入删除困难
 58     方法:
 59         1.确定待查记录区间
 60         2.缩小范围
 61     环境:
 62         dict.txt是字典顺序,按string类的比较操作符是增序排列
 63     停止条件:
 64         1.查找成功
 65         2.待查区间长度为0且查找失败
 66 */
 67 long Static_Search_Table::Bi_Search(int &pos,const string key)
 68 {
 69     int beg,end,mid;
 70     long num;
 71     beg=1;end=sum;num=0;
 72     while(beg<=end)
 73         {
 74             mid=(beg+end)/2;
 75             if(dict[mid]==key)
 76                 {pos=mid;++num;end-=sum;}
 77             else
 78                 if(dict[mid]<key)
 79                     {end=mid-1;++num;}
 80                 else
 81                     {beg=mid+1;++num;}
 82         }
 83     return num;
 84 }
 85 string Static_Search_Table::Element(int pos)
 86 {
 87     return dict[pos];
 88 };
 89 /*
 90     性能分析模块:performance analysis
 91     查找性能用ASL(Average Search Length)衡量:ASL越小,性能越好
 92 */
 93 void Static_Search_Table::Performance_Analysis()
 94 {
 95     int temp,j=0;
 96     long SSL=0;
 97     double ASL;
 98     for(index i=1;i<dict.size();++i)
 99             SSL+=Bi_Search(temp,dict[i]);
100     ASL=double(SSL)/double(sum);
101     cout<<"ASL:"<<ASL<<endl;
102 }
103 void Static_Search_Table::Traverse()
104 {
105     fstream traverse("d:\\mydir\\Traverse.txt",ios::out|ios::trunc);
106     if(traverse.fail())
107         {cerr<<"open failed!"<<endl;exit(1);}
108     for(index i=1;i<sum;++i)
109             traverse<<dict[i]<<endl;
110     traverse.close();
111 }
112 
113 int main()
114 {
115     clock_t start,finish,duration;
116     Static_Search_Table SST;
117     start=clock();
118     SST.Performance_Analysis();
119     finish=clock();
120     cout<<"the precise time is:"<<finish-start<<endl;
121     duration=(finish-start)/CLOCKS_PER_SEC;
122     cout<<"the whole time is:"<<duration<<"s"<<endl;
123     system("pause");
124     return 0;
125 }

 

DisOrdered Dictionary 乱序字典生成:

  置换:http://en.wikipedia.org/wiki/Random_permutation

  随机数生成:http://www.cnblogs.com/longdouhzt/archive/2011/10/15/2213756.html

  高效不重复随机数的生成:http://www.cppblog.com/sleepwom/archive/2010/01/13/105570.html

  不重复随机数快速生成原理:有1—n总共n个正整数,利用随机置换生成排列。

               循环从最后一个元素到第二个元素,记为 i 。

               每次从0—(i-1)个元素中随机选取一个元素,与 i 做置换。

               可用数学方法证明使用上述步骤生成的是一个随机序列

 1 #include<iostream>
 2 #include<fstream>
 3 #include<vector>
 4 #include<string>
 5 #include<ctime>
 6 using namespace std;
 7 
 8 class DisOrdered
 9 {
10 private:
11     vector<int> v;
12     vector<string> dict;
13     int sum;
14 public:
15     DisOrdered();
16     ~DisOrdered();
17     void Swap(int i,int j);
18     void Replacement();
19 };
20 DisOrdered::DisOrdered()
21 {
22     fstream read_dict("d:\\mydir\\dict.txt",ios::in);
23     if(read_dict.fail())
24         {cerr<<"open failed!"<<endl;exit(1);}
25     string temp;
26     while(read_dict>>temp)
27         dict.push_back(temp);
28     sum=dict.size();
29     read_dict.close();
30     for(int i=0;i<sum;++i)
31         v.push_back(i);
32     cout<<"init successfully!"<<endl;
33 }
34 DisOrdered::~DisOrdered()
35 {
36     cout<<"exit successfully!"<<endl;
37 }
38 void DisOrdered::Swap(int i,int j)
39 {
40     int temp;
41     temp=v[j];v[j]=v[i];v[i]=temp;
42 };
43 void DisOrdered::Replacement()
44 {
45     int i,j;
46     fstream write_dict_random("d:\\mydir\\dict_random.txt",ios::out|ios::trunc);
47     if(write_dict_random.fail())
48         {cerr<<"open failed!"<<endl;exit(1);}
49     srand((unsigned)time(0));
50     time_t beg,end,dur;
51     beg=clock();
52     for(i=sum-1;i>1;--i)
53         {
54             j=rand()%i;
55             Swap(i,j);
56         }
57     end=clock();
58     dur=end-beg;
59     cout<<"duration:"<<double(dur)/double(CLOCKS_PER_SEC)<<endl;
60     for(i=0;i<sum;++i)
61         write_dict_random<<dict[v[i]]<<endl;
62     write_dict_random.close();
63 }
64 int main()
65 {
66     DisOrdered DOdict;
67     DOdict.Replacement();
68     system("pause");
69     return 0;
70 }

 

 

Binary Sort Tree 二叉排序树(红黑树)

  STL容器:sets container 和 multisets container 。

  root:第一个输入的元素

  left_child:小于双亲节点

  right_child:大于双亲节点

  关于sets:http://apps.hi.baidu.com/share/detail/18593242

       http://www.cplusplus.com/reference/stl/set/

  书目:《c++标准程序库:自修教程与参考手册》 华中科技大学出版社 侯捷/孟岩 6.5 sets 和 multisets Page:175-191

 代码:

  1 /*
  2 
  3     本来想在这里尝试用sets容器构造二叉树
  4 
  5     STL参考书上sets是以“红黑树”(即二叉排序树)方式实现,相对简单很多
  6 
  7     但是考虑到要计算AVL,只好重编一个二叉排序树
  8 
  9 */
 10 #include<iostream>
 11 #include<fstream>
 12 #include<string>
 13 #include<iomanip>
 14 using namespace std;
 15 
 16 typedef struct Binary_Tree_Node
 17 {
 18     string data;
 19     Binary_Tree_Node *left_child,*right_child;
 20 }BTN;
 21 
 22 class Binary_Sort_Tree
 23 {
 24 private:
 25     BTN root;
 26     int sum;
 27 public:
 28     Binary_Sort_Tree();
 29     ~Binary_Sort_Tree();
 30     long Match(const string s);
 31     void Performance_Analysis();
 32     void Insert(const string s);
 33     void Destroy(BTN *&p);
 34 };
 35 Binary_Sort_Tree::Binary_Sort_Tree()
 36 {
 37     fstream read_dictr("d:\\mydir\\dictr.txt",ios::in);
 38     if(read_dictr.fail())
 39         {cerr<<"init open failed!"<<endl;exit(1);}
 40     string temp;
 41     read_dictr>>root.data;
 42     root.left_child=root.right_child=NULL;
 43     while(read_dictr>>temp)
 44         {
 45             BTN *x=&root;
 46             while(x!=NULL)
 47             {
 48                 if(temp<x->data)
 49                     if(x->left_child!=NULL)
 50                         x=x->left_child;
 51                     else
 52                         {
 53                             x->left_child=new BTN;
 54                             x=x->left_child;
 55                             x->data=temp;
 56                             x->left_child=x->right_child=NULL;
 57                             x=NULL;
 58                         }    
 59                 else
 60                     if(x->right_child!=NULL)
 61                         x=x->right_child;
 62                     else
 63                         {
 64                             x->right_child=new BTN;
 65                             x=x->right_child;
 66                             x->data=temp;
 67                             x->left_child=x->right_child=NULL;
 68                             x=NULL;
 69                         }    
 70             }
 71         }
 72     read_dictr.close();
 73 }
 74 long Binary_Sort_Tree::Match(const string s)
 75 {
 76     bool b=true;
 77     long num=0;
 78     BTN *x=&root;
 79     while(b)
 80         {
 81             if(x->data==s)
 82                 {++num;b=false;}
 83             if(x->data>s)
 84                 {x=x->left_child;++num;}
 85             else
 86                 {x=x->right_child;++num;}
 87         }
 88     return num;
 89 }
 90 void Binary_Sort_Tree::Performance_Analysis()
 91 {
 92     fstream test("d:\\mydir\\dictr.txt",ios::in);
 93     if(test.fail())
 94         {cerr<<"open failed!"<<endl;exit(1);}
 95     long sum=0,SSL=0;
 96     double ASL;
 97     string temp;
 98     while(test>>temp)
 99         {SSL+=Match(temp);++sum;}
100     ASL=double(SSL)/double(sum);
101     cout<<"sum:"<<sum<<endl;
102     cout<<"SSL:"<<SSL<<endl;
103     cout<<"ASL:"<<ASL<<endl;
104     test.close();
105 }
106 void Binary_Sort_Tree::Insert(const string s)
107 {
108     BTN *x=&root;
109     while(x!=NULL)
110         {
111             if(x->data==s)
112                 {cerr<<"string:"<<s<<" has existed!"<<endl;return;}
113             if(s<x->data)
114                 x=x->left_child;
115             else
116                 x=x->right_child;
117         }
118     x=new BTN;
119     x->data=s;
120     x->left_child=x->right_child=NULL;
121     cout<<"insert successfully!"<<endl;
122 }
123 void Binary_Sort_Tree::Destroy(BTN *&p)
124 {
125     if(p)
126         {
127             Destroy(p->left_child);
128             Destroy(p->right_child);
129             delete p;
130             p=NULL;
131         }
132 }
133 Binary_Sort_Tree::~Binary_Sort_Tree()
134 {
135     BTN *x=&root;
136     Destroy(x->left_child);
137     Destroy(x->right_child);
138 }
139 int main()
140 {
141     clock_t start,finish,duration;
142     Binary_Sort_Tree BST;
143     start=clock();
144     BST.Performance_Analysis();
145     finish=clock();
146     cout<<"the precise time is:"<<finish-start<<endl;
147     duration=(finish-start)/CLOCKS_PER_SEC;
148     cout<<"the whole time is:"<<duration<<"s"<<endl;
149     system("pause");
150     return 0;
151 }

 

效率比较:

  静态表查找速度:

  ASL:16.4998

  时间:1111ms

  动态表查找速度:

  ASL:21.3702

  时间:1937ms

  静态表查找速度较快

  

  

 

转载于:https://www.cnblogs.com/kopanswer/archive/2012/06/06/2537503.html

import streamlit as st import pandas as pd import numpy as np import joblib import os import time import matplotlib.pyplot as plt import matplotlib as mpl import matplotlib.font_manager as fm import seaborn as sns from pyspark.sql import SparkSession from pyspark.ml.feature import VectorAssembler, StandardScaler from pyspark.ml.classification import LogisticRegression, DecisionTreeClassifier, RandomForestClassifier from pyspark.ml.evaluation import BinaryClassificationEvaluator from pyspark.ml.tuning import ParamGridBuilder, CrossValidator from pyspark.sql.functions import when, col from sklearn.metrics import classification_report, confusion_matrix import warnings import dask.dataframe as dd from dask.diagnostics import ProgressBar from dask_ml.preprocessing import StandardScaler as DaskStandardScaler import tempfile warnings.filterwarnings("ignore") plt.rcParams['font.sans-serif'] = ['SimHei'] plt.rcParams['axes.unicode_minus'] = False # 页面设置 st.set_page_config( page_title="单宽转融用户预测系统", page_icon="📶", layout="wide", initial_sidebar_state="expanded" ) # 自定义CSS样式 st.markdown(""" <style> .stApp { background: linear-gradient(135deg, #f5f7fa 0%, #e4edf5 100%); font-family: 'Helvetica Neue', Arial, sans-serif; } .header { background: linear-gradient(90deg, #2c3e50 0%, #4a6491 100%); color: white; padding: 1.5rem; border-radius: 0.75rem; box-shadow: 0 4px 12px rgba(0,0,0,0.1); margin-bottom: 2rem; } .card { background: white; border-radius: 0.75rem; padding: 1.5rem; margin-bottom: 1.5rem; box-shadow: 0 4px 12px rgba(0,0,0,0.08); transition: transform 0.3s ease; } .card:hover { transform: translateY(-5px); box-shadow: 0 6px 16px rgba(0,0,0,0.12); } .stButton button { background: linear-gradient(90deg, #3498db 0%, #1a5276 100%) !important; color: white !important; border: none !important; border-radius: 0.5rem; padding: 0.75rem 1.5rem; font-size: 1rem; font-weight: 600; transition: all 0.3s ease; width: 100%; } .stButton button:hover { transform: scale(1.05); box-shadow: 0 4px 8px rgba(52, 152, 219, 0.4); } .feature-box { background: linear-gradient(135deg, #e3f2fd 0%, #bbdefb 100%); border-radius: 0.75rem; padding: 1.5rem; margin-bottom: 1.5rem; } .result-box { background: linear-gradient(135deg, #e8f5e9 0%, #c8e6c9 100%); border-radius: 0.75rem; padding: 1.5rem; margin-top: 1.5rem; } .model-box { background: linear-gradient(135deg, #fff3e0 0%, #ffe0b2 100%); border-radius: 0.75rem; padding: 1.5rem; margin-top: 1.5rem; } .stProgress > div > div > div { background: linear-gradient(90deg, #2ecc71 0%, #27ae60 100%) !important; } .metric-card { background: white; border-radius: 0.75rem; padding: 1rem; text-align: center; box-shadow: 0 4px 8px rgba(0,0,0,0.06); } .metric-value { font-size: 1.8rem; font-weight: 700; color: #2c3e50; } .metric-label { font-size: 0.9rem; color: #7f8c8d; margin-top: 0.5rem; } .highlight { background: linear-gradient(90deg, #ffeb3b 0%, #fbc02d 100%); padding: 0.2rem 0.5rem; border-radius: 0.25rem; font-weight: 600; } .stDataFrame { border-radius: 0.75rem; box-shadow: 0 4px 8px rgba(0,0,0,0.06); } .risk-high { background-color: #ffcdd2 !important; color: #c62828 !important; font-weight: 700; } .risk-medium { background-color: #fff9c4 !important; color: #f57f17 !important; font-weight: 600; } .risk-low { background-color: #c8e6c9 !important; color: #388e3c !important; } </style> """, unsafe_allow_html=True) def preprocess_data(ddf): """ 使用Dask进行数据预处理,支持大数据处理 参数: ddf (dask.DataFrame): 原始数据 返回: processed_ddf (dask.DataFrame): 处理后的数据 feature_cols (list): 特征列名列表 """ # 创建副本以避免修改原始数据 processed_ddf = ddf.copy() # 1. 删除无意义特征 drop_cols = ['BIL_MONTH', 'ASSET_ROW_ID', 'CCUST_ROW_ID', 'BELONG_CITY', 'MKT_CHANNEL_NAME', 'MKT_CHANNEL_SUB_NAME', 'PREPARE_FLG', 'SERV_START_DT', 'COMB_STAT_NAME', 'FIBER_ACCESS_CATEGORY'] # 检查并删除存在的列 existing_cols = [col for col in drop_cols if col in processed_ddf.columns] processed_ddf = processed_ddf.drop(columns=existing_cols) # 2. 处理缺失值 # 数值型特征用均值填充 numeric_cols = processed_ddf.select_dtypes(include=[np.number]).columns.tolist() if 'is_rh_next' in numeric_cols: numeric_cols.remove('is_rh_next') # 计算均值(Dask需要先persist) with ProgressBar(): means = processed_ddf[numeric_cols].mean().compute() # 填充缺失值 for col in numeric_cols: processed_ddf[col] = processed_ddf[col].fillna(means[col]) # 类别型特征用"Unknown"填充 object_cols = processed_ddf.select_dtypes(include=['object']).columns.tolist() for col in object_cols: processed_ddf[col] = processed_ddf[col].fillna("Unknown") # 3. 离散特征编码 # 对二元特征进行简单映射 binary_cols = ['IF_YHTS', 'is_kdts', 'is_itv_up', 'is_mobile_up', 'if_zzzw_up'] for col in binary_cols: if col in processed_ddf.columns: processed_ddf[col] = processed_ddf[col].map({'否': 0, '是': 1, 0: 0, 1: 1, 'Unknown': -1}) # 对性别进行映射 if 'GENDER' in processed_ddf.columns: gender_mapping = {'男': 0, '女': 1, 'Unknown': -1} processed_ddf['GENDER'] = processed_ddf['GENDER'].map(gender_mapping) # 4. 用户星级映射 if 'MKT_STAR_GRADE_NAME' in processed_ddf.columns: star_mapping = { '五星级': 5, '四星级': 4, '三星级': 3, '二星级': 2, '一星级': 1, 'Unknown': 0 } processed_ddf['MKT_STAR_GRADE_NAME'] = processed_ddf['MKT_STAR_GRADE_NAME'].map(star_mapping) # 5. 特征工程 # 计算消费比率(套餐价格/出账金额) if 'PROM_AMT' in processed_ddf.columns and 'STMT_AMT' in processed_ddf.columns: processed_ddf['CONSUMPTION_RATIO'] = processed_ddf['PROM_AMT'] / (processed_ddf['STMT_AMT'] + 1) # 计算流量使用密度(下载流量/在网天数) if 'DWN_VOL' in processed_ddf.columns and 'ONLINE_DAY' in processed_ddf.columns: processed_ddf['TRAFFIC_DENSITY'] = processed_ddf['DWN_VOL'] / (processed_ddf['ONLINE_DAY'] + 1) # 是否有终端设备 if 'TERM_CNT' in processed_ddf.columns: processed_ddf['HAS_TERMINAL'] = (processed_ddf['TERM_CNT'] > 0).astype(int) # 6. 标准化处理 scaler = DaskStandardScaler() numeric_cols_for_scaling = list(set(numeric_cols) - set(['is_rh_next'])) if len(numeric_cols_for_scaling) > 0: processed_ddf[numeric_cols_for_scaling] = scaler.fit_transform(processed_ddf[numeric_cols_for_scaling]) # 保存特征列 feature_cols = [col for col in processed_ddf.columns if col != 'is_rh_next'] return processed_ddf, feature_cols, means, numeric_cols_for_scaling, scaler def create_spark_session(): """创建或获取现有的Spark会话""" return SparkSession.builder \ .appName("SingleToMeltUserPrediction") \ .config("spark.sql.shuffle.partitions", "8") \ .config("spark.driver.memory", "8g") \ .config("spark.executor.memory", "8g") \ .getOrCreate() def train_models(spark_df, feature_cols): """ 使用Spark训练多个模型并评估性能 参数: spark_df (pyspark.sql.DataFrame): 处理后的数据 feature_cols (list): 特征列名列表 返回: results (dict): 包含训练好的模型及其性能指标 """ # 初始化Spark会话 spark = create_spark_session() # 将特征列组合为特征向量 assembler = VectorAssembler(inputCols=feature_cols, outputCol="rawFeatures") assembled_df = assembler.transform(spark_df) # 标准化特征 scaler = StandardScaler(inputCol="rawFeatures", outputCol="features") scaler_model = scaler.fit(assembled_df) scaled_df = scaler_model.transform(assembled_df) # 划分训练集和测试集 train_df, test_df = scaled_df.randomSplit([0.8, 0.2], seed=42) # 定义评估器 lr = LogisticRegression(featuresCol="features", labelCol="is_rh_next") dt = DecisionTreeClassifier(featuresCol="features", labelCol="is_rh_next") rf = RandomForestClassifier(featuresCol="features", labelCol="is_rh_next", numTrees=10) # 定义参数网格 lr_param_grid = ParamGridBuilder() \ .addGrid(lr.regParam, [0.01, 0.1]) \ .addGrid(lr.elasticNetParam, [0.0, 0.5]) \ .build() dt_param_grid = ParamGridBuilder() \ .addGrid(dt.maxDepth, [5, 10]) \ .addGrid(dt.minInstancesPerNode, [10, 20]) \ .build() rf_param_grid = ParamGridBuilder() \ .addGrid(rf.numTrees, [10, 20]) \ .addGrid(rf.maxDepth, [5, 10]) \ .build() # 定义交叉验证器 evaluator = BinaryClassificationEvaluator(labelCol="is_rh_next", metricName="areaUnderROC") lr_cv = CrossValidator(estimator=lr, estimatorParamMaps=lr_param_grid, evaluator=evaluator, numFolds=3) dt_cv = CrossValidator(estimator=dt, estimatorParamMaps=dt_param_grid, evaluator=evaluator, numFolds=3) rf_cv = CrossValidator(estimator=rf, estimatorParamMaps=rf_param_grid, evaluator=evaluator, numFolds=3) # 训练模型 results = {} # 逻辑回归 with st.spinner("正在训练逻辑回归模型..."): lr_model = lr_cv.fit(train_df) lr_predictions = lr_model.transform(test_df) lr_auc = evaluator.evaluate(lr_predictions) lr_accuracy = lr_predictions.filter(lr_predictions.is_rh_next == lr_predictions.prediction).count() / test_df.count() results["logistic_regression"] = { "model": lr_model, "auc": lr_auc, "accuracy": lr_accuracy, "best_params": lr_model.bestModel._java_obj.parent().extractParamMap() } # 决策树 with st.spinner("正在训练决策树模型..."): dt_model = dt_cv.fit(train_df) dt_predictions = dt_model.transform(test_df) dt_auc = evaluator.evaluate(dt_predictions) dt_accuracy = dt_predictions.filter(dt_predictions.is_rh_next == dt_predictions.prediction).count() / test_df.count() results["decision_tree"] = { "model": dt_model, "auc": dt_auc, "accuracy": dt_accuracy, "best_params": dt_model.bestModel._java_obj.parent().extractParamMap(), "feature_importances": dt_model.bestModel.featureImportances.toArray().tolist() } # 随机森林 with st.spinner("正在训练随机森林模型..."): rf_model = rf_cv.fit(train_df) rf_predictions = rf_model.transform(test_df) rf_auc = evaluator.evaluate(rf_predictions) rf_accuracy = rf_predictions.filter(rf_predictions.is_rh_next == rf_predictions.prediction).count() / test_df.count() results["random_forest"] = { "model": rf_model, "auc": rf_auc, "accuracy": rf_accuracy, "best_params": rf_model.bestModel._java_obj.parent().extractParamMap(), "feature_importances": rf_model.bestModel.featureImportances.toArray().tolist() } return results # 标题区域 st.markdown(""" <div class="header"> <h1 style='text-align: center; margin: 0;'>单宽转融用户预测系统</h1> <p style='text-align: center; margin: 0.5rem 0 0; font-size: 1.1rem;'>基于大数据挖掘的精准营销分析平台</p> </div> """, unsafe_allow_html=True) # 页面布局 col1, col2 = st.columns([1, 1.5]) # 左侧区域 - 图片和简介 with col1: st.markdown(""" <div class="card"> <h3 style='text-align: center; color: #2c3e50;'>精准营销系统</h3> <p style='text-align: center;'>利用先进数据挖掘技术识别潜在融合套餐用户</p> </div> """, unsafe_allow_html=True) # 使用在线图片作为占位符 st.image("https://images.unsplash.com/photo-1550751822256-00808c92fc8d?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1200&q=80", caption="精准营销示意图", use_column_width=True) st.markdown(""" <div class="feature-box"> <h4>📈 系统功能</h4> <ul> <li>用户转化预测</li> <li>多模型对比分析</li> <li>特征重要性分析</li> <li>可视化数据洞察</li> </ul> </div> """, unsafe_allow_html=True) # 右侧区域 - 功能选择 with col2: st.markdown(""" <div class="card"> <h3 style='color: #2c3e50;'>请选择操作类型</h3> <p>您可以选择训练新模型或查看现有模型分析结果</p> </div> """, unsafe_allow_html=True) # 功能选择 option = st.radio("", ["🚀 训练新模型 - 使用新数据训练预测模型", "🔍 模型分析 - 查看现有模型的分析结果"], index=0, label_visibility="hidden") # 模型训练部分 if "训练新模型" in option: st.markdown(""" <div class="model-box"> <h4>模型训练</h4> <p>上传训练数据并训练新的预测模型</p> </div> """, unsafe_allow_html=True) # 上传训练数据 train_file = st.file_uploader("上传训练数据 (CSV格式)", type=["csv"], accept_multiple_files=False) if train_file is not None: try: # 使用Dask读取大文件 with tempfile.TemporaryDirectory() as tmpdir: tmp_path = os.path.join(tmpdir, "large_file.csv") with open(tmp_path, "wb") as f: f.write(train_file.read()) # 分块读取配置 chunksize = 10**6 # 1MB每块 raw_ddf = dd.read_csv(tmp_path, blocksize=chunksize, assume_missing=True, dtype={'is_rh_next': 'int8'}) # 显示数据预览 with st.expander("数据预览", expanded=True): st.dataframe(raw_ddf.head(1000, npartitions=1).compute()) col1, col2 = st.columns(2) col1.metric("总样本数", f"{raw_ddf.shape[0].compute():,}") col2.metric("特征数量", raw_ddf.shape[1].compute() - 1) # 检查目标变量是否存在 if 'is_rh_next' not in raw_ddf.columns: st.warning("⚠️ 注意:未找到目标变量 'is_rh_next'") # 数据预处理按钮 if st.button("开始数据预处理", use_container_width=True): with st.spinner("正在进行数据预处理,请稍候..."): processed_ddf, feature_cols, means, numeric_cols_for_scaling, scaler = preprocess_data(raw_ddf) # 保存预处理参数 preprocessor_params = { 'means': means, 'numeric_cols_for_scaling': numeric_cols_for_scaling, 'scaler': scaler, 'feature_cols': feature_cols } joblib.dump(preprocessor_params, 'preprocessor_params.pkl') # 保存处理后的数据 processed_ddf.to_csv('processed_data_*.csv', index=False) st.success("✅ 数据预处理完成!") # 显示处理后的数据统计 st.subheader("数据质量检查") col1, col2 = st.columns(2) col1.write("缺失值统计:") col1.write(processed_df.isnull().sum()[processed_df.isnull().sum() > 0]) col2.write("异常值检测:") for col in processed_df[feature_cols]: q1 = processed_df[col].quantile(0.25) q3 = processed_df[col].quantile(0.75) iqr = q3 - q1 outlier_count = ((processed_df[col] < (q1 - 1.5 * iqr)) | (processed_df[col] > (q3 + 1.5 * iqr))).sum() if outlier_count > 0: col2.write(f"{col}: {outlier_count} 个异常值") # 可视化关键特征分布 st.subheader("关键特征分布") fig, axes = plt.subplots(2, 2, figsize=(12, 10)) sns.histplot(processed_df['AGE'], ax=axes[0, 0], kde=True) sns.histplot(processed_df['ONLINE_DAY'], ax=axes[0, 1], kde=True) sns.histplot(processed_df['PROM_AMT'], ax=axes[1, 0], kde=True) sns.histplot(processed_df['DWN_VOL'], ax=axes[1, 1], kde=True) plt.tight_layout() st.pyplot(fig) # 目标变量分布 st.subheader("目标变量分布") fig, ax = plt.subplots(figsize=(6, 4)) sns.countplot(x='is_rh_next', data=processed_df, ax=ax) ax.set_xlabel("是否转化 (0=未转化, 1=转化)") ax.set_ylabel("用户数量") ax.set_title("用户转化分布") st.pyplot(fig) # 特征目标变量相关性 st.subheader("特征转化的相关性") correlation = processed_df[feature_cols + ['is_rh_next']].corr()['is_rh_next'].sort_values(ascending=False) fig, ax = plt.subplots(figsize=(10, 6)) sns.barplot(x=correlation.values, y=correlation.index, ax=ax) ax.set_title("特征转化的相关性") st.pyplot(fig) # 模型训练按钮 if st.button("开始模型训练", use_container_width=True): if not os.path.exists('processed_data.csv'): st.error("请先进行数据预处理") else: # 加载处理后的数据 processed_df = pd.read_csv('processed_data.csv') preprocessor_params = joblib.load('preprocessor_params.pkl') feature_cols = preprocessor_params['feature_cols'] # 转换为Spark DataFrame spark = create_spark_session() spark_df = spark.createDataFrame(processed_df) # 训练模型 with st.spinner("正在训练模型,请耐心等待..."): results = train_models(spark_df, feature_cols) # 保存模型结果 joblib.dump(results, 'model_results.pkl') st.success("🎉 模型训练完成!") # 显示模型比较 st.subheader("模型性能对比") model_performance = pd.DataFrame({ '模型': ['逻辑回归', '决策树', '随机森林'], '准确率': [results['logistic_regression']['accuracy'], results['decision_tree']['accuracy'], results['random_forest']['accuracy']], 'AUC': [results['logistic_regression']['auc'], results['decision_tree']['auc'], results['random_forest']['auc']] }).sort_values('AUC', ascending=False) st.table(model_performance.style.format({ '准确率': '{:.2%}', 'AUC': '{:.4f}' })) # 最佳模型特征重要性 best_model_name = model_performance.iloc[0]['模型'] model_map = { '逻辑回归': 'logistic_regression', '决策树': 'decision_tree', '随机森林': 'random_forest' } best_model_key = model_map[best_model_name] best_model = results[best_model_key]['model'].bestModel st.subheader(f"最佳模型 ({best_model_name}) 分析") if best_model_key in ['decision_tree', 'random_forest']: feature_importances = results[best_model_key]['feature_importances'] importance_df = pd.DataFrame({ '特征': feature_cols, '重要性': feature_importances }).sort_values('重要性', ascending=False).head(10) fig, ax = plt.subplots(figsize=(10, 6)) sns.barplot(x='重要性', y='特征', data=importance_df, ax=ax) ax.set_title('Top 10 重要特征') st.pyplot(fig) # 显示最佳模型参数 st.subheader("最佳模型参数") params = results[best_model_key]['best_params'] param_table = pd.DataFrame({ '参数': [str(param.name) for param in params.keys()], '值': [str(value) for value in params.values()] }) st.table(param_table) except Exception as e: st.error(f"数据处理错误: {str(e)}") # 模型分析部分 else: st.markdown(""" <div class="model-box"> <h4>模型分析</h4> <p>查看已有模型的详细分析结果</p> </div> """, unsafe_allow_html=True) if not os.path.exists('model_results.pkl'): st.info("ℹ️ 当前没有可用模型。请先进行模型训练以生成分析报告。") # 显示可执行的操作 col1, col2 = st.columns(2) with col1: st.markdown(""" <div class="card"> <h4>🛠️ 数据准备</h4> <p>确保您已准备好符合要求的数据文件,包含所有必要的特征字段。</p> </div> """, unsafe_allow_html=True) with col2: st.markdown(""" <div class="card"> <h4>🧠 模型训练</h4> <p>切换到“训练新模型”选项卡,上传您的数据并开始训练过程。</p> </div> """, unsafe_allow_html=True) else: # 加载模型结果 results = joblib.load('model_results.pkl') preprocessor_params = joblib.load('preprocessor_params.pkl') feature_cols = preprocessor_params['feature_cols'] # 模型选择 model_choice = st.selectbox( "选择要分析的模型", ("逻辑回归", "决策树", "随机森林") ) model_map = { '逻辑回归': 'logistic_regression', '决策树': 'decision_tree', '随机森林': 'random_forest' } model_key = model_map[model_choice] # 显示模型基本信息 model_info = results[model_key] st.markdown(f""" <div class="card"> <h3>{model_choice}</h3> <p><strong>AUC得分:</strong> {model_info['auc']:.4f}</p> <p><strong>准确率:</strong> {model_info['accuracy']:.2%}</p> </div> """, unsafe_allow_html=True) # 显示参数详情 with st.expander("模型参数详情", expanded=False): params = model_info['best_params'] param_table = pd.DataFrame({ '参数': [str(param.name) for param in params.keys()], '值': [str(value) for value in params.values()] }) st.table(param_table) # 如果存在特征重要性信息 if model_key in ['decision_tree', 'random_forest']: feature_importances = model_info['feature_importances'] importance_df = pd.DataFrame({ '特征': feature_cols, '重要性': feature_importances }).sort_values('重要性', ascending=False) st.subheader("特征重要性分析") # Top 10 重要特征 top_features = importance_df.head(10) fig, ax = plt.subplots(figsize=(10, 6)) sns.barplot(x='重要性', y='特征', data=top_features, ax=ax) ax.set_title('Top 10 重要特征') st.pyplot(fig) # 所有特征的重要性分布 fig, ax = plt.subplots(figsize=(10, 6)) sns.histplot(importance_df['重要性'], bins=20, ax=ax) ax.set_title('特征重要性分布') st.pyplot(fig) # 显示具体数值 st.write("特征重要性详细数据:") st.dataframe(importance_df.style.background_gradient(subset=['重要性'], cmap='viridis')) # 模型比较 st.subheader("其他模型的对比") model_performance = pd.DataFrame({ '模型': ['逻辑回归', '决策树', '随机森林'], '准确率': [results['logistic_regression']['accuracy'], results['decision_tree']['accuracy'], results['random_forest']['accuracy']], 'AUC': [results['logistic_regression']['auc'], results['decision_tree']['auc'], results['random_forest']['auc']] }).sort_values('AUC', ascending=False) fig, ax = plt.subplots(figsize=(10, 6)) model_performance.set_index('模型')[['AUC', '准确率']].plot(kind='bar', ax=ax) ax.set_title('模型性能对比') ax.set_ylabel('评分') plt.xticks(rotation=0) st.pyplot(fig) st.table(model_performance.style.format({ '准确率': '{:.2%}', 'AUC': '{:.4f}' }).apply(lambda x: ['background: lightgreen' if x.name == model_performance.index[0] else '' for _ in x])) # 页脚 st.markdown("---") st.markdown(""" <div style="text-align: center; color: #7f8c8d; font-size: 0.9rem; padding: 1rem;"> © 2023 单宽转融用户预测系统 | 2231030273 基于Streamlit和Spark开发 </div> """, unsafe_allow_html=True) 修改上述代码使其能读取大于200MB的csv文件,并给出完整修改后代码
06-28
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值