使用boost::remove_if实现条件筛选

408 篇文章 ¥29.90 ¥99.00
本文介绍了在C++编程中如何使用boost库的remove_if函数进行条件筛选,通过示例展示了如何定义谓词并从容器中删除满足特定条件的元素,最后提醒了在实际操作中需配合erase函数来彻底删除元素。

使用boost::remove_if实现条件筛选

在C++编程中,我们常常需要进行一些条件筛选,例如:从某个容器(vector、list等)中删除满足特定条件的元素。此时,boost库中的remove_if函数可以大有用处。

remove_if函数的格式为:

template<typename Range, typename Predicate>
Range remove_if(Range&& range, Predicate pred);

其中,Range表示容器范围(begin()到end()),Predicate表示谓词(即判断条件)。这个函数会返回一个迭代器,指向被删除元素后面的第一个元素。

下面,我们以vector为例,介绍如何使用remove_if函数实现条件筛选。

首先,我们定义一个vector,并向其中添加一些元素:

#include <iostream>
#
帮我修改下面代码,把训练模型的方法从抽样改成所有数据进行训练 # -*- coding: utf-8 -*- """ Created on Sat Aug 9 11:56:46 2025 @author: srx20 """ # -*- coding: utf-8 -*- """ Created on Sat Aug 9 10:33:06 2025 @author: srx20 """ import os import glob import pandas as pd import numpy as np import joblib import gc from datetime import datetime, timedelta from sklearn.preprocessing import StandardScaler from sklearn.cluster import MiniBatchKMeans from sklearn.ensemble import RandomForestClassifier from sklearn.metrics import classification_report, confusion_matrix import talib as ta import warnings import chardet import psutil import sys import pyarrow as pa import pyarrow.parquet as pq import pyarrow.csv as pc from tqdm import tqdm from sklearn.model_selection import train_test_split # 修复警告处理 warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning) warnings.filterwarnings("ignore", category=RuntimeWarning) warnings.filterwarnings("ignore", category=pd.errors.ParserWarning) warnings.filterwarnings("ignore", category=UserWarning) warnings.filterwarnings("ignore", category=FutureWarning) # 忽略FutureWarning class StockPredictionSystem: def __init__(self, config): """ 初始化预测系统 - 针对超大内存优化 :param config: 配置字典 """ self.config = config self.five_min_paths = { 'sz': config['five_min_sz_path'], 'sh': config['five_min_sh_path'] } self.daily_paths = { 'sz': config['daily_sz_path'], 'sh': config['daily_sh_path'] } self.output_path = config['output_path'] self.start_date = datetime.strptime(config['start_date'], '%Y-%m-%d') self.end_date = datetime.strptime(config['end_date'], '%Y-%m-%d') self.data = None self.features = None self.labels = None self.scaler = StandardScaler() self.cluster_model = None self.prediction_model = None self.feature_cols = [] self.temp_dir = os.path.join(self.output_path, "temp") os.makedirs(self.temp_dir, exist_ok=True) self.parquet_files = [] def print_memory_usage(self, step_name): """打印当前内存使用情况""" process = psutil.Process(os.getpid()) mem = process.memory_info().rss / 1024 ** 2 print(f"[{step_name}] 当前内存使用: {mem:.2f} MB") def safe_read_csv(self, file_path, required_columns): """ 安全读取CSV文件 - 使用PyArrow进行高效读取 :param file_path: 文件路径 :param required_columns: 需要的列名列表 :return: 读取的DataFrame或None """ try: # 检查文件大小 if not os.path.exists(file_path): print(f"文件不存在: {file_path}") return None file_size = os.path.getsize(file_path) if file_size == 0: print(f"文件 {file_path} 大小为0,跳过") return None # 使用PyArrow读取CSV read_options = pc.ReadOptions( use_threads=True, block_size=4096 * 1024 # 4MB块大小 ) parse_options = pc.ParseOptions(delimiter=',') convert_options = pc.ConvertOptions( include_columns=required_columns, column_types={ 'date': pa.string(), 'time': pa.string(), 'open': pa.float32(), 'high': pa.float32(), 'low': pa.float32(), 'close': pa.float32(), 'volume': pa.float32(), 'amount': pa.float32() } ) table = pc.read_csv( file_path, read_options=read_options, parse_options=parse_options, convert_options=convert_options ) # 转换为Pandas DataFrame df = table.to_pandas() # 检查是否读取到数据 if df.empty: print(f"文件 {file_path} 读取后为空") return None return df except Exception as e: print(f"读取文件 {file_path} 时出错: {str(e)}") return None def process_and_save_chunk(self, df, market, stock_code, chunk_index): """ 处理单个股票的数据块并保存为Parquet文件 - 内存优化版本 """ if df is None or df.empty: return None try: # 添加市场前缀 df['stock_code'] = f"{market}_{stock_code}" # 修复日期时间转换问题 df['date'] = df['date'].astype(str).str.zfill(8) # 填充为8位字符串 df['time'] = df['time'].astype(str) # 处理时间格式 df['time'] = df['time'].apply( lambda x: f"{x[:2]}:{x[2:4]}" if len(x) == 4 else x ) # 合并日期和时间 df['datetime'] = pd.to_datetime( df['date'] + ' ' + df['time'], format='%Y%m%d %H:%M', errors='coerce' ) # 删除无效的日期时间 df = df.dropna(subset=['datetime']) # 筛选日期范围 df = df[(df['datetime'] >= self.start_date) & (df['datetime'] <= self.end_date)] if df.empty: return None # 优化内存使用 df = df[['stock_code', 'datetime', 'open', 'high', 'low', 'close', 'volume', 'amount']] # 保存为Parquet文件 output_file = os.path.join(self.temp_dir, f"{market}_{stock_code}_{chunk_index}.parquet") # 使用PyArrow直接写入Parquet,避免Pandas中间转换 table = pa.Table.from_pandas(df, preserve_index=False) pq.write_table(table, output_file, compression='SNAPPY') return output_file except Exception as e: print(f"处理股票 {stock_code} 时出错: {str(e)}") return None def incremental_merge_parquet_files(self, parquet_files, batch_size=100): """ 增量合并Parquet文件 - 避免一次性加载所有数据 :param parquet_files: Parquet文件列表 :param batch_size: 每次合并的文件数量 :return: 合并后的Parquet文件路径 """ merged_file = os.path.join(self.temp_dir, "merged_data.parquet") # 如果文件已存在,删除 if os.path.exists(merged_file): os.remove(merged_file) # 分批合并文件 for i in tqdm(range(0, len(parquet_files), batch_size), desc="合并Parquet文件"): batch_files = parquet_files[i:i+batch_size] # 读取当前批次文件 tables = [] for file in batch_files: try: table = pq.read_table(file) tables.append(table) except Exception as e: print(f"读取文件 {file} 出错: {str(e)}") if not tables: continue # 合并当前批次 merged_table = pa.concat_tables(tables) # 追加到输出文件 if os.path.exists(merged_file): # 追加模式 with pq.ParquetWriter(merged_file, merged_table.schema) as writer: writer.write_table(merged_table) else: # 首次写入 pq.write_table(merged_table, merged_file) # 释放内存 del tables del merged_table gc.collect() return merged_file def load_and_preprocess_data(self): """ 加载和预处理数据 - 使用增量合并避免内存溢出 """ print("开始加载和预处理数据...") self.print_memory_usage("开始加载数据") # 创建临时目录 os.makedirs(self.temp_dir, exist_ok=True) parquet_files = [] # 加载五分钟线数据 for market, path in self.five_min_paths.items(): print(f"开始处理市场: {market}, 路径: {path}") file_count = 0 processed_count = 0 # 获取文件列表 csv_files = list(glob.glob(os.path.join(path, '*.csv'))) print(f"找到 {len(csv_files)} 个文件") for file_path in tqdm(csv_files, desc=f"处理 {market} 市场文件"): file_count += 1 stock_code = os.path.basename(file_path).split('.')[0] try: # 安全读取CSV文件 df = self.safe_read_csv(file_path, ['date', 'time', 'open', 'high', 'low', 'close', 'volume', 'amount']) if df is None: continue # 处理并保存为Parquet output_file = self.process_and_save_chunk(df, market, stock_code, processed_count) if output_file: parquet_files.append(output_file) processed_count += 1 # 每处理100个文件释放内存 if processed_count % 100 == 0: self.print_memory_usage(f"已处理 {processed_count} 个文件") gc.collect() except Exception as e: print(f"处理文件 {file_path} 时出错: {str(e)}") continue print(f"市场 {market} 完成: 共 {file_count} 个文件, 成功处理 {processed_count} 个文件") # 如果没有找到有效文件 if not parquet_files: raise ValueError("没有找到有效的五分钟线数据") print(f"开始增量合并 {len(parquet_files)} 个Parquet文件...") self.print_memory_usage("合并前") # 增量合并Parquet文件 merged_file = self.incremental_merge_parquet_files(parquet_files, batch_size=50) # 加载合并后的数据 print(f"加载合并后的数据: {merged_file}") self.data = pq.read_table(merged_file).to_pandas() # 优化内存使用 self.data['stock_code'] = self.data['stock_code'].astype('category') print(f"数据合并完成,共 {len(self.data)} 条记录") self.print_memory_usage("合并后") # 清理临时文件 for file in parquet_files: try: os.remove(file) except: pass # 加载日线数据 daily_data = [] daily_required_columns = ['date', 'open', 'high', 'low', 'close', 'volume'] for market, path in self.daily_paths.items(): print(f"开始处理日线市场: {market}, 路径: {path}") file_count = 0 processed_count = 0 # 获取所有CSV文件 all_files = list(glob.glob(os.path.join(path, '*.csv'))) print(f"找到 {len(all_files)} 个日线文件") for file_path in tqdm(all_files, desc=f"处理 {market} 日线文件"): file_count += 1 stock_code = os.path.basename(file_path).split('.')[0] try: # 安全读取CSV文件 df = self.safe_read_csv(file_path, daily_required_columns) if df is None or df.empty: continue # 添加市场前缀 df['stock_code'] = f"{market}_{stock_code}" # 转换日期格式 df['date'] = pd.to_datetime(df['date'], errors='coerce') # 删除无效日期 df = df.dropna(subset=['date']) # 筛选日期范围 df = df[(df['date'] >= self.start_date) & (df['date'] <= self.end_date)] if df.empty: continue # 优化内存使用 df = df[['stock_code', 'date', 'open', 'high', 'low', 'close', 'volume']] # 优化数据类型 - 修复错误: 使用astype而不是ast df['open'] = df['open'].astype(np.float32) df['high'] = df['high'].astype(np.float32) df['low'] = df['low'].astype(np.float32) df['close'] = df['close'].astype(np.float32) df['volume'] = df['volume'].astype(np.float32) daily_data.append(df) processed_count += 1 if processed_count % 100 == 0: self.print_memory_usage(f"已处理 {processed_count} 个日线文件") gc.collect() except Exception as e: print(f"处理日线文件 {file_path} 时出错: {str(e)}") continue print(f"日线市场 {market} 完成: 共 {file_count} 个文件, 成功处理 {processed_count} 个文件") # 合并日线数据 if daily_data: daily_df = pd.concat(daily_data, ignore_index=True) daily_df['stock_code'] = daily_df['stock_code'].astype('category') # 添加日线特征 self._add_daily_features(daily_df) else: print("警告: 没有找到日线数据") print(f"数据加载完成,共 {len(self.data)} 条记录") self.print_memory_usage("数据加载完成") def _add_daily_features(self, daily_df): """ 添加日线特征到五分钟线数据 - 使用内存优化技术 """ print("添加日线特征...") # 预处理日线数据 daily_df = daily_df.sort_values(['stock_code', 'date']) # 计算日线技术指标 - 修复FutureWarning daily_df['daily_ma5'] = daily_df.groupby('stock_code', observed=True)['close'].transform( lambda x: x.rolling(5).mean()) daily_df['daily_ma10'] = daily_df.groupby('stock_code', observed=True)['close'].transform( lambda x: x.rolling(10).mean()) daily_df['daily_vol_ma5'] = daily_df.groupby('stock_code', observed=True)['volume'].transform( lambda x: x.rolling(5).mean()) # 计算MACD - 使用更高效的方法 def calculate_macd(group): group = group.sort_values('date') if len(group) < 26: return group.assign(daily_macd=np.nan, daily_signal=np.nan) close_vals = group['close'].values.astype(np.float64) macd, signal, _ = ta.MACD(close_vals, fastperiod=12, slowperiod=26, signalperiod=9) return group.assign(daily_macd=macd, daily_signal=signal) daily_df = daily_df.groupby('stock_code', group_keys=False, observed=True).apply(calculate_macd) # 提取日期部分用于合并 self.data['date'] = self.data['datetime'].dt.date.astype('datetime64[ns]') # 优化数据类型 daily_df = daily_df[['stock_code', 'date', 'daily_ma5', 'daily_ma10', 'daily_vol_ma5', 'daily_macd', 'daily_signal']] daily_df['daily_ma5'] = daily_df['daily_ma5'].astype(np.float32) daily_df['daily_ma10'] = daily_df['daily_ma10'].astype(np.float32) daily_df['daily_vol_ma5'] = daily_df['daily_vol_ma5'].astype(np.float32) daily_df['daily_macd'] = daily_df['daily_macd'].astype(np.float32) daily_df['daily_signal'] = daily_df['daily_signal'].astype(np.float32) # 合并日线特征 self.data = pd.merge( self.data, daily_df, on=['stock_code', 'date'], how='left' ) # 删除临时列 del self.data['date'] # 释放内存 del daily_df gc.collect() def create_features(self): """ 创建特征工程 - 使用内存优化技术 """ print("开始创建特征...") self.print_memory_usage("创建特征前") if self.data is None: raise ValueError("请先加载数据") # 按股票和时间排序 self.data = self.data.sort_values(['stock_code', 'datetime']) # 特征列表 features = [] # 1. 基础特征 features.append('open') features.append('high') features.append('low') features.append('close') features.append('volume') features.append('amount') # 2. 技术指标 - 使用分组计算避免内存溢出 # 计算移动平均线 self.data['ma5'] = self.data.groupby('stock_code', observed=True)['close'].transform( lambda x: x.rolling(5, min_periods=1).mean()) self.data['ma10'] = self.data.groupby('stock_code', observed=True)['close'].transform( lambda x: x.rolling(10, min_periods=1).mean()) features.extend(['ma5', 'ma10']) # 计算RSI - 使用更高效的方法 print("计算RSI指标...") def calculate_rsi(group): group = group.sort_values('datetime') close = group['close'].values.astype(np.float64) rsi = ta.RSI(close, timeperiod=14) return group.assign(rsi=rsi) self.data = self.data.groupby('stock_code', group_keys=False, observed=True).apply(calculate_rsi) features.append('rsi') # 3. 波动率特征 print("计算波动率特征...") self.data['price_change'] = self.data.groupby('stock_code', observed=True)['close'].pct_change() self.data['volatility'] = self.data.groupby('stock_code', observed=True)['price_change'].transform( lambda x: x.rolling(10, min_periods=1).std()) features.append('volatility') # 4. 成交量特征 self.data['vol_change'] = self.data.groupby('stock_code', observed=True)['volume'].pct_change() self.data['vol_ma5'] = self.data.groupby('stock_code', observed=True)['volume'].transform( lambda x: x.rolling(5, min_periods=1).mean()) features.extend(['vol_change', 'vol_ma5']) # 5. 日线特征 features.extend(['daily_ma5', 'daily_ma10', 'daily_vol_ma5', 'daily_macd', 'daily_signal']) # 保存特征列 self.feature_cols = features # 处理缺失值 - 只删除特征列中的缺失值 self.data = self.data.dropna(subset=features) # 优化数据类型 - 使用astype而不是ast for col in features: if self.data[col].dtype == np.float64: self.data[col] = self.data[col].astype(np.float32) print(f"特征创建完成,共 {len(features)} 个特征") self.print_memory_usage("创建特征后") def clean_data(self): """ 清洗数据 - 处理无穷大和超出范围的值(修复索引问题) """ print("开始数据清洗...") self.print_memory_usage("清洗前") # 1. 检查无穷大值 inf_mask = np.isinf(self.data[self.feature_cols].values) inf_rows = np.any(inf_mask, axis=1) inf_count = np.sum(inf_rows) if inf_count > 0: print(f"发现 {inf_count} 行包含无穷大值,正在清理...") # 将无穷大替换为NaN self.data[self.feature_cols] = self.data[self.feature_cols].replace([np.inf, -np.inf], np.nan) # 2. 检查超出float32范围的值 float32_max = np.finfo(np.float32).max float32_min = np.finfo(np.float32).min # 统计超出范围的值 overflow_count = 0 for col in self.feature_cols: col_max = self.data[col].max() col_min = self.data[col].min() if col_max > float32_max or col_min < float32_min: overflow_count += 1 print(f"列 {col} 包含超出float32范围的值: min={col_min}, max={col_max}") if overflow_count > 0: print(f"共发现 {overflow_count} 列包含超出float32范围的值,正在处理...") # 缩放到安全范围 for col in self.feature_cols: col_min = self.data[col].min() col_max = self.data[col].max() # 如果范围过大,进行缩放 if col_max - col_min > 1e6: print(f"列 {col} 范围过大 ({col_min} 到 {col_max}),进行缩放...") self.data[col] = (self.data[col] - col_min) / (col_max - col_min) # 3. 处理NaN值 - 修复索引问题 nan_count = self.data[self.feature_cols].isna().sum().sum() if nan_count > 0: print(f"发现 {nan_count} 个NaN值,使用前向填充处理...") # 方法1: 使用transform保持索引一致 for col in self.feature_cols: self.data[col] = self.data.groupby('stock_code', observed=True)[col].transform( lambda x: x.fillna(method='ffill').fillna(method='bfill').fillna(0) ) # 方法2: 使用循环逐组处理(备用方法) # for stock in self.data['stock_code'].unique(): # stock_mask = self.data['stock_code'] == stock # self.data.loc[stock_mask, self.feature_cols] = self.data.loc[stock_mask, self.feature_cols].fillna(method='ffill').fillna(method='bfill').fillna(0) # 4. 最终检查 cleaned = True for col in self.feature_cols: if np.isinf(self.data[col]).any() or self.data[col].isna().any(): print(f"警告: 列 {col} 仍包含无效值") cleaned = False if cleaned: print("数据清洗完成") else: print("数据清洗完成,但仍存在部分问题") self.print_memory_usage("清洗后") def create_labels(self): """ 创建标签 - 添加新条件: 1. 次日(T+1)收盘价(15:00)比次日(T+1)9:35收盘价大5% 2. 后日(T+2)9:35收盘价比次日(T+1)收盘价(15:00)大1% """ print("开始创建标签...") self.print_memory_usage("创建标签前") if self.data is None: raise ValueError("请先加载数据") # 按股票和时间排序 self.data = self.data.sort_values(['stock_code', 'datetime']) # 添加日期列用于合并 self.data['date'] = self.data['datetime'].dt.date # 创建每日关键时间点价格数据 daily_key_points = self.data.groupby(['stock_code', 'date']).apply( lambda x: pd.Series({ 'time9_35_close': x[x['datetime'].dt.time == pd.to_datetime('09:35:00').time()]['close'].iloc[0] if not x[x['datetime'].dt.time == pd.to_datetime('09:35:00').time()].empty else np.nan, 'time15_00_close': x[x['datetime'].dt.time == pd.to_datetime('15:00:00').time()]['close'].iloc[0] if not x[x['datetime'].dt.time == pd.to_datetime('15:00:00').time()].empty else np.nan }) ).reset_index() # 为每日关键点添加次日(T+1)和后日(T+2)数据 daily_key_points = daily_key_points.sort_values(['stock_code', 'date']) daily_key_points['next_date'] = daily_key_points.groupby('stock_code')['date'].shift(-1) daily_key_points['next_next_date'] = daily_key_points.groupby('stock_code')['date'].shift(-2) # 合并次日(T+1)数据 daily_key_points = pd.merge( daily_key_points, daily_key_points[['stock_code', 'date', 'time9_35_close', 'time15_00_close']].rename( columns={ 'date': 'next_date', 'time9_35_close': 'next_time9_35_close', 'time15_00_close': 'next_time15_00_close' } ), on=['stock_code', 'next_date'], how='left' ) # 合并后日(T+2)数据 daily_key_points = pd.merge( daily_key_points, daily_key_points[['stock_code', 'date', 'time9_35_close']].rename( columns={ 'date': 'next_next_date', 'time9_35_close': 'next_next_time9_35_close' } ), on=['stock_code', 'next_next_date'], how='left' ) # 将关键点数据合并回原始数据 self.data = pd.merge( self.data, daily_key_points[['stock_code', 'date', 'next_time9_35_close', 'next_time15_00_close', 'next_next_time9_35_close']], on=['stock_code', 'date'], how='left' ) # 计算新条件 cond1 = (self.data['next_time15_00_close'] > self.data['next_time9_35_close'] * 1.05) cond2 = (self.data['next_next_time9_35_close'] > self.data['next_time15_00_close'] * 1.01) # 创建标签(满足两个条件则为1) self.data['label'] = np.where(cond1 & cond2, 1, 0).astype(np.int8) # 删除中间列 self.data.drop([ 'date', 'next_time9_35_close', 'next_time15_00_close', 'next_next_time9_35_close' ], axis=1, inplace=True, errors='ignore') # 保存标签 self.labels = self.data['label'] # 分析标签分布 label_counts = self.data['label'].value_counts(normalize=True) print(f"标签分布:\n{label_counts}") print("标签创建完成") self.print_memory_usage("创建标签后") def perform_clustering(self, n_clusters=5, batch_size=100000): """ 执行聚类分析 - 使用MiniBatchKMeans处理大数据 :param n_clusters: 聚类数量 :param batch_size: 每次处理的样本数量 """ print(f"开始聚类分析,聚类数: {n_clusters}...") self.print_memory_usage("聚类前") if self.feature_cols is None: raise ValueError("请先创建特征") # 添加数据清洗步骤 self.clean_data() # 标准化特征 print("标准化特征...") self.scaler.fit(self.data[self.feature_cols]) # 使用MiniBatchKMeans进行聚类 self.cluster_model = MiniBatchKMeans( n_clusters=n_clusters, batch_size=batch_size, random_state=42, n_init=3 ) # 分批处理数据 print("分批聚类...") n_samples = len(self.data) for i in tqdm(range(0, n_samples, batch_size), desc="聚类进度"): batch_data = self.data.iloc[i:i+batch_size] scaled_batch = self.scaler.transform(batch_data[self.feature_cols]) self.cluster_model.partial_fit(scaled_batch) # 获取最终聚类结果 print("获取聚类结果...") clusters = [] for i in tqdm(range(0, n_samples, batch_size), desc="分配聚类"): batch_data = self.data.iloc[i:i+batch_size] scaled_batch = self.scaler.transform(batch_data[self.feature_cols]) batch_clusters = self.cluster_model.predict(scaled_batch) clusters.append(batch_clusters) # 添加聚类结果到数据 self.data['cluster'] = np.concatenate(clusters) self.feature_cols.append('cluster') # 分析聚类结果 cluster_summary = self.data.groupby('cluster')['label'].agg(['mean', 'count']) print("聚类结果分析:") print(cluster_summary) # 保存聚类模型 cluster_model_path = os.path.join( self.output_path, "分钟线预测训练聚类模型.pkl" ) joblib.dump(self.cluster_model, cluster_model_path) print(f"聚类模型已保存至: {cluster_model_path}") self.print_memory_usage("聚类后") def train_prediction_model(self, sample_fraction=0.1): """ 训练预测模型 - 使用数据抽样减少内存使用 :param sample_fraction: 抽样比例 """ print("开始训练预测模型...") self.print_memory_usage("训练模型前") if self.feature_cols is None or self.labels is None: raise ValueError("请先创建特征和标签") # 抽样数据 if sample_fraction < 1.0: print(f"抽样 {sample_fraction*100:.1f}% 数据用于训练") sample_data = self.data.sample(frac=sample_fraction, random_state=42) X = sample_data[self.feature_cols] y = sample_data['label'] else: X = self.data[self.feature_cols] y = self.labels # 检查类别分布 if y.nunique() < 2: print("警告: 只有一个类别的数据,无法训练模型") return # 划分训练集和测试集 X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.2, random_state=42, stratify=y ) # 训练随机森林分类器 self.prediction_model = RandomForestClassifier( n_estimators=100, # 减少树的数量 max_depth=8, # 减小最大深度 min_samples_split=10, class_weight='balanced', random_state=42, n_jobs=-1 ) self.prediction_model.fit(X_train, y_train) # 评估模型 y_pred = self.prediction_model.predict(X_test) print("模型评估报告:") print(classification_report(y_test, y_pred)) # 打印混淆矩阵 cm = confusion_matrix(y_test, y_pred) print("混淆矩阵:") print(cm) # 保存预测模型 model_path = os.path.join( self.output_path, "分钟线预测训练模型.pkl" ) joblib.dump(self.prediction_model, model_path) print(f"预测模型已保存至: {model_path}") self.print_memory_usage("训练模型后") def predict_and_save(self, output_results=True): """ 使用模型进行预测并保存结果 :param output_results: 是否输出预测结果 """ print("开始预测...") self.print_memory_usage("预测前") if self.prediction_model is None: raise ValueError("请先训练预测模型") # 准备预测数据 X = self.data[self.feature_cols] # 分批预测 predictions = [] batch_size = 10000 n_samples = len(X) for i in tqdm(range(0, n_samples, batch_size), desc="预测进度"): batch_data = X.iloc[i:i+batch_size] batch_pred = self.prediction_model.predict(batch_data) predictions.append(batch_pred) # 合并预测结果 self.data['prediction'] = np.concatenate(predictions) # 保存预测结果 if output_results: output_file = os.path.join(self.output_path, "预测结果.csv") self.data[['stock_code', 'datetime', 'close', 'label', 'prediction']].to_csv(output_file, index=False) print(f"预测结果已保存至: {output_file}") # 分析预测效果 accuracy = (self.data['label'] == self.data['prediction']).mean() print(f"整体预测准确率: {accuracy:.4f}") # 按股票分析预测效果 stock_accuracy = self.data.groupby('stock_code').apply( lambda x: (x['label'] == x['prediction']).mean() ) print("\n股票预测准确率统计:") print(stock_accuracy.describe()) self.print_memory_usage("预测后") def run(self, output_results=True, sample_fraction=0.1): """ 运行整个流程 - 使用内存优化技术 """ try: # 分步执行,每步完成后释放内存 self.load_and_preprocess_data() gc.collect() self.print_memory_usage("数据加载后") self.create_features() gc.collect() self.print_memory_usage("特征创建后") self.create_labels() # 使用新的标签创建方法 gc.collect() self.print_memory_usage("标签创建后") self.perform_clustering(n_clusters=self.config.get('n_clusters', 5)) gc.collect() self.print_memory_usage("聚类后") self.train_prediction_model(sample_fraction=sample_fraction) gc.collect() self.print_memory_usage("模型训练后") self.predict_and_save(output_results) gc.collect() self.print_memory_usage("预测后") print("训练和预测流程完成!") except KeyboardInterrupt: print("用户中断执行") except Exception as e: print(f"运行过程中出错: {str(e)}") import traceback traceback.print_exc() # 配置参数 config = { # 数据路径配置 'five_min_sz_path': r"D:\股票量化数据库\股票五分钟线csv数据\深证", 'five_min_sh_path': r"D:\股票量化数据库\股票五分钟线csv数据\上证", 'daily_sz_path': r"D:\股票量化数据库\股票csv数据\深证", 'daily_sh_path': r"D:\股票量化数据库\股票csv数据\上证", # 输出路径 'output_path': r"D:\股票量化数据库\预测结果", # 时间范围配置 'start_date': '2023-09-08', 'end_date': '2025-08-07', # 聚类配置 'n_clusters': 5 } # 创建并运行系统 if __name__ == "__main__": # 打印环境信息 print(f"Python版本: {sys.version}") print(f"Pandas版本: {pd.__version__}") # 是否输出预测结果 output_results = True # 抽样比例 (0.1 = 10%) sample_fraction = 0.1 # 设置Pandas内存选项 pd.set_option('mode.chained_assignment', None) pd.set_option('display.max_columns', None) # 设置内存优化选项 pd.set_option('compute.use_numexpr', True) pd.set_option('compute.use_bottleneck', True) # 创建并运行系统 system = StockPredictionSystem(config) system.run(output_results=output_results, sample_fraction=sample_fraction)
最新发布
08-10
<think>我们注意到在`train_prediction_model`方法中,当前使用了抽样(sample_fraction)来减少训练数据量以节省内存。 现在要求改为使用所有数据进行训练,但考虑到内存限制,我们需要使用增量学习(out-of-core learning)或分批训练的方法。 由于随机森林本身不支持增量学习(partial_fit),我们可以考虑使用支持增量学习的模型,例如SGDClassifier(用于线性模型)或使用MiniBatchKMeans(但这里我们需要分类模型)。 然而,我们也可以使用随机森林的增量学习替代方案,例如使用`warm_start`参数和增加树的数量分批训练,但这样仍然需要将所有数据加载到内存。 另一种方法是使用梯度提升树(如XGBoost或LightGBM)支持分批读取数据训练,但这里我们使用的是随机森林。 考虑到代码中已经使用了随机森林,并且数据量可能很大,我们有以下几种选择: 1. 使用支持增量学习的分类器(如SGDClassifier)替换随机森林,但线性模型可能效果不如随机森林。 2. 使用梯度提升树(如LightGBM)并利用其支持从文件加载数据分批训练的功能,但需要修改模型和特征处理流程。 3. 继续使用随机森林,但通过分批采样(bagging)的方式训练多个随机森林然后集成(类似于ExtraTrees),但这样实现复杂。 考虑到时间,我们可以采用第三种方法的分批训练然后集成,但这里我们为了简化,采用另一种方式:使用`joblib`的并行和`warm_start`来逐步增加树的数量,但这样仍然需要一次加载所有数据。 然而,我们注意到在`train_prediction_model`方法中已经有一个抽样参数,现在要求去掉抽样,使用全量数据。如果内存不足,全量数据训练可能会导致内存溢出。 因此,我们需要调整训练方法,使其能够分批处理数据而不需要一次性将所有数据加载到内存。 这里我们选择使用LightGBM,因为它支持从文件或迭代器加载数据训练,并且内存效率高。同时,LightGBM的训练速度也很快。 修改计划: 1. 安装lightgbm(如果尚未安装,需要在代码中检查并提示) 2. 修改`train_prediction_model`方法,使用LightGBM代替随机森林。 3. 使用LightGBM的train方法,并利用`callbacks`来显示进度。 4. 由于数据量可能很大,我们使用LightGBM的Dataset来分批读取数据(但这里我们仍然需要将数据整体放入内存,除非我们使用文件缓存)。 另一种方式是使用LightGBM的增量学习(通过多次调用train,每次传递一个chunk),但LightGBM本身支持通过pandas DataFrame初始化,而且可以设置`init_model`来继续训练,但这样并不节省内存。 为了节省内存,我们可以将数据保存到LightGBM的二进制格式,然后从文件加载,这样LightGBM在训练时不会将整个数据集加载到内存,而是按需加载。 步骤: a) 将训练数据(包括训练集和测试集)保存到LightGBM的二进制文件(使用`lgb.Dataset`的`save_binary`方法)。 b) 然后使用`lgb.train`方法训练,指定数据文件路径。 但是,我们也可以不保存到文件,而是直接使用Dataset,并且设置`bin_construct_sample_cnt`等参数来减少内存使用。 考虑到时间,我们采用直接使用Dataset的方式,但注意,这仍然需要将数据转换为numpy数组,可能会占用大量内存。 因此,我们折中:使用分批加载数据到LightGBM训练,即每次加载一个batch,然后调用`train`多次进行增量学习(继续训练)。但是,LightGBM的增量学习(继续训练)需要之前训练的模型作为基础。 具体步骤(在`train_prediction_model`方法中): 1. 将数据划分为训练集和测试集(注意:不能一次性加载所有数据,所以我们需要自己管理分批) 2. 使用一个循环,每次从训练集中加载一个批次(例如10000条记录),然后创建一个Dataset,并训练模型(如果是第一轮则初始化,否则继续训练) 3. 在每一轮训练后,评估测试集的性能,并决定是否停止。 但是,这种方法需要将整个测试集放在内存中,而且训练集需要分批读取。 考虑到我们的数据已经加载在内存中(尽管很大,但我们已经通过优化在内存中),而且我们要求使用全量数据,所以如果内存足够,我们可以直接使用全量数据训练。如果内存不足,我们可能需要将数据保存到临时文件,然后让LightGBM从文件读取。 因此,我们采用以下步骤: 1. 将训练集和测试集保存到临时CSV文件(或二进制文件)。 2. 使用LightGBM的从文件加载数据的方式,这样在训练时不会一次性加载所有数据到内存。 修改代码: 步骤: a) 在`train_prediction_model`方法中,不再抽样,而是使用全量数据(但注意内存,所以如果数据太大,我们使用文件缓存的方式)。 b) 划分训练集和测试集,并保存为临时文件(使用`lgb.Dataset`的`save_binary`?或者使用CSV?) c) 创建训练集和测试集的Dataset,指定文件路径。 d) 训练模型。 e) 删除临时文件。 但是,由于我们的数据已经在内存中,保存到文件会消耗I/O时间,但可以节省内存。 具体实现: 由于我们不确定数据量,所以我们可以先尝试使用内存中的Dataset,如果内存不足再考虑文件方式。但为了保险,我们直接使用文件方式。 修改后的`train_prediction_model`方法(去掉抽样,使用全量数据,并改用LightGBM): 注意:需要安装lightgbm,如果没有安装,则抛出异常。 代码修改如下: 1. 导入lightgbm 2. 创建临时目录保存训练数据 3. 将训练集和测试集保存为LibSVM格式(或CSV)?LightGBM支持CSV,但需要指定列类型。我们使用二进制的Dataset格式更高效。 步骤: - 将训练集和测试集的特征和标签合并为一个DataFrame(特征和标签在同一文件中,标签在第一列?) - 保存为二进制文件(使用`lgb.Dataset`的`save_binary`方法) 但是,我们也可以直接使用`lgb.Dataset`从numpy数组创建,然后保存为二进制文件,然后释放内存。 具体步骤: 由于数据已经在内存,我们这样做: train_data = lgb.Dataset(X_train, label=y_train) train_data.save_binary('train_data.bin') del train_data, X_train, y_train valid_data = lgb.Dataset(X_test, label=y_test, reference=train_data) valid_data.save_binary('valid_data.bin') del valid_data, X_test, y_test 然后从文件加载: train_data = lgb.Dataset('train_data.bin') valid_data = lgb.Dataset('valid_data.bin') 然后训练模型。 这样,在训练时,LightGBM会按需加载数据,减少内存使用。 修改代码: 注意:由于我们使用了全量数据,所以不再抽样。 另外,由于标签不平衡,我们需要设置`is_unbalance`或`scale_pos_weight`。 我们使用`class_weight`参数,但LightGBM中可以通过`scale_pos_weight`设置(正例权重)或设置`is_unbalance`为True。 这里我们使用`is_unbalance=True`。 参数设置: params = { 'objective': 'binary', 'metric': 'binary_logloss,auc', 'boosting_type': 'gbdt', 'num_leaves': 31, 'learning_rate': 0.05, 'feature_fraction': 0.9, 'is_unbalance': True, 'verbose': 1 } 训练: model = lgb.train(params, train_data, valid_sets=[valid_data], num_boost_round=1000, early_stopping_rounds=50, verbose_eval=10) 然后保存模型。 最后删除临时文件。 修改后的`train_prediction_model`方法代码: 注意:由于我们使用了临时文件,需要确保在方法结束时删除它们。 但是,如果程序中断,临时文件可能不会被删除,所以我们可以使用`tempfile`模块来创建临时文件。 不过,为了简单,我们使用一个临时目录,并在方法结束时删除。 考虑到内存,我们这样做: 步骤: 1. 划分训练集和测试集(全量数据) 2. 创建临时目录 3. 将训练集和测试集保存为二进制文件 4. 从文件加载Dataset 5. 训练模型 6. 删除临时文件 代码修改如下:</think>为了修改代码,将训练模型的方法从抽样改为使用所有数据进行训练,我们需要调整`train_prediction_model`方法。当前代码在训练时使用抽样比例`sample_fraction`来减少数据量,我们需要移除这个抽样步骤,直接使用所有数据。 以下是修改后的`train_prediction_model`方法: ```python def train_prediction_model(self): """ 训练预测模型 - 使用所有数据进行训练 """ print("开始训练预测模型(使用所有数据)...") self.print_memory_usage("训练模型前") if self.feature_cols is None or self.labels is None: raise ValueError("请先创建特征和标签") # 使用所有数据 X = self.data[self.feature_cols] y = self.labels # 检查类别分布 if y.nunique() < 2: print("警告: 只有一个类别的数据,无法训练模型") return # 划分训练集和测试集 X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.2, random_state=42, stratify=y ) # 训练随机森林分类器 self.prediction_model = RandomForestClassifier( n_estimators=100, max_depth=8, min_samples_split=10, class_weight='balanced', random_state=42, n_jobs=-1 ) self.prediction_model.fit(X_train, y_train) # 评估模型 y_pred = self.prediction_model.predict(X_test) print("模型评估报告:") print(classification_report(y_test, y_pred)) # 打印混淆矩阵 cm = confusion_matrix(y_test, y_pred) print("混淆矩阵:") print(cm) # 保存预测模型 model_path = os.path.join( self.output_path, "分钟线预测训练模型.pkl" ) joblib.dump(self.prediction_model, model_path) print(f"预测模型已保存至: {model_path}") self.print_memory_usage("训练模型后") ``` 同时,我们需要修改`run`方法,移除`sample_fraction`参数: ```python def run(self, output_results=True): """ 运行整个流程 - 使用内存优化技术 """ try: # 分步执行,每步完成后释放内存 self.load_and_preprocess_data() gc.collect() self.print_memory_usage("数据加载后") self.create_features() gc.collect() self.print_memory_usage("特征创建后") self.create_labels() # 使用新的标签创建方法 gc.collect() self.print_memory_usage("标签创建后") self.perform_clustering(n_clusters=self.config.get('n_clusters', 5)) gc.collect() self.print_memory_usage("聚类后") self.train_prediction_model() # 使用修改后的训练方法 gc.collect() self.print_memory_usage("模型训练后") self.predict_and_save(output_results) gc.collect() self.print_memory_usage("预测后") print("训练和预测流程完成!") except KeyboardInterrupt: print("用户中断执行") except Exception as e: print(f"运行过程中出错: {str(e)}") import traceback traceback.print_exc() ``` 最后,在主程序部分移除`sample_fraction`参数: ```python # 创建并运行系统 if __name__ == "__main__": # 打印环境信息 print(f"Python版本: {sys.version}") print(f"Pandas版本: {pd.__version__}") # 是否输出预测结果 output_results = True # 设置Pandas内存选项 pd.set_option('mode.chained_assignment', None) pd.set_option('display.max_columns', None) # 设置内存优化选项 pd.set_option('compute.use_numexpr', True) pd.set_option('compute.use_bottleneck', True) # 创建并运行系统 system = StockPredictionSystem(config) system.run(output_results=output_results) # 移除sample_fraction参数 ``` ### 修改说明: 1. **移除了抽样逻辑**: - 删除了`sample_fraction`参数和相关的抽样代码 - 直接使用`self.data[self.feature_cols]`和`self.labels`作为训练数据 2. **优化了内存管理**: - 保留了原有的内存监控和垃圾回收机制 - 使用`stratify=y`确保训练集和测试集的类别分布一致 3. **保持了原有的模型评估机制**: - 保留分类报告和混淆矩阵输出 - 保持模型保存功能 4. **修改了调用方式**: - 移除了`run`方法中的`sample_fraction`参数 - 简化了训练流程,直接使用全量数据 ### 注意事项: 1. **内存使用**:使用所有数据训练可能会显著增加内存消耗,特别是当数据集很大时。确保系统有足够的内存资源。 2. **训练时间**:使用全量数据训练会增加训练时间,特别是对于随机森林这类算法。 3. **类别平衡**:代码中保留了`class_weight='balanced'`参数,这有助于处理类别不平衡问题。 4. **测试集划分**:保留了20%的数据作为测试集用于模型评估,这有助于验证模型性能。 5. **模型保存**:训练完成后,模型仍然会保存到指定路径,方便后续使用。 这些修改确保了模型训练过程使用所有可用数据,同时保留了原有的内存优化和评估机制。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值