修改下列代码,把"要求次日(T+1)收盘价(time15:00:00的close时候的股价)比次日(T+1)time9:35:00的close时候的股价大百分之五,且后天(T+2)time9:35:00的close时候的股价比次日(T+1)收盘价(time15:00:00的close时候的股价)大百分之一"的条件改成“要求次日(T+1)收盘价比次日(T+1)开盘价大百分之五,且后天(T+2)开盘价比次日(T+1)收盘价大百分之一”。每天的收盘价、开盘价在日线数据中,结合股票的五分钟数据线进行训练和预测。
下面是需要修改的代码:
# -*- coding: utf-8 -*-
"""
股票预测系统 - 带概率输出版本
"""
import os
import glob
import pandas as pd
import numpy as np
import joblib
import gc
import sys
import psutil
import warnings
from datetime import datetime
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import MiniBatchKMeans
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix
import talib as ta
import pyarrow as pa
import pyarrow.parquet as pq
import pyarrow.csv as pc
from tqdm import tqdm
from sklearn.model_selection import train_test_split
# 修复警告处理
warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=pd.errors.ParserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
class StockPredictionSystem:
def __init__(self, config):
"""
初始化预测系统 - 针对超大内存优化
:param config: 配置字典
"""
self.config = config
self.five_min_paths = {
'sz': config['five_min_sz_path'],
'sh': config['five_min_sh_path']
}
self.daily_paths = {
'sz': config['daily_sz_path'],
'sh': config['daily_sh_path']
}
self.output_path = config['output_path']
self.start_date = datetime.strptime(config['start_date'], '%Y-%m-%d')
self.end_date = datetime.strptime(config['end_date'], '%Y-%m-%d')
self.data = None
self.features = None
self.labels = None
self.scaler = StandardScaler()
self.cluster_model = None
self.prediction_model = None
self.feature_cols = []
self.temp_dir = os.path.join(self.output_path, "temp")
os.makedirs(self.temp_dir, exist_ok=True)
self.parquet_files = []
def print_memory_usage(self, step_name):
"""打印当前内存使用情况"""
process = psutil.Process(os.getpid())
mem = process.memory_info().rss / 1024 ** 2
print(f"[{step_name}] 当前内存使用: {mem:.2f} MB")
def safe_read_csv(self, file_path, required_columns):
"""
安全读取CSV文件 - 使用PyArrow进行高效读取
:param file_path: 文件路径
:param required_columns: 需要的列名列表
:return: 读取的DataFrame或None
"""
try:
# 检查文件大小
if not os.path.exists(file_path):
print(f"文件不存在: {file_path}")
return None
file_size = os.path.getsize(file_path)
if file_size == 0:
print(f"文件 {file_path} 大小为0,跳过")
return None
# 使用PyArrow读取CSV
read_options = pc.ReadOptions(
use_threads=True,
block_size=4096 * 1024 # 4MB块大小
)
parse_options = pc.ParseOptions(delimiter=',')
convert_options = pc.ConvertOptions(
include_columns=required_columns,
column_types={
'date': pa.string(),
'time': pa.string(),
'open': pa.float32(),
'high': pa.float32(),
'low': pa.float32(),
'close': pa.float32(),
'volume': pa.float32(),
'amount': pa.float32()
}
)
table = pc.read_csv(
file_path,
read_options=read_options,
parse_options=parse_options,
convert_options=convert_options
)
# 转换为Pandas DataFrame
df = table.to_pandas()
# 检查是否读取到数据
if df.empty:
print(f"文件 {file_path} 读取后为空")
return None
return df
except Exception as e:
print(f"读取文件 {file_path} 时出错: {str(e)}")
return None
def process_and_save_chunk(self, df, market, stock_code, chunk_index):
"""
处理单个股票的数据块并保存为Parquet文件 - 内存优化版本
"""
if df is None or df.empty:
return None
try:
# 添加市场前缀
df['stock_code'] = f"{market}_{stock_code}"
# 修复日期时间转换问题
df['date'] = df['date'].astype(str).str.zfill(8) # 填充为8位字符串
df['time'] = df['time'].astype(str)
# 处理时间格式
df['time'] = df['time'].apply(
lambda x: f"{x[:2]}:{x[2:4]}" if len(x) == 4 else x
)
# 合并日期和时间
df['datetime'] = pd.to_datetime(
df['date'] + ' ' + df['time'],
format='%Y%m%d %H:%M',
errors='coerce'
)
# 删除无效的日期时间
df = df.dropna(subset=['datetime'])
# 筛选日期范围
df = df[(df['datetime'] >= self.start_date) &
(df['datetime'] <= self.end_date)]
if df.empty:
return None
# 优化内存使用
df = df[['stock_code', 'datetime', 'open', 'high', 'low', 'close', 'volume', 'amount']]
# 保存为Parquet文件
output_file = os.path.join(self.temp_dir, f"{market}_{stock_code}_{chunk_index}.parquet")
# 使用PyArrow直接写入Parquet
table = pa.Table.from_pandas(df, preserve_index=False)
pq.write_table(table, output_file, compression='SNAPPY')
return output_file
except Exception as e:
print(f"处理股票 {stock_code} 时出错: {str(e)}")
return None
def incremental_merge_parquet_files(self, parquet_files, batch_size=100):
"""
增量合并Parquet文件 - 避免一次性加载所有数据
:param parquet_files: Parquet文件列表
:param batch_size: 每次合并的文件数量
:return: 合并后的Parquet文件路径
"""
merged_file = os.path.join(self.temp_dir, "merged_data.parquet")
# 如果文件已存在,删除
if os.path.exists(merged_file):
os.remove(merged_file)
# 分批合并文件
for i in tqdm(range(0, len(parquet_files), batch_size), desc="合并Parquet文件"):
batch_files = parquet_files[i:i+batch_size]
# 读取当前批次文件
tables = []
for file in batch_files:
try:
table = pq.read_table(file)
tables.append(table)
except Exception as e:
print(f"读取文件 {file} 出错: {str(e)}")
if not tables:
continue
# 合并当前批次
merged_table = pa.concat_tables(tables)
# 追加到输出文件
if os.path.exists(merged_file):
# 追加模式
with pq.ParquetWriter(merged_file, merged_table.schema) as writer:
writer.write_table(merged_table)
else:
# 首次写入
pq.write_table(merged_table, merged_file)
# 释放内存
del tables
del merged_table
gc.collect()
return merged_file
def load_and_preprocess_data(self):
"""
加载和预处理数据 - 使用增量合并避免内存溢出
"""
print("开始加载和预处理数据...")
self.print_memory_usage("开始加载数据")
# 创建临时目录
os.makedirs(self.temp_dir, exist_ok=True)
parquet_files = []
# 加载五分钟线数据
for market, path in self.five_min_paths.items():
print(f"开始处理市场: {market}, 路径: {path}")
file_count = 0
processed_count = 0
# 获取文件列表
csv_files = list(glob.glob(os.path.join(path, '*.csv')))
print(f"找到 {len(csv_files)} 个文件")
for file_path in tqdm(csv_files, desc=f"处理 {market} 市场文件"):
file_count += 1
stock_code = os.path.basename(file_path).split('.')[0]
try:
# 安全读取CSV文件
df = self.safe_read_csv(file_path, ['date', 'time', 'open', 'high', 'low', 'close', 'volume', 'amount'])
if df is None:
continue
# 处理并保存为Parquet
output_file = self.process_and_save_chunk(df, market, stock_code, processed_count)
if output_file:
parquet_files.append(output_file)
processed_count += 1
# 每处理100个文件释放内存
if processed_count % 100 == 0:
self.print_memory_usage(f"已处理 {processed_count} 个文件")
gc.collect()
except Exception as e:
print(f"处理文件 {file_path} 时出错: {str(e)}")
continue
print(f"市场 {market} 完成: 共 {file_count} 个文件, 成功处理 {processed_count} 个文件")
# 如果没有找到有效文件
if not parquet_files:
raise ValueError("没有找到有效的五分钟线数据")
print(f"开始增量合并 {len(parquet_files)} 个Parquet文件...")
self.print_memory_usage("合并前")
# 增量合并Parquet文件
merged_file = self.incremental_merge_parquet_files(parquet_files, batch_size=50)
# 加载合并后的数据
print(f"加载合并后的数据: {merged_file}")
self.data = pq.read_table(merged_file).to_pandas()
# 优化内存使用
self.data['stock_code'] = self.data['stock_code'].astype('category')
print(f"数据合并完成,共 {len(self.data)} 条记录")
self.print_memory_usage("合并后")
# 清理临时文件
for file in parquet_files:
try:
os.remove(file)
except:
pass
# 加载日线数据
daily_data = []
daily_required_columns = ['date', 'open', 'high', 'low', 'close', 'volume']
for market, path in self.daily_paths.items():
print(f"开始处理日线市场: {market}, 路径: {path}")
file_count = 0
processed_count = 0
# 获取所有CSV文件
all_files = list(glob.glob(os.path.join(path, '*.csv')))
print(f"找到 {len(all_files)} 个日线文件")
for file_path in tqdm(all_files, desc=f"处理 {market} 日线文件"):
file_count += 1
stock_code = os.path.basename(file_path).split('.')[0]
try:
# 安全读取CSV文件
df = self.safe_read_csv(file_path, daily_required_columns)
if df is None or df.empty:
continue
# 添加市场前缀
df['stock_code'] = f"{market}_{stock_code}"
# 转换日期格式
df['date'] = pd.to_datetime(df['date'], errors='coerce')
# 删除无效日期
df = df.dropna(subset=['date'])
# 筛选日期范围
df = df[(df['date'] >= self.start_date) &
(df['date'] <= self.end_date)]
if df.empty:
continue
# 优化内存使用
df = df[['stock_code', 'date', 'open', 'high', 'low', 'close', 'volume']]
# 优化数据类型
df['open'] = df['open'].astype(np.float32)
df['high'] = df['high'].astype(np.float32)
df['low'] = df['low'].astype(np.float32)
df['close'] = df['close'].astype(np.float32)
df['volume'] = df['volume'].astype(np.float32)
daily_data.append(df)
processed_count += 1
if processed_count % 100 == 0:
self.print_memory_usage(f"已处理 {processed_count} 个日线文件")
gc.collect()
except Exception as e:
print(f"处理日线文件 {file_path} 时出错: {str(e)}")
continue
print(f"日线市场 {market} 完成: 共 {file_count} 个文件, 成功处理 {processed_count} 个文件")
# 合并日线数据
if daily_data:
daily_df = pd.concat(daily_data, ignore_index=True)
daily_df['stock_code'] = daily_df['stock_code'].astype('category')
# 添加日线特征
self._add_daily_features(daily_df)
else:
print("警告: 没有找到日线数据")
print(f"数据加载完成,共 {len(self.data)} 条记录")
self.print_memory_usage("数据加载完成")
def _add_daily_features(self, daily_df):
"""
添加日线特征到五分钟线数据 - 使用内存优化技术
"""
print("添加日线特征...")
# 预处理日线数据
daily_df = daily_df.sort_values(['stock_code', 'date'])
# 计算日线技术指标
daily_df['daily_ma5'] = daily_df.groupby('stock_code', observed=True)['close'].transform(
lambda x: x.rolling(5).mean())
daily_df['daily_ma10'] = daily_df.groupby('stock_code', observed=True)['close'].transform(
lambda x: x.rolling(10).mean())
daily_df['daily_vol_ma5'] = daily_df.groupby('stock_code', observed=True)['volume'].transform(
lambda x: x.rolling(5).mean())
# 计算MACD - 使用更高效的方法
def calculate_macd(group):
group = group.sort_values('date')
if len(group) < 26:
return group.assign(daily_macd=np.nan, daily_signal=np.nan)
close_vals = group['close'].values.astype(np.float64)
macd, signal, _ = ta.MACD(close_vals, fastperiod=12, slowperiod=26, signalperiod=9)
return group.assign(daily_macd=macd, daily_signal=signal)
daily_df = daily_df.groupby('stock_code', group_keys=False, observed=True).apply(calculate_macd)
# 提取日期部分用于合并
self.data['date'] = self.data['datetime'].dt.date.astype('datetime64[ns]')
# 优化数据类型
daily_df = daily_df[['stock_code', 'date', 'daily_ma5', 'daily_ma10',
'daily_vol_ma5', 'daily_macd', 'daily_signal']]
daily_df['daily_ma5'] = daily_df['daily_ma5'].astype(np.float32)
daily_df['daily_ma10'] = daily_df['daily_ma10'].astype(np.float32)
daily_df['daily_vol_ma5'] = daily_df['daily_vol_ma5'].astype(np.float32)
daily_df['daily_macd'] = daily_df['daily_macd'].astype(np.float32)
daily_df['daily_signal'] = daily_df['daily_signal'].astype(np.float32)
# 合并日线特征
self.data = pd.merge(
self.data,
daily_df,
on=['stock_code', 'date'],
how='left'
)
# 删除临时列
del self.data['date']
# 释放内存
del daily_df
gc.collect()
def create_features(self):
"""
创建特征工程 - 使用内存优化技术
"""
print("开始创建特征...")
self.print_memory_usage("创建特征前")
if self.data is None:
raise ValueError("请先加载数据")
# 按股票和时间排序
self.data = self.data.sort_values(['stock_code', 'datetime'])
# 特征列表
features = []
# 1. 基础特征
features.append('open')
features.append('high')
features.append('low')
features.append('close')
features.append('volume')
features.append('amount')
# 2. 技术指标 - 使用分组计算避免内存溢出
# 计算移动平均线
self.data['ma5'] = self.data.groupby('stock_code', observed=True)['close'].transform(
lambda x: x.rolling(5, min_periods=1).mean())
self.data['ma10'] = self.data.groupby('stock_code', observed=True)['close'].transform(
lambda x: x.rolling(10, min_periods=1).mean())
features.extend(['ma5', 'ma10'])
# 计算RSI - 使用更高效的方法
print("计算RSI指标...")
def calculate_rsi(group):
group = group.sort_values('datetime')
close = group['close'].values.astype(np.float64)
rsi = ta.RSI(close, timeperiod=14)
return group.assign(rsi=rsi)
self.data = self.data.groupby('stock_code', group_keys=False, observed=True).apply(calculate_rsi)
features.append('rsi')
# 3. 波动率特征
print("计算波动率特征...")
self.data['price_change'] = self.data.groupby('stock_code', observed=True)['close'].pct_change()
self.data['volatility'] = self.data.groupby('stock_code', observed=True)['price_change'].transform(
lambda x: x.rolling(10, min_periods=1).std())
features.append('volatility')
# 4. 成交量特征
self.data['vol_change'] = self.data.groupby('stock_code', observed=True)['volume'].pct_change()
self.data['vol_ma5'] = self.data.groupby('stock_code', observed=True)['volume'].transform(
lambda x: x.rolling(5, min_periods=1).mean())
features.extend(['vol_change', 'vol_ma5'])
# 5. 日线特征
features.extend(['daily_ma5', 'daily_ma10', 'daily_vol_ma5',
'daily_macd', 'daily_signal'])
# 保存特征列
self.feature_cols = features
# 处理缺失值 - 只删除特征列中的缺失值
self.data = self.data.dropna(subset=features)
# 优化数据类型
for col in features:
if self.data[col].dtype == np.float64:
self.data[col] = self.data[col].astype(np.float32)
print(f"特征创建完成,共 {len(features)} 个特征")
self.print_memory_usage("创建特征后")
def clean_data(self):
"""
清洗数据 - 处理无穷大和超出范围的值
"""
print("开始数据清洗...")
self.print_memory_usage("清洗前")
# 1. 检查无穷大值
inf_mask = np.isinf(self.data[self.feature_cols].values)
inf_rows = np.any(inf_mask, axis=1)
inf_count = np.sum(inf_rows)
if inf_count > 0:
print(f"发现 {inf_count} 行包含无穷大值,正在清理...")
# 将无穷大替换为NaN
self.data[self.feature_cols] = self.data[self.feature_cols].replace([np.inf, -np.inf], np.nan)
# 2. 检查超出float32范围的值
float32_max = np.finfo(np.float32).max
float32_min = np.finfo(np.float32).min
# 统计超出范围的值
overflow_count = 0
for col in self.feature_cols:
col_max = self.data[col].max()
col_min = self.data[col].min()
if col_max > float32_max or col_min < float32_min:
overflow_count += 1
print(f"列 {col} 包含超出float32范围的值: min={col_min}, max={col_max}")
if overflow_count > 0:
print(f"共发现 {overflow_count} 列包含超出float32范围的值,正在处理...")
# 缩放到安全范围
for col in self.feature_cols:
col_min = self.data[col].min()
col_max = self.data[col].max()
# 如果范围过大,进行缩放
if col_max - col_min > 1e6:
print(f"列 {col} 范围过大 ({col_min} 到 {col_max}),进行缩放...")
self.data[col] = (self.data[col] - col_min) / (col_max - col_min)
# 3. 处理NaN值
nan_count = self.data[self.feature_cols].isna().sum().sum()
if nan_count > 0:
print(f"发现 {nan_count} 个NaN值,使用前向填充处理...")
# 使用transform保持索引一致
for col in self.feature_cols:
self.data[col] = self.data.groupby('stock_code', observed=True)[col].transform(
lambda x: x.fillna(method='ffill').fillna(method='bfill').fillna(0)
)
# 4. 最终检查
cleaned = True
for col in self.feature_cols:
if np.isinf(self.data[col]).any() or self.data[col].isna().any():
print(f"警告: 列 {col} 仍包含无效值")
cleaned = False
if cleaned:
print("数据清洗完成")
else:
print("数据清洗完成,但仍存在部分问题")
self.print_memory_usage("清洗后")
def create_labels(self):
"""
创建标签 - 优化标签创建逻辑,确保有正负样本
"""
print("开始创建标签...")
self.print_memory_usage("创建标签前")
if self.data is None:
raise ValueError("请先加载数据")
# 按股票和时间排序
self.data = self.data.sort_values(['stock_code', 'datetime'])
# 添加日期列用于合并
self.data['date'] = self.data['datetime'].dt.date
# 创建每日关键时间点价格数据
daily_key_points = self.data.groupby(['stock_code', 'date']).apply(
lambda x: pd.Series({
'time9_35_close': x[x['datetime'].dt.time == pd.to_datetime('09:35:00').time()]['close'].iloc[0]
if not x[x['datetime'].dt.time == pd.to_datetime('09:35:00').time()].empty else np.nan,
'time15_00_close': x[x['datetime'].dt.time == pd.to_datetime('15:00:00').time()]['close'].iloc[0]
if not x[x['datetime'].dt.time == pd.to_datetime('15:00:00').time()].empty else np.nan
})
).reset_index()
# 为每日关键点添加次日(T+1)和后日(T+2)数据
daily_key_points = daily_key_points.sort_values(['stock_code', 'date'])
daily_key_points['next_date'] = daily_key_points.groupby('stock_code')['date'].shift(-1)
daily_key_points['next_next_date'] = daily_key_points.groupby('stock_code')['date'].shift(-2)
# 合并次日(T+1)数据
daily_key_points = pd.merge(
daily_key_points,
daily_key_points[['stock_code', 'date', 'time9_35_close', 'time15_00_close']].rename(
columns={
'date': 'next_date',
'time9_35_close': 'next_time9_35_close',
'time15_00_close': 'next_time15_00_close'
}
),
on=['stock_code', 'next_date'],
how='left'
)
# 合并后日(T+2)数据
daily_key_points = pd.merge(
daily_key_points,
daily_key_points[['stock_code', 'date', 'time9_35_close']].rename(
columns={
'date': 'next_next_date',
'time9_35_close': 'next_next_time9_35_close'
}
),
on=['stock_code', 'next_next_date'],
how='left'
)
# 将关键点数据合并回原始数据
self.data = pd.merge(
self.data,
daily_key_points[['stock_code', 'date', 'next_time9_35_close', 'next_time15_00_close', 'next_next_time9_35_close']],
on=['stock_code', 'date'],
how='left'
)
# 计算新条件 - 添加容错处理
# 条件1: 次日(T+1)收盘价比次日9:35收盘价大5%
cond1 = (self.data['next_time15_00_close'] > self.data['next_time9_35_close'] * 1.05)
# 条件2: 后日(T+2)9:35收盘价比次日收盘价大1%
cond2 = (self.data['next_next_time9_35_close'] > self.data['next_time15_00_close'] * 1.01)
# 处理可能的NaN值
cond1 = cond1.fillna(False)
cond2 = cond2.fillna(False)
# 创建标签(满足两个条件则为1)
self.data['label'] = np.where(cond1 & cond2, 1, 0).astype(np.int8)
# 删除中间列
self.data.drop([
'date',
'next_time9_35_close',
'next_time15_00_close',
'next_next_time9_35_close'
], axis=1, inplace=True, errors='ignore')
# 保存标签
self.labels = self.data['label']
# 分析标签分布
label_counts = self.data['label'].value_counts(normalize=True)
print(f"标签分布:\n{label_counts}")
# 检查标签多样性
unique_labels = self.data['label'].nunique()
if unique_labels < 2:
print(f"警告: 只有 {unique_labels} 个标签类别,可能导致模型训练失败")
# 尝试创建替代标签
self._create_alternative_labels()
print("标签创建完成")
self.print_memory_usage("创建标签后")
def _create_alternative_labels(self):
"""创建替代标签确保有正负样本"""
print("创建替代标签确保多样性...")
# 方法1: 基于价格变化创建标签
self.data['price_change'] = self.data.groupby('stock_code')['close'].pct_change()
self.data['label'] = np.where(self.data['price_change'] > 0.01, 1, 0)
# 更新标签
self.labels = self.data['label']
# 再次分析标签分布
label_counts = self.data['label'].value_counts(normalize=True)
print(f"替代标签分布:\n{label_counts}")
def perform_clustering(self, n_clusters=5, batch_size=100000):
"""
执行聚类分析 - 使用MiniBatchKMeans处理大数据
:param n_clusters: 聚类数量
:param batch_size: 每次处理的样本数量
"""
print(f"开始聚类分析,聚类数: {n_clusters}...")
self.print_memory_usage("聚类前")
if self.feature_cols is None:
raise ValueError("请先创建特征")
# 添加数据清洗步骤
self.clean_data()
# 标准化特征
print("标准化特征...")
self.scaler.fit(self.data[self.feature_cols])
# 使用MiniBatchKMeans进行聚类
self.cluster_model = MiniBatchKMeans(
n_clusters=n_clusters,
batch_size=batch_size,
random_state=42,
n_init=3
)
# 分批处理数据
print("分批聚类...")
n_samples = len(self.data)
for i in tqdm(range(0, n_samples, batch_size), desc="聚类进度"):
batch_data = self.data.iloc[i:i+batch_size]
scaled_batch = self.scaler.transform(batch_data[self.feature_cols])
self.cluster_model.partial_fit(scaled_batch)
# 获取最终聚类结果
print("获取聚类结果...")
clusters = []
for i in tqdm(range(0, n_samples, batch_size), desc="分配聚类"):
batch_data = self.data.iloc[i:i+batch_size]
scaled_batch = self.scaler.transform(batch_data[self.feature_cols])
batch_clusters = self.cluster_model.predict(scaled_batch)
clusters.append(batch_clusters)
# 添加聚类结果到数据
self.data['cluster'] = np.concatenate(clusters)
self.feature_cols.append('cluster')
# 分析聚类结果
cluster_summary = self.data.groupby('cluster')['label'].agg(['mean', 'count'])
print("聚类结果分析:")
print(cluster_summary)
# 保存聚类模型
cluster_model_path = os.path.join(
self.output_path,
"分钟线预测训练聚类模型.pkl"
)
joblib.dump(self.cluster_model, cluster_model_path)
print(f"聚类模型已保存至: {cluster_model_path}")
self.print_memory_usage("聚类后")
def train_prediction_model(self):
"""
训练预测模型 - 使用所有数据进行训练
"""
print("开始训练预测模型(使用所有数据)...")
self.print_memory_usage("训练模型前")
if self.feature_cols is None or self.labels is None:
raise ValueError("请先创建特征和标签")
# 使用所有数据
X = self.data[self.feature_cols]
y = self.labels
# 检查类别分布
if y.nunique() < 2:
print("警告: 只有一个类别的数据,无法训练模型")
return
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
# 训练随机森林分类器
self.prediction_model = RandomForestClassifier(
n_estimators=100,
max_depth=8,
min_samples_split=10,
class_weight='balanced',
random_state=42,
n_jobs=-1
)
self.prediction_model.fit(X_train, y_train)
# 评估模型
y_pred = self.prediction_model.predict(X_test)
print("模型评估报告:")
print(classification_report(y_test, y_pred))
# 打印混淆矩阵
cm = confusion_matrix(y_test, y_pred)
print("混淆矩阵:")
print(cm)
# 保存预测模型
model_path = os.path.join(
self.output_path,
"分钟线预测训练模型.pkl"
)
joblib.dump(self.prediction_model, model_path)
print(f"预测模型已保存至: {model_path}")
self.print_memory_usage("训练模型后")
def predict_and_save(self, output_results=True):
"""
使用模型进行预测并保存结果(添加满足条件的概率)
:param output_results: 是否输出预测结果
"""
print("开始预测(添加概率)...")
self.print_memory_usage("预测前")
if self.prediction_model is None:
raise ValueError("请先训练预测模型")
# 准备预测数据
X = self.data[self.feature_cols]
# 分批预测概率
probabilities = [] # 存储满足条件的概率
predictions = [] # 存储预测类别
batch_size = 10000
n_samples = len(X)
for i in tqdm(range(0, n_samples, batch_size), desc="预测概率进度"):
batch_data = X.iloc[i:i+batch_size]
# 预测类别和概率
batch_pred = self.prediction_model.predict(batch_data)
batch_proba = self.prediction_model.predict_proba(batch_data)
# 获取满足条件(label=1)的概率
# 注意:predict_proba返回的是每个类别的概率,索引0对应label=0,索引1对应label=1
if batch_proba.shape[1] > 1: # 确保是多分类问题
batch_prob1 = batch_proba[:, 1] # 获取label=1的概率
else:
batch_prob1 = np.ones(len(batch_data)) # 如果只有一类,设为1
probabilities.append(batch_prob1)
predictions.append(batch_pred)
# 合并预测结果
self.data['prediction'] = np.concatenate(predictions)
self.data['probability'] = np.concatenate(probabilities)
# 保存预测结果
if output_results:
output_file = os.path.join(self.output_path, "预测结果.csv")
# 只保存关键列以减小文件大小
result_cols = ['stock_code', 'datetime', 'close', 'label', 'prediction', 'probability']
# 添加额外的时间信息便于分析
self.data['date'] = self.data['datetime'].dt.date
self.data['time'] = self.data['datetime'].dt.time
result_df = self.data[['date', 'time'] + result_cols].copy()
# 优化输出格式
result_df['probability'] = result_df['probability'].round(4) # 保留4位小数
# 添加交易信号列
result_df['signal'] = np.where(
(result_df['prediction'] == 1) & (result_df['probability'] > 0.7),
'强烈买入',
np.where(
(result_df['prediction'] == 1) & (result_df['probability'] > 0.5),
'买入',
np.where(
(result_df['prediction'] == 0) & (result_df['probability'] < 0.3),
'卖出',
'观望'
)
)
)
result_df.to_csv(output_file, index=False)
print(f"预测结果已保存至: {output_file}")
# 分析预测效果
accuracy = (self.data['label'] == self.data['prediction']).mean()
print(f"整体预测准确率: {accuracy:.4f}")
# 分析概率分布
print("\n满足条件的概率分布统计:")
print(self.data['probability'].describe())
# 按股票分析预测效果
stock_accuracy = self.data.groupby('stock_code').apply(
lambda x: (x['label'] == x['prediction']).mean()
)
print("\n股票预测准确率统计:")
print(stock_accuracy.describe())
# 分析不同概率区间的准确率
print("\n不同概率区间的预测准确率:")
bins = [0, 0.3, 0.5, 0.7, 1.0]
labels = ['低概率(0-0.3)', '中低概率(0.3-0.5)', '中高概率(0.5-0.7)', '高概率(0.7-1.0)']
self.data['prob_bin'] = pd.cut(self.data['probability'], bins=bins, labels=labels)
bin_accuracy = self.data.groupby('prob_bin').apply(
lambda x: (x['label'] == x['prediction']).mean()
)
print(bin_accuracy)
self.print_memory_usage("预测后")
def run(self, output_results=True):
"""
运行整个流程 - 添加错误处理和内存优化
"""
try:
# 分步执行,每步完成后释放内存
self.load_and_preprocess_data()
gc.collect()
self.print_memory_usage("数据加载后")
self.create_features()
gc.collect()
self.print_memory_usage("特征创建后")
self.create_labels() # 使用新的标签创建方法
gc.collect()
self.print_memory_usage("标签创建后")
# 检查标签多样性
if self.labels.nunique() < 2:
print("错误: 标签不足两个类别,跳过聚类和模型训练")
return
self.perform_clustering(n_clusters=self.config.get('n_clusters', 5))
gc.collect()
self.print_memory_usage("聚类后")
self.train_prediction_model() # 使用修改后的训练方法
gc.collect()
self.print_memory_usage("模型训练后")
# 检查模型是否训练成功
if self.prediction_model is not None:
self.predict_and_save(output_results)
gc.collect()
self.print_memory_usage("预测后")
print("训练和预测流程完成!")
else:
print("模型训练失败,跳过预测步骤")
except KeyboardInterrupt:
print("用户中断执行")
except Exception as e:
print(f"运行过程中出错: {str(e)}")
import traceback
traceback.print_exc()
# 配置参数
config = {
# 数据路径配置
'five_min_sz_path': r"D:\股票量化数据库\股票五分钟线csv数据\深证",
'five_min_sh_path': r"D:\股票量化数据库\股票五分钟线csv数据\上证",
'daily_sz_path': r"D:\股票量化数据库\股票csv数据\深证",
'daily_sh_path': r"D:\股票量化数据库\股票csv数据\上证",
# 输出路径
'output_path': r"D:\股票量化数据库\预测结果",
# 时间范围配置
'start_date': '2023-09-08',
'end_date': '2025-08-08',
# 聚类配置
'n_clusters': 5
}
# 创建并运行系统
if __name__ == "__main__":
# 打印环境信息
print(f"Python版本: {sys.version}")
print(f"Pandas版本: {pd.__version__}")
# 是否输出预测结果
output_results = True
# 设置Pandas内存选项
pd.set_option('mode.chained_assignment', None)
pd.set_option('display.max_columns', None)
# 设置内存优化选项
pd.set_option('compute.use_numexpr', True)
pd.set_option('compute.use_bottleneck', True)
# 创建并运行系统
system = StockPredictionSystem(config)
system.run(output_results=output_results)