基于下列代码进行上述要求的修改
# -*- coding: utf-8 -*-
"""
Created on Sun Jul 20 22:17:32 2025
@author: srx20
"""
# -*- coding: utf-8 -*-
"""
股票预测筛选程序 - 支持自定义排名范围输出并屏蔽特定股票
"""
import os
import joblib
import pandas as pd
import numpy as np
from tqdm import tqdm
from typing import Dict, List, Tuple
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import MiniBatchKMeans
import talib as ta
import logging
from datetime import datetime
import matplotlib.pyplot as plt
from matplotlib import font_manager as fm
import base64
from io import BytesIO
# 设置中文字体支持
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
# 设置日志记录
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('stock_prediction_filter.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
# ========== 配置类 ==========
class StockConfig:
def __init__(self):
# 数据路径
self.SH_PATH = r"D:\股票量化数据库\股票csv数据\上证"
self.SZ_PATH = r"D:\股票量化数据库\股票csv数据\深证"
# 聚类设置
self.CLUSTER_NUM = 8
self.CLUSTER_FEATURES = [
'price_change', 'volatility', 'volume_change',
'MA5', 'MA20', 'RSI14', 'MACD_hist'
]
# 目标条件
self.MIN_GAIN = 0.05
self.MIN_LOW_RATIO = 0.98
# 预测特征 (初始列表,实际使用时会动态更新)
self.PREDICT_FEATURES = [
'open', 'high', 'low', 'close', 'volume',
'price_change', 'volatility', 'volume_change',
'MA5', 'MA20', 'RSI14', 'MACD_hist',
'cluster', 'MOM10', 'ATR14', 'VWAP', 'RSI_diff',
'price_vol_ratio', 'MACD_RSI', 'advance_decline',
'day_of_week', 'month'
]
# 需要屏蔽的股票前缀
self.BLOCKED_PREFIXES = ['SZ_sz16', 'SH_sh88','SH_sh999','SH_sh688',
'SZ_sz300','SH_sh500']
# ========== 特征工程 ==========
class FeatureEngineer:
def __init__(self, config):
self.config = config
def safe_fillna(self, series, default=0):
"""安全填充NaN值"""
if isinstance(series, pd.Series):
return series.fillna(default)
elif isinstance(series, np.ndarray):
return np.nan_to_num(series, nan=default)
return series
def transform(self, df):
"""添加技术指标特征"""
try:
# 创建临时副本用于TA-Lib计算
df_temp = df.copy()
# 将价格列转换为float64以满足TA-Lib要求
for col in ['open', 'high', 'low', 'close']:
df_temp[col] = df_temp[col].astype(np.float64)
# 基础特征
df['price_change'] = df['close'].pct_change().fillna(0)
df['volatility'] = df['close'].rolling(5).std().fillna(0)
df['volume_change'] = df['volume'].pct_change().fillna(0)
df['MA5'] = df['close'].rolling(5).mean().fillna(0)
df['MA20'] = df['close'].rolling(20).mean().fillna(0)
# 技术指标
rsi = ta.RSI(df_temp['close'].values, timeperiod=14)
df['RSI14'] = self.safe_fillna(rsi, 50)
macd, macd_signal, macd_hist = ta.MACD(
df_temp['close'].values,
fastperiod=12,
slowperiod=26,
signalperiod=9
)
df['MACD_hist'] = self.safe_fillna(macd_hist, 0)
# 新增特征
mom = ta.MOM(df_temp['close'].values, timeperiod=10)
df['MOM10'] = self.safe_fillna(mom, 0)
atr = ta.ATR(
df_temp['high'].values,
df_temp['low'].values,
df_temp['close'].values,
timeperiod=14
)
df['ATR14'] = self.safe_fillna(atr, 0)
# 成交量加权平均价
vwap = (df['volume'] * (df['high'] + df['low'] + df['close']) / 3).cumsum() / df['volume'].cumsum()
df['VWAP'] = self.safe_fillna(vwap, 0)
# 相对强弱指数差值
df['RSI_diff'] = df['RSI14'] - df['RSI14'].rolling(5).mean().fillna(0)
# 价格波动比率
df['price_vol_ratio'] = df['price_change'] / (df['volatility'].replace(0, 1e-8) + 1e-8)
# 技术指标组合特征
df['MACD_RSI'] = df['MACD_hist'] * df['RSI14']
# 市场情绪指标
df['advance_decline'] = (df['close'] > df['open']).astype(int).rolling(5).sum().fillna(0)
# 时间特征
df['day_of_week'] = df['date'].dt.dayofweek
df['month'] = df['date'].dt.month
# 处理无穷大和NaN
df = df.replace([np.inf, -np.inf], np.nan)
df = df.fillna(0)
return df
except Exception as e:
logger.error(f"特征工程失败: {str(e)}", exc_info=True)
# 返回基本特征作为回退方案
df['price_change'] = df['close'].pct_change().fillna(0)
df['volatility'] = df['close'].rolling(5).std().fillna(0)
df['volume_change'] = df['volume'].pct_change().fillna(0)
df['MA5'] = df['close'].rolling(5).mean().fillna(0)
df['MA20'] = df['close'].rolling(20).mean().fillna(0)
# 填充缺失的技术指标
for col in self.config.PREDICT_FEATURES:
if col not in df.columns:
df[col] = 0
return df
# ========== 聚类模型 ==========
class StockCluster:
def __init__(self, config):
self.config = config
self.scaler = StandardScaler()
self.kmeans = MiniBatchKMeans(
n_clusters=config.CLUSTER_NUM,
random_state=42,
batch_size=1000
)
self.cluster_map = {} # 股票代码到聚类ID的映射
self.model_file = "stock_cluster_model.pkl" # 模型保存路径
def load(self):
"""从文件加载聚类模型"""
if os.path.exists(self.model_file):
model_data = joblib.load(self.model_file)
self.kmeans = model_data['kmeans']
self.scaler = model_data['scaler']
self.cluster_map = model_data['cluster_map']
logger.info(f"从 {self.model_file} 加载聚类模型")
return True
else:
logger.warning("聚类模型文件不存在")
return False
def transform(self, df, stock_code):
"""为数据添加聚类特征"""
cluster_id = self.cluster_map.get(stock_code, -1) # 默认为-1表示未知聚类
df['cluster'] = cluster_id
return df
# ========== 数据加载函数 ==========
def load_prediction_data(sh_path: str, sz_path: str, lookback_days: int = 30) -> Dict[str, pd.DataFrame]:
"""
加载用于预测的股票数据(只加载最近lookback_days天的数据)
"""
stock_data = {}
exchanges = [
('SH', sh_path),
('SZ', sz_path)
]
total_files = 0
for exchange, path in exchanges:
if os.path.exists(path):
csv_files = [f for f in os.listdir(path) if f.endswith('.csv')]
total_files += len(csv_files)
if total_files == 0:
logger.warning("没有找到任何CSV文件")
return stock_data
pbar = tqdm(total=total_files, desc='加载股票数据')
for exchange, path in exchanges:
if not os.path.exists(path):
continue
for file in os.listdir(path):
if not file.endswith('.csv'):
continue
stock_code = f"{exchange}_{file.split('.')[0]}"
file_path = os.path.join(path, file)
try:
# 读取整个文件
df = pd.read_csv(file_path)
# 验证必要的列是否存在
required_cols = ['date', 'open', 'high', 'low', 'close', 'volume']
if not all(col in df.columns for col in required_cols):
logger.debug(f"股票 {stock_code} 缺少必要列,跳过")
pbar.update(1)
continue
# 转换日期并排序
df['date'] = pd.to_datetime(df['date'])
df = df.sort_values('date', ascending=False)
# 只取最近lookback_days天的数据
if len(df) > lookback_days:
df = df.head(lookback_days)
# 转换数据类型
for col in ['open', 'high', 'low', 'close']:
df[col] = pd.to_numeric(df[col], errors='coerce').astype(np.float32)
df['volume'] = pd.to_numeric(df['volume'], errors='coerce').astype(np.uint32)
# 删除包含NaN的行
df = df.dropna(subset=required_cols)
if len(df) > 0:
stock_data[stock_code] = df
logger.debug(f"成功加载股票 {stock_code},数据条数: {len(df)}")
else:
logger.warning(f"股票 {stock_code} 无有效数据")
except Exception as e:
logger.error(f"加载股票 {stock_code} 失败: {str(e)}", exc_info=True)
pbar.update(1)
pbar.close()
logger.info(f"成功加载 {len(stock_data)} 只股票数据")
return stock_data
# ========== 生成HTML报告 ==========
def generate_html_report(top_stocks: List[Tuple[str, float]],
prediction_date: str,
model_version: str = "1.0",
start_rank: int = 1,
end_rank: int = 50,
blocked_count: int = 0) -> str:
"""
生成HTML格式的预测报告
参数:
top_stocks: 包含(股票代码, 概率)元组的列表
prediction_date: 预测日期
model_version: 模型版本号
start_rank: 起始排名
end_rank: 结束排名
blocked_count: 屏蔽的股票数量
返回:
HTML字符串
"""
# 创建DataFrame
df = pd.DataFrame(top_stocks, columns=['股票代码', '上涨概率'])
df['排名'] = range(start_rank, start_rank + len(df))
# 创建技术指标图表
plt.figure(figsize=(10, 6))
plt.bar(df['股票代码'], df['上涨概率'], color='skyblue')
plt.title(f'Top {start_rank}-{end_rank}股票上涨概率分布', fontsize=16)
plt.xlabel('股票代码', fontsize=12)
plt.ylabel('上涨概率', fontsize=12)
plt.xticks(rotation=90, fontsize=8)
plt.ylim(0.7, 1.0)
plt.grid(axis='y', linestyle='--', alpha=0.7)
# 将图表转换为Base64编码
buf = BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight')
buf.seek(0)
chart_base64 = base64.b64encode(buf.read()).decode('utf-8')
plt.close()
# 生成HTML内容
html_content = f"""
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>大涨小跌预测结果 ({start_rank}-{end_rank}名)</title>
<style>
body {{
font-family: 'Microsoft YaHei', sans-serif;
margin: 0;
padding: 20px;
background-color: #f5f7fa;
color: #333;
}}
.container {{
max-width: 1200px;
margin: 0 auto;
background-color: white;
border-radius: 10px;
box-shadow: 0 0 20px rgba(0, 0, 0, 0.1);
padding: 30px;
}}
.header {{
text-align: center;
padding-bottom: 20px;
border-bottom: 1px solid #eee;
margin-bottom: 30px;
}}
.header h1 {{
color: #1e3a8a;
margin-bottom: 10px;
}}
.header .subtitle {{
color: #6b7280;
font-size: 18px;
}}
.info-box {{
background-color: #f0f7ff;
border-left: 4px solid #3b82f6;
padding: 15px;
margin-bottom: 30px;
border-radius: 0 5px 5px 0;
}}
.chart-container {{
text-align: center;
margin-bottom: 30px;
}}
.chart-container img {{
max-width: 100%;
border-radius: 5px;
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
}}
table {{
width: 100%;
border-collapse: collapse;
margin-bottom: 30px;
}}
th, td {{
padding: 12px 15px;
text-align: center;
border-bottom: 1px solid #e5e7eb;
}}
th {{
background-color: #3b82f6;
color: white;
font-weight: bold;
}}
tr:nth-child(even) {{
background-color: #f9fafb;
}}
tr:hover {{
background-color: #f0f7ff;
}}
.footer {{
text-align: center;
padding-top: 20px;
border-top: 1px solid #eee;
color: #6b7280;
font-size: 14px;
}}
.highlight {{
color: #10b981;
font-weight: bold;
}}
.rank-1 {{ background-color: #ffeb3b; }}
.rank-2 {{ background-color: #e0e0e0; }}
.rank-3 {{ background-color: #ff9800; }}
.blocked-info {{
background-color: #fff3cd;
border-left: 4px solid #ffc107;
padding: 10px;
margin-top: 15px;
border-radius: 0 5px 5px 0;
}}
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>大涨小跌预测结果 ({start_rank}-{end_rank}名)</h1>
<div class="subtitle">基于机器学习模型的股票预测分析</div>
</div>
<div class="info-box">
<p><strong>预测日期:</strong>{prediction_date}</p>
<p><strong>模型版本:</strong>{model_version}</p>
<p><strong>排名范围:</strong>{start_rank}-{end_rank}名</p>
<p><strong>筛选条件:</strong>收盘价 > 开盘价 × 105% 且 最低价 > 开盘价 × 98%</p>
<p><strong>屏蔽规则:</strong>已过滤掉特定前缀的股票(SZ_sz16, SH_sh88)</p>
<p><strong>说明:</strong>本报告基于历史数据预测,不构成投资建议</p>
</div>
<div class="chart-container">
<h2>Top {start_rank}-{end_rank}股票上涨概率分布图</h2>
<img src="data:image/png;base64,{chart_base64}" alt="股票上涨概率分布图">
</div>
<h2>详细预测结果</h2>
<table>
<thead>
<tr>
<th>排名</th>
<th>股票代码</th>
<th>上涨概率</th>
<th>预测评级</th>
</tr>
</thead>
<tbody>
"""
# 添加表格行
for i, (stock_code, prob) in enumerate(top_stocks):
rank = start_rank + i
rating = ""
row_class = ""
if prob >= 0.95:
rating = "⭐⭐⭐⭐⭐"
elif prob >= 0.9:
rating = "⭐⭐⭐⭐"
elif prob >= 0.85:
rating = "⭐⭐⭐"
elif prob >= 0.8:
rating = "⭐⭐"
else:
rating = "⭐"
if rank == 1:
row_class = "class='rank-1'"
elif rank == 2:
row_class = "class='rank-2'"
elif rank == 3:
row_class = "class='rank-3'"
html_content += f"""
<tr {row_class}>
<td>{rank}</td>
<td>{stock_code}</td>
<td class="highlight">{prob:.4f}</td>
<td>{rating}</td>
</tr>
"""
# 添加HTML尾部
html_content += f"""
</tbody>
</table>
<div class="blocked-info">
<p><strong>屏蔽信息:</strong>在预测过程中已过滤掉 {blocked_count} 只以特定前缀开头的股票</p>
</div>
<div class="footer">
<p>生成时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} | 预测模型:LightGBM分类器</p>
<p>© 2025 股票量化分析系统 | 本报告仅供研究参考</p>
</div>
</div>
</body>
</html>
"""
return html_content
# ========== 主预测函数 ==========
def predict_top_stocks(model_path: str = "stock_prediction_model.pkl",
top_n: int = 50,
start_rank: int = 1,
end_rank: int = 50) -> List[Tuple[str, float]]:
"""
预测满足条件的Top N股票并生成HTML报告
新增参数:
start_rank: 起始排名 (包含)
end_rank: 结束排名 (包含)
"""
# 验证排名范围
if start_rank < 1:
logger.warning(f"起始排名不能小于1,已自动调整为1 (原值: {start_rank})")
start_rank = 1
if end_rank < start_rank:
logger.warning(f"结束排名不能小于起始排名,已自动调整为起始排名 (原值: end_rank={end_rank}, start_rank={start_rank})")
end_rank = start_rank
if end_rank > top_n:
logger.warning(f"结束排名不能大于总排名数,已自动调整为最大排名数 (原值: end_rank={end_rank}, top_n={top_n})")
end_rank = top_n
# 1. 初始化配置
config = StockConfig()
logger.info(f"===== 股票预测筛选程序 (排名范围: {start_rank}-{end_rank}) =====")
logger.info(f"屏蔽规则: 将过滤掉前缀为 {config.BLOCKED_PREFIXES} 的股票")
# 2. 加载模型
if not os.path.exists(model_path):
logger.error(f"模型文件 {model_path} 不存在")
return []
try:
model, selected_features = joblib.load(model_path)
logger.info(f"成功加载预测模型,使用特征: {selected_features}")
except Exception as e:
logger.error(f"加载模型失败: {str(e)}", exc_info=True)
return []
# 3. 加载聚类模型
cluster_model = StockCluster(config)
cluster_model_loaded = cluster_model.load()
if not cluster_model_loaded:
logger.warning("无法加载聚类模型,使用默认聚类")
# 4. 加载股票数据(最近30天)
logger.info("加载股票数据...")
stock_data = load_prediction_data(config.SH_PATH, config.SZ_PATH, lookback_days=30)
if not stock_data:
logger.error("没有加载到任何股票数据")
return []
# 5. 初始化特征工程
feature_engineer = FeatureEngineer(config)
# 6. 准备预测数据
predictions = []
blocked_stocks = [] # 存储被屏蔽的股票
logger.info("处理股票数据并进行预测...")
for stock_code, df in tqdm(stock_data.items(), desc="预测股票"):
try:
# 检查是否需要屏蔽该股票
if any(stock_code.startswith(prefix) for prefix in config.BLOCKED_PREFIXES):
blocked_stocks.append(stock_code)
continue
# 确保数据按日期升序排列(用于正确计算指标)
df = df.sort_values('date', ascending=True)
# 特征工程
df = feature_engineer.transform(df.copy())
# 添加聚类特征
if cluster_model_loaded:
df = cluster_model.transform(df, stock_code)
# 获取最新一天的数据(用于预测)
latest_data = df.iloc[-1:].copy()
# 确保所有特征都存在
for feature in selected_features:
if feature not in latest_data.columns:
latest_data[feature] = 0
# 选择模型使用的特征
X_pred = latest_data[selected_features]
# 预测概率(类别1的概率)
proba = model.predict_proba(X_pred)[0, 1]
# 添加到预测结果
predictions.append((stock_code, proba))
except Exception as e:
logger.error(f"处理股票 {stock_code} 失败: {str(e)}", exc_info=True)
# 记录屏蔽信息
blocked_count = len(blocked_stocks)
logger.info(f"已屏蔽 {blocked_count} 只股票: {blocked_stocks}")
# 7. 按概率排序并取指定排名范围
predictions.sort(key=lambda x: x[1], reverse=True)
top_predictions = predictions[:top_n] # 先取全部top_n
# 筛选指定排名范围
selected_predictions = top_predictions[start_rank-1:end_rank] # 列表索引从0开始
# 8. 生成HTML报告
prediction_date = datetime.now().strftime("%Y-%m-%d")
html_content = generate_html_report(selected_predictions, prediction_date,
start_rank=start_rank, end_rank=end_rank,
blocked_count=blocked_count)
# 9. 保存HTML报告
html_file = f"大涨小跌预测结果_{start_rank}-{end_rank}名.html"
with open(html_file, "w", encoding="utf-8") as f:
f.write(html_content)
logger.info(f"已生成HTML报告: {html_file}")
return selected_predictions
if __name__ == "__main__":
# 示例用法:
# 获取第1-10名
#top_1_10 = predict_top_stocks(top_n=50, start_rank=1, end_rank=10)
# 获取第11-30名
#top_11_30 = predict_top_stocks(top_n=50, start_rank=11, end_rank=30)
# 获取第31-50名
#top_31_50 = predict_top_stocks(top_n=50, start_rank=31, end_rank=50)
# 自定义范围 (例如第5-15名)
custom_range = predict_top_stocks(top_n=50, start_rank=1, end_rank=50)
# 保存所有CSV结果
results = [
#(top_1_10, "1-10"),
#(top_11_30, "11-30"),
#(top_31_50, "31-50"),
(custom_range, "1-50")
]
for result, range_name in results:
if result:
result_df = pd.DataFrame(result, columns=['股票代码', '上涨概率'])
# 计算实际排名范围
actual_start = int(range_name.split('-')[0])
actual_end = actual_start + len(result) - 1
result_df['排名'] = range(actual_start, actual_end + 1)
result_df.to_csv(f'大涨小跌预测结果_{range_name}名.csv', index=False, encoding='utf-8-sig')
logger.info(f"结果已保存到 大涨小跌预测结果_{range_name}名.csv")
最新发布