帮我修改下面代码,把训练模型的方法从抽样改成所有数据进行训练
# -*- coding: utf-8 -*-
"""
Created on Sat Aug 9 11:56:46 2025
@author: srx20
"""
# -*- coding: utf-8 -*-
"""
Created on Sat Aug 9 10:33:06 2025
@author: srx20
"""
import os
import glob
import pandas as pd
import numpy as np
import joblib
import gc
from datetime import datetime, timedelta
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import MiniBatchKMeans
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix
import talib as ta
import warnings
import chardet
import psutil
import sys
import pyarrow as pa
import pyarrow.parquet as pq
import pyarrow.csv as pc
from tqdm import tqdm
from sklearn.model_selection import train_test_split
# 修复警告处理
warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=pd.errors.ParserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning) # 忽略FutureWarning
class StockPredictionSystem:
def __init__(self, config):
"""
初始化预测系统 - 针对超大内存优化
:param config: 配置字典
"""
self.config = config
self.five_min_paths = {
'sz': config['five_min_sz_path'],
'sh': config['five_min_sh_path']
}
self.daily_paths = {
'sz': config['daily_sz_path'],
'sh': config['daily_sh_path']
}
self.output_path = config['output_path']
self.start_date = datetime.strptime(config['start_date'], '%Y-%m-%d')
self.end_date = datetime.strptime(config['end_date'], '%Y-%m-%d')
self.data = None
self.features = None
self.labels = None
self.scaler = StandardScaler()
self.cluster_model = None
self.prediction_model = None
self.feature_cols = []
self.temp_dir = os.path.join(self.output_path, "temp")
os.makedirs(self.temp_dir, exist_ok=True)
self.parquet_files = []
def print_memory_usage(self, step_name):
"""打印当前内存使用情况"""
process = psutil.Process(os.getpid())
mem = process.memory_info().rss / 1024 ** 2
print(f"[{step_name}] 当前内存使用: {mem:.2f} MB")
def safe_read_csv(self, file_path, required_columns):
"""
安全读取CSV文件 - 使用PyArrow进行高效读取
:param file_path: 文件路径
:param required_columns: 需要的列名列表
:return: 读取的DataFrame或None
"""
try:
# 检查文件大小
if not os.path.exists(file_path):
print(f"文件不存在: {file_path}")
return None
file_size = os.path.getsize(file_path)
if file_size == 0:
print(f"文件 {file_path} 大小为0,跳过")
return None
# 使用PyArrow读取CSV
read_options = pc.ReadOptions(
use_threads=True,
block_size=4096 * 1024 # 4MB块大小
)
parse_options = pc.ParseOptions(delimiter=',')
convert_options = pc.ConvertOptions(
include_columns=required_columns,
column_types={
'date': pa.string(),
'time': pa.string(),
'open': pa.float32(),
'high': pa.float32(),
'low': pa.float32(),
'close': pa.float32(),
'volume': pa.float32(),
'amount': pa.float32()
}
)
table = pc.read_csv(
file_path,
read_options=read_options,
parse_options=parse_options,
convert_options=convert_options
)
# 转换为Pandas DataFrame
df = table.to_pandas()
# 检查是否读取到数据
if df.empty:
print(f"文件 {file_path} 读取后为空")
return None
return df
except Exception as e:
print(f"读取文件 {file_path} 时出错: {str(e)}")
return None
def process_and_save_chunk(self, df, market, stock_code, chunk_index):
"""
处理单个股票的数据块并保存为Parquet文件 - 内存优化版本
"""
if df is None or df.empty:
return None
try:
# 添加市场前缀
df['stock_code'] = f"{market}_{stock_code}"
# 修复日期时间转换问题
df['date'] = df['date'].astype(str).str.zfill(8) # 填充为8位字符串
df['time'] = df['time'].astype(str)
# 处理时间格式
df['time'] = df['time'].apply(
lambda x: f"{x[:2]}:{x[2:4]}" if len(x) == 4 else x
)
# 合并日期和时间
df['datetime'] = pd.to_datetime(
df['date'] + ' ' + df['time'],
format='%Y%m%d %H:%M',
errors='coerce'
)
# 删除无效的日期时间
df = df.dropna(subset=['datetime'])
# 筛选日期范围
df = df[(df['datetime'] >= self.start_date) &
(df['datetime'] <= self.end_date)]
if df.empty:
return None
# 优化内存使用
df = df[['stock_code', 'datetime', 'open', 'high', 'low', 'close', 'volume', 'amount']]
# 保存为Parquet文件
output_file = os.path.join(self.temp_dir, f"{market}_{stock_code}_{chunk_index}.parquet")
# 使用PyArrow直接写入Parquet,避免Pandas中间转换
table = pa.Table.from_pandas(df, preserve_index=False)
pq.write_table(table, output_file, compression='SNAPPY')
return output_file
except Exception as e:
print(f"处理股票 {stock_code} 时出错: {str(e)}")
return None
def incremental_merge_parquet_files(self, parquet_files, batch_size=100):
"""
增量合并Parquet文件 - 避免一次性加载所有数据
:param parquet_files: Parquet文件列表
:param batch_size: 每次合并的文件数量
:return: 合并后的Parquet文件路径
"""
merged_file = os.path.join(self.temp_dir, "merged_data.parquet")
# 如果文件已存在,删除
if os.path.exists(merged_file):
os.remove(merged_file)
# 分批合并文件
for i in tqdm(range(0, len(parquet_files), batch_size), desc="合并Parquet文件"):
batch_files = parquet_files[i:i+batch_size]
# 读取当前批次文件
tables = []
for file in batch_files:
try:
table = pq.read_table(file)
tables.append(table)
except Exception as e:
print(f"读取文件 {file} 出错: {str(e)}")
if not tables:
continue
# 合并当前批次
merged_table = pa.concat_tables(tables)
# 追加到输出文件
if os.path.exists(merged_file):
# 追加模式
with pq.ParquetWriter(merged_file, merged_table.schema) as writer:
writer.write_table(merged_table)
else:
# 首次写入
pq.write_table(merged_table, merged_file)
# 释放内存
del tables
del merged_table
gc.collect()
return merged_file
def load_and_preprocess_data(self):
"""
加载和预处理数据 - 使用增量合并避免内存溢出
"""
print("开始加载和预处理数据...")
self.print_memory_usage("开始加载数据")
# 创建临时目录
os.makedirs(self.temp_dir, exist_ok=True)
parquet_files = []
# 加载五分钟线数据
for market, path in self.five_min_paths.items():
print(f"开始处理市场: {market}, 路径: {path}")
file_count = 0
processed_count = 0
# 获取文件列表
csv_files = list(glob.glob(os.path.join(path, '*.csv')))
print(f"找到 {len(csv_files)} 个文件")
for file_path in tqdm(csv_files, desc=f"处理 {market} 市场文件"):
file_count += 1
stock_code = os.path.basename(file_path).split('.')[0]
try:
# 安全读取CSV文件
df = self.safe_read_csv(file_path, ['date', 'time', 'open', 'high', 'low', 'close', 'volume', 'amount'])
if df is None:
continue
# 处理并保存为Parquet
output_file = self.process_and_save_chunk(df, market, stock_code, processed_count)
if output_file:
parquet_files.append(output_file)
processed_count += 1
# 每处理100个文件释放内存
if processed_count % 100 == 0:
self.print_memory_usage(f"已处理 {processed_count} 个文件")
gc.collect()
except Exception as e:
print(f"处理文件 {file_path} 时出错: {str(e)}")
continue
print(f"市场 {market} 完成: 共 {file_count} 个文件, 成功处理 {processed_count} 个文件")
# 如果没有找到有效文件
if not parquet_files:
raise ValueError("没有找到有效的五分钟线数据")
print(f"开始增量合并 {len(parquet_files)} 个Parquet文件...")
self.print_memory_usage("合并前")
# 增量合并Parquet文件
merged_file = self.incremental_merge_parquet_files(parquet_files, batch_size=50)
# 加载合并后的数据
print(f"加载合并后的数据: {merged_file}")
self.data = pq.read_table(merged_file).to_pandas()
# 优化内存使用
self.data['stock_code'] = self.data['stock_code'].astype('category')
print(f"数据合并完成,共 {len(self.data)} 条记录")
self.print_memory_usage("合并后")
# 清理临时文件
for file in parquet_files:
try:
os.remove(file)
except:
pass
# 加载日线数据
daily_data = []
daily_required_columns = ['date', 'open', 'high', 'low', 'close', 'volume']
for market, path in self.daily_paths.items():
print(f"开始处理日线市场: {market}, 路径: {path}")
file_count = 0
processed_count = 0
# 获取所有CSV文件
all_files = list(glob.glob(os.path.join(path, '*.csv')))
print(f"找到 {len(all_files)} 个日线文件")
for file_path in tqdm(all_files, desc=f"处理 {market} 日线文件"):
file_count += 1
stock_code = os.path.basename(file_path).split('.')[0]
try:
# 安全读取CSV文件
df = self.safe_read_csv(file_path, daily_required_columns)
if df is None or df.empty:
continue
# 添加市场前缀
df['stock_code'] = f"{market}_{stock_code}"
# 转换日期格式
df['date'] = pd.to_datetime(df['date'], errors='coerce')
# 删除无效日期
df = df.dropna(subset=['date'])
# 筛选日期范围
df = df[(df['date'] >= self.start_date) &
(df['date'] <= self.end_date)]
if df.empty:
continue
# 优化内存使用
df = df[['stock_code', 'date', 'open', 'high', 'low', 'close', 'volume']]
# 优化数据类型 - 修复错误: 使用astype而不是ast
df['open'] = df['open'].astype(np.float32)
df['high'] = df['high'].astype(np.float32)
df['low'] = df['low'].astype(np.float32)
df['close'] = df['close'].astype(np.float32)
df['volume'] = df['volume'].astype(np.float32)
daily_data.append(df)
processed_count += 1
if processed_count % 100 == 0:
self.print_memory_usage(f"已处理 {processed_count} 个日线文件")
gc.collect()
except Exception as e:
print(f"处理日线文件 {file_path} 时出错: {str(e)}")
continue
print(f"日线市场 {market} 完成: 共 {file_count} 个文件, 成功处理 {processed_count} 个文件")
# 合并日线数据
if daily_data:
daily_df = pd.concat(daily_data, ignore_index=True)
daily_df['stock_code'] = daily_df['stock_code'].astype('category')
# 添加日线特征
self._add_daily_features(daily_df)
else:
print("警告: 没有找到日线数据")
print(f"数据加载完成,共 {len(self.data)} 条记录")
self.print_memory_usage("数据加载完成")
def _add_daily_features(self, daily_df):
"""
添加日线特征到五分钟线数据 - 使用内存优化技术
"""
print("添加日线特征...")
# 预处理日线数据
daily_df = daily_df.sort_values(['stock_code', 'date'])
# 计算日线技术指标 - 修复FutureWarning
daily_df['daily_ma5'] = daily_df.groupby('stock_code', observed=True)['close'].transform(
lambda x: x.rolling(5).mean())
daily_df['daily_ma10'] = daily_df.groupby('stock_code', observed=True)['close'].transform(
lambda x: x.rolling(10).mean())
daily_df['daily_vol_ma5'] = daily_df.groupby('stock_code', observed=True)['volume'].transform(
lambda x: x.rolling(5).mean())
# 计算MACD - 使用更高效的方法
def calculate_macd(group):
group = group.sort_values('date')
if len(group) < 26:
return group.assign(daily_macd=np.nan, daily_signal=np.nan)
close_vals = group['close'].values.astype(np.float64)
macd, signal, _ = ta.MACD(close_vals, fastperiod=12, slowperiod=26, signalperiod=9)
return group.assign(daily_macd=macd, daily_signal=signal)
daily_df = daily_df.groupby('stock_code', group_keys=False, observed=True).apply(calculate_macd)
# 提取日期部分用于合并
self.data['date'] = self.data['datetime'].dt.date.astype('datetime64[ns]')
# 优化数据类型
daily_df = daily_df[['stock_code', 'date', 'daily_ma5', 'daily_ma10',
'daily_vol_ma5', 'daily_macd', 'daily_signal']]
daily_df['daily_ma5'] = daily_df['daily_ma5'].astype(np.float32)
daily_df['daily_ma10'] = daily_df['daily_ma10'].astype(np.float32)
daily_df['daily_vol_ma5'] = daily_df['daily_vol_ma5'].astype(np.float32)
daily_df['daily_macd'] = daily_df['daily_macd'].astype(np.float32)
daily_df['daily_signal'] = daily_df['daily_signal'].astype(np.float32)
# 合并日线特征
self.data = pd.merge(
self.data,
daily_df,
on=['stock_code', 'date'],
how='left'
)
# 删除临时列
del self.data['date']
# 释放内存
del daily_df
gc.collect()
def create_features(self):
"""
创建特征工程 - 使用内存优化技术
"""
print("开始创建特征...")
self.print_memory_usage("创建特征前")
if self.data is None:
raise ValueError("请先加载数据")
# 按股票和时间排序
self.data = self.data.sort_values(['stock_code', 'datetime'])
# 特征列表
features = []
# 1. 基础特征
features.append('open')
features.append('high')
features.append('low')
features.append('close')
features.append('volume')
features.append('amount')
# 2. 技术指标 - 使用分组计算避免内存溢出
# 计算移动平均线
self.data['ma5'] = self.data.groupby('stock_code', observed=True)['close'].transform(
lambda x: x.rolling(5, min_periods=1).mean())
self.data['ma10'] = self.data.groupby('stock_code', observed=True)['close'].transform(
lambda x: x.rolling(10, min_periods=1).mean())
features.extend(['ma5', 'ma10'])
# 计算RSI - 使用更高效的方法
print("计算RSI指标...")
def calculate_rsi(group):
group = group.sort_values('datetime')
close = group['close'].values.astype(np.float64)
rsi = ta.RSI(close, timeperiod=14)
return group.assign(rsi=rsi)
self.data = self.data.groupby('stock_code', group_keys=False, observed=True).apply(calculate_rsi)
features.append('rsi')
# 3. 波动率特征
print("计算波动率特征...")
self.data['price_change'] = self.data.groupby('stock_code', observed=True)['close'].pct_change()
self.data['volatility'] = self.data.groupby('stock_code', observed=True)['price_change'].transform(
lambda x: x.rolling(10, min_periods=1).std())
features.append('volatility')
# 4. 成交量特征
self.data['vol_change'] = self.data.groupby('stock_code', observed=True)['volume'].pct_change()
self.data['vol_ma5'] = self.data.groupby('stock_code', observed=True)['volume'].transform(
lambda x: x.rolling(5, min_periods=1).mean())
features.extend(['vol_change', 'vol_ma5'])
# 5. 日线特征
features.extend(['daily_ma5', 'daily_ma10', 'daily_vol_ma5',
'daily_macd', 'daily_signal'])
# 保存特征列
self.feature_cols = features
# 处理缺失值 - 只删除特征列中的缺失值
self.data = self.data.dropna(subset=features)
# 优化数据类型 - 使用astype而不是ast
for col in features:
if self.data[col].dtype == np.float64:
self.data[col] = self.data[col].astype(np.float32)
print(f"特征创建完成,共 {len(features)} 个特征")
self.print_memory_usage("创建特征后")
def clean_data(self):
"""
清洗数据 - 处理无穷大和超出范围的值(修复索引问题)
"""
print("开始数据清洗...")
self.print_memory_usage("清洗前")
# 1. 检查无穷大值
inf_mask = np.isinf(self.data[self.feature_cols].values)
inf_rows = np.any(inf_mask, axis=1)
inf_count = np.sum(inf_rows)
if inf_count > 0:
print(f"发现 {inf_count} 行包含无穷大值,正在清理...")
# 将无穷大替换为NaN
self.data[self.feature_cols] = self.data[self.feature_cols].replace([np.inf, -np.inf], np.nan)
# 2. 检查超出float32范围的值
float32_max = np.finfo(np.float32).max
float32_min = np.finfo(np.float32).min
# 统计超出范围的值
overflow_count = 0
for col in self.feature_cols:
col_max = self.data[col].max()
col_min = self.data[col].min()
if col_max > float32_max or col_min < float32_min:
overflow_count += 1
print(f"列 {col} 包含超出float32范围的值: min={col_min}, max={col_max}")
if overflow_count > 0:
print(f"共发现 {overflow_count} 列包含超出float32范围的值,正在处理...")
# 缩放到安全范围
for col in self.feature_cols:
col_min = self.data[col].min()
col_max = self.data[col].max()
# 如果范围过大,进行缩放
if col_max - col_min > 1e6:
print(f"列 {col} 范围过大 ({col_min} 到 {col_max}),进行缩放...")
self.data[col] = (self.data[col] - col_min) / (col_max - col_min)
# 3. 处理NaN值 - 修复索引问题
nan_count = self.data[self.feature_cols].isna().sum().sum()
if nan_count > 0:
print(f"发现 {nan_count} 个NaN值,使用前向填充处理...")
# 方法1: 使用transform保持索引一致
for col in self.feature_cols:
self.data[col] = self.data.groupby('stock_code', observed=True)[col].transform(
lambda x: x.fillna(method='ffill').fillna(method='bfill').fillna(0)
)
# 方法2: 使用循环逐组处理(备用方法)
# for stock in self.data['stock_code'].unique():
# stock_mask = self.data['stock_code'] == stock
# self.data.loc[stock_mask, self.feature_cols] = self.data.loc[stock_mask, self.feature_cols].fillna(method='ffill').fillna(method='bfill').fillna(0)
# 4. 最终检查
cleaned = True
for col in self.feature_cols:
if np.isinf(self.data[col]).any() or self.data[col].isna().any():
print(f"警告: 列 {col} 仍包含无效值")
cleaned = False
if cleaned:
print("数据清洗完成")
else:
print("数据清洗完成,但仍存在部分问题")
self.print_memory_usage("清洗后")
def create_labels(self):
"""
创建标签 - 添加新条件:
1. 次日(T+1)收盘价(15:00)比次日(T+1)9:35收盘价大5%
2. 后日(T+2)9:35收盘价比次日(T+1)收盘价(15:00)大1%
"""
print("开始创建标签...")
self.print_memory_usage("创建标签前")
if self.data is None:
raise ValueError("请先加载数据")
# 按股票和时间排序
self.data = self.data.sort_values(['stock_code', 'datetime'])
# 添加日期列用于合并
self.data['date'] = self.data['datetime'].dt.date
# 创建每日关键时间点价格数据
daily_key_points = self.data.groupby(['stock_code', 'date']).apply(
lambda x: pd.Series({
'time9_35_close': x[x['datetime'].dt.time == pd.to_datetime('09:35:00').time()]['close'].iloc[0]
if not x[x['datetime'].dt.time == pd.to_datetime('09:35:00').time()].empty else np.nan,
'time15_00_close': x[x['datetime'].dt.time == pd.to_datetime('15:00:00').time()]['close'].iloc[0]
if not x[x['datetime'].dt.time == pd.to_datetime('15:00:00').time()].empty else np.nan
})
).reset_index()
# 为每日关键点添加次日(T+1)和后日(T+2)数据
daily_key_points = daily_key_points.sort_values(['stock_code', 'date'])
daily_key_points['next_date'] = daily_key_points.groupby('stock_code')['date'].shift(-1)
daily_key_points['next_next_date'] = daily_key_points.groupby('stock_code')['date'].shift(-2)
# 合并次日(T+1)数据
daily_key_points = pd.merge(
daily_key_points,
daily_key_points[['stock_code', 'date', 'time9_35_close', 'time15_00_close']].rename(
columns={
'date': 'next_date',
'time9_35_close': 'next_time9_35_close',
'time15_00_close': 'next_time15_00_close'
}
),
on=['stock_code', 'next_date'],
how='left'
)
# 合并后日(T+2)数据
daily_key_points = pd.merge(
daily_key_points,
daily_key_points[['stock_code', 'date', 'time9_35_close']].rename(
columns={
'date': 'next_next_date',
'time9_35_close': 'next_next_time9_35_close'
}
),
on=['stock_code', 'next_next_date'],
how='left'
)
# 将关键点数据合并回原始数据
self.data = pd.merge(
self.data,
daily_key_points[['stock_code', 'date', 'next_time9_35_close', 'next_time15_00_close', 'next_next_time9_35_close']],
on=['stock_code', 'date'],
how='left'
)
# 计算新条件
cond1 = (self.data['next_time15_00_close'] > self.data['next_time9_35_close'] * 1.05)
cond2 = (self.data['next_next_time9_35_close'] > self.data['next_time15_00_close'] * 1.01)
# 创建标签(满足两个条件则为1)
self.data['label'] = np.where(cond1 & cond2, 1, 0).astype(np.int8)
# 删除中间列
self.data.drop([
'date',
'next_time9_35_close',
'next_time15_00_close',
'next_next_time9_35_close'
], axis=1, inplace=True, errors='ignore')
# 保存标签
self.labels = self.data['label']
# 分析标签分布
label_counts = self.data['label'].value_counts(normalize=True)
print(f"标签分布:\n{label_counts}")
print("标签创建完成")
self.print_memory_usage("创建标签后")
def perform_clustering(self, n_clusters=5, batch_size=100000):
"""
执行聚类分析 - 使用MiniBatchKMeans处理大数据
:param n_clusters: 聚类数量
:param batch_size: 每次处理的样本数量
"""
print(f"开始聚类分析,聚类数: {n_clusters}...")
self.print_memory_usage("聚类前")
if self.feature_cols is None:
raise ValueError("请先创建特征")
# 添加数据清洗步骤
self.clean_data()
# 标准化特征
print("标准化特征...")
self.scaler.fit(self.data[self.feature_cols])
# 使用MiniBatchKMeans进行聚类
self.cluster_model = MiniBatchKMeans(
n_clusters=n_clusters,
batch_size=batch_size,
random_state=42,
n_init=3
)
# 分批处理数据
print("分批聚类...")
n_samples = len(self.data)
for i in tqdm(range(0, n_samples, batch_size), desc="聚类进度"):
batch_data = self.data.iloc[i:i+batch_size]
scaled_batch = self.scaler.transform(batch_data[self.feature_cols])
self.cluster_model.partial_fit(scaled_batch)
# 获取最终聚类结果
print("获取聚类结果...")
clusters = []
for i in tqdm(range(0, n_samples, batch_size), desc="分配聚类"):
batch_data = self.data.iloc[i:i+batch_size]
scaled_batch = self.scaler.transform(batch_data[self.feature_cols])
batch_clusters = self.cluster_model.predict(scaled_batch)
clusters.append(batch_clusters)
# 添加聚类结果到数据
self.data['cluster'] = np.concatenate(clusters)
self.feature_cols.append('cluster')
# 分析聚类结果
cluster_summary = self.data.groupby('cluster')['label'].agg(['mean', 'count'])
print("聚类结果分析:")
print(cluster_summary)
# 保存聚类模型
cluster_model_path = os.path.join(
self.output_path,
"分钟线预测训练聚类模型.pkl"
)
joblib.dump(self.cluster_model, cluster_model_path)
print(f"聚类模型已保存至: {cluster_model_path}")
self.print_memory_usage("聚类后")
def train_prediction_model(self, sample_fraction=0.1):
"""
训练预测模型 - 使用数据抽样减少内存使用
:param sample_fraction: 抽样比例
"""
print("开始训练预测模型...")
self.print_memory_usage("训练模型前")
if self.feature_cols is None or self.labels is None:
raise ValueError("请先创建特征和标签")
# 抽样数据
if sample_fraction < 1.0:
print(f"抽样 {sample_fraction*100:.1f}% 数据用于训练")
sample_data = self.data.sample(frac=sample_fraction, random_state=42)
X = sample_data[self.feature_cols]
y = sample_data['label']
else:
X = self.data[self.feature_cols]
y = self.labels
# 检查类别分布
if y.nunique() < 2:
print("警告: 只有一个类别的数据,无法训练模型")
return
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
# 训练随机森林分类器
self.prediction_model = RandomForestClassifier(
n_estimators=100, # 减少树的数量
max_depth=8, # 减小最大深度
min_samples_split=10,
class_weight='balanced',
random_state=42,
n_jobs=-1
)
self.prediction_model.fit(X_train, y_train)
# 评估模型
y_pred = self.prediction_model.predict(X_test)
print("模型评估报告:")
print(classification_report(y_test, y_pred))
# 打印混淆矩阵
cm = confusion_matrix(y_test, y_pred)
print("混淆矩阵:")
print(cm)
# 保存预测模型
model_path = os.path.join(
self.output_path,
"分钟线预测训练模型.pkl"
)
joblib.dump(self.prediction_model, model_path)
print(f"预测模型已保存至: {model_path}")
self.print_memory_usage("训练模型后")
def predict_and_save(self, output_results=True):
"""
使用模型进行预测并保存结果
:param output_results: 是否输出预测结果
"""
print("开始预测...")
self.print_memory_usage("预测前")
if self.prediction_model is None:
raise ValueError("请先训练预测模型")
# 准备预测数据
X = self.data[self.feature_cols]
# 分批预测
predictions = []
batch_size = 10000
n_samples = len(X)
for i in tqdm(range(0, n_samples, batch_size), desc="预测进度"):
batch_data = X.iloc[i:i+batch_size]
batch_pred = self.prediction_model.predict(batch_data)
predictions.append(batch_pred)
# 合并预测结果
self.data['prediction'] = np.concatenate(predictions)
# 保存预测结果
if output_results:
output_file = os.path.join(self.output_path, "预测结果.csv")
self.data[['stock_code', 'datetime', 'close', 'label', 'prediction']].to_csv(output_file, index=False)
print(f"预测结果已保存至: {output_file}")
# 分析预测效果
accuracy = (self.data['label'] == self.data['prediction']).mean()
print(f"整体预测准确率: {accuracy:.4f}")
# 按股票分析预测效果
stock_accuracy = self.data.groupby('stock_code').apply(
lambda x: (x['label'] == x['prediction']).mean()
)
print("\n股票预测准确率统计:")
print(stock_accuracy.describe())
self.print_memory_usage("预测后")
def run(self, output_results=True, sample_fraction=0.1):
"""
运行整个流程 - 使用内存优化技术
"""
try:
# 分步执行,每步完成后释放内存
self.load_and_preprocess_data()
gc.collect()
self.print_memory_usage("数据加载后")
self.create_features()
gc.collect()
self.print_memory_usage("特征创建后")
self.create_labels() # 使用新的标签创建方法
gc.collect()
self.print_memory_usage("标签创建后")
self.perform_clustering(n_clusters=self.config.get('n_clusters', 5))
gc.collect()
self.print_memory_usage("聚类后")
self.train_prediction_model(sample_fraction=sample_fraction)
gc.collect()
self.print_memory_usage("模型训练后")
self.predict_and_save(output_results)
gc.collect()
self.print_memory_usage("预测后")
print("训练和预测流程完成!")
except KeyboardInterrupt:
print("用户中断执行")
except Exception as e:
print(f"运行过程中出错: {str(e)}")
import traceback
traceback.print_exc()
# 配置参数
config = {
# 数据路径配置
'five_min_sz_path': r"D:\股票量化数据库\股票五分钟线csv数据\深证",
'five_min_sh_path': r"D:\股票量化数据库\股票五分钟线csv数据\上证",
'daily_sz_path': r"D:\股票量化数据库\股票csv数据\深证",
'daily_sh_path': r"D:\股票量化数据库\股票csv数据\上证",
# 输出路径
'output_path': r"D:\股票量化数据库\预测结果",
# 时间范围配置
'start_date': '2023-09-08',
'end_date': '2025-08-07',
# 聚类配置
'n_clusters': 5
}
# 创建并运行系统
if __name__ == "__main__":
# 打印环境信息
print(f"Python版本: {sys.version}")
print(f"Pandas版本: {pd.__version__}")
# 是否输出预测结果
output_results = True
# 抽样比例 (0.1 = 10%)
sample_fraction = 0.1
# 设置Pandas内存选项
pd.set_option('mode.chained_assignment', None)
pd.set_option('display.max_columns', None)
# 设置内存优化选项
pd.set_option('compute.use_numexpr', True)
pd.set_option('compute.use_bottleneck', True)
# 创建并运行系统
system = StockPredictionSystem(config)
system.run(output_results=output_results, sample_fraction=sample_fraction)
最新发布