帮我改一下下面这串代码,要求把次日(x+1日)收盘价大于次日(x+1日)开盘价百分之五的条件改成次日(x+1日)开盘价比昨日(x-1日)收盘价高出百分之十,且今日(x日)收盘价不高出今日(x日)开盘价百分之五。删除次日(x+1日)最低价不低于次日(x+1日)开盘价百分之九十八的限制条件。然后,输出的预测模型命名为尾盘选股预测模型,输出的聚类模型命名为微盘选股聚类模型。
# -*- coding: utf-8 -*-
"""
Created on Sun Jul 20 16:00:01 2025
@author: srx20
"""
import os
import gc
import numpy as np
import pandas as pd
import joblib
import talib as ta
from tqdm import tqdm
import random
from sklearn.cluster import MiniBatchKMeans
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import RandomizedSearchCV, GroupKFold
from sklearn.feature_selection import SelectKBest, f_classif
from sklearn.metrics import make_scorer, recall_score, classification_report
import lightgbm as lgb
import logging
import psutil
import warnings
from scipy import sparse
warnings.filterwarnings('ignore')
# 设置日志记录
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('stock_prediction_fixed.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
# ========== 配置类 ==========
class StockConfig:
def __init__(self):
# 数据路径
self.SH_PATH = r"D:\股票量化数据库\股票csv数据\上证"
self.SZ_PATH = r"D:\股票量化数据库\股票csv数据\深证"
# 时间范围
self.START_DATE = "2011-1-1"
self.END_DATE = "2024-1-1"
self.TEST_START = "2024-1-1"
self.TEST_END = "2025-7-18"
# 聚类设置
self.CLUSTER_NUM = 8
self.CLUSTER_FEATURES = [
'price_change', 'volatility', 'volume_change',
'MA5', 'MA20', 'RSI14', 'MACD_hist'
]
# 预测特征 (初始列表,实际使用时会动态更新)
self.PREDICT_FEATURES = [
'open', 'high', 'low', 'close', 'volume',
'price_change', 'volatility', 'volume_change',
'MA5', 'MA20', 'RSI14', 'MACD_hist',
'cluster', 'MOM10', 'ATR14', 'VWAP', 'RSI_diff',
'price_vol_ratio', 'MACD_RSI', 'advance_decline',
'day_of_week', 'month'
]
# 模型参数优化范围(内存优化版)
self.PARAM_GRID = {
'boosting_type': ['gbdt'], # 减少选项
'num_leaves': [31, 63], # 减少选项
'max_depth': [-1, 7], # 减少选项
'learning_rate': [0.01, 0.05],
'n_estimators': [300, 500], # 减少选项
'min_child_samples': [50], # 固定值
'min_split_gain': [0.0, 0.1],
'reg_alpha': [0, 0.1],
'reg_lambda': [0, 0.1],
'feature_fraction': [0.7, 0.9],
'bagging_fraction': [0.7, 0.9],
'bagging_freq': [1]
}
# 目标条件
self.MIN_GAIN = 0.05
self.MIN_LOW_RATIO = 0.98
# 调试模式
self.DEBUG_MODE = False
self.MAX_STOCKS = 50 if self.DEBUG_MODE else None
self.SAMPLE_FRACTION = 0.3 if not self.DEBUG_MODE else 1.0 # 采样比例
# ========== 内存管理工具 (修复版) ==========
def reduce_mem_usage(df):
"""优化DataFrame内存使用,只处理数值列"""
start_mem = df.memory_usage().sum() / 1024**2
# 只处理数值列
numeric_cols = df.select_dtypes(include=['int', 'float', 'integer']).columns
for col in numeric_cols:
col_type = df[col].dtype
if col_type != object:
c_min = df[col].min()
c_max = df[col].max()
if str(col_type)[:3] == 'int':
if c_min > np.iinfo(np.int8).min and c_max < np.iinfo(np.int8).max:
df[col] = df[col].astype(np.int8)
elif c_min > np.iinfo(np.int16).min and c_max < np.iinfo(np.int16).max:
df[col] = df[col].astype(np.int16)
elif c_min > np.iinfo(np.int32).min and c_max < np.iinfo(np.int32).max:
df[col] = df[col].astype(np.int32)
elif c_min > np.iinfo(np.int64).min and c_max < np.iinfo(np.int64).max:
df[col] = df[col].astype(np.int64)
else:
if c_min > np.finfo(np.float16).min and c_max < np.finfo(np.float16).max:
df[col] = df[col].astype(np.float16)
elif c_min > np.finfo(np.float32).min and c_max < np.finfo(np.float32).max:
df[col] = df[col].astype(np.float32)
else:
df[col] = df[col].astype(np.float64)
end_mem = df.memory_usage().sum() / 1024**2
logger.info(f'内存优化: 从 {start_mem:.2f} MB 减少到 {end_mem:.2f} MB ({100*(start_mem-end_mem)/start_mem:.1f}%)')
return df
def print_memory_usage():
"""打印当前内存使用情况"""
process = psutil.Process(os.getpid())
mem = process.memory_info().rss / (1024 ** 2)
logger.info(f"当前内存使用: {mem:.2f} MB")
# ========== 数据加载 (修复版) ==========
def load_stock_data(sh_path, sz_path, start_date, end_date, sample_fraction=1.0, debug_mode=False, max_stocks=None):
"""加载股票数据,并过滤日期范围(修复随机抽样问题)"""
stock_data = {}
# 创建文件列表
all_files = []
for exchange, path in [('SH', sh_path), ('SZ', sz_path)]:
if os.path.exists(path):
csv_files = [f for f in os.listdir(path) if f.endswith('.csv')]
for file in csv_files:
all_files.append((exchange, path, file))
if not all_files:
logger.warning("没有找到任何CSV文件")
return stock_data
# 随机抽样(修复一维问题)
if sample_fraction < 1.0:
sample_size = max(1, int(len(all_files) * sample_fraction))
# 使用random.sample代替np.random.choice
all_files = random.sample(all_files, sample_size)
logger.info(f"抽样 {len(all_files)} 只股票文件 (比例: {sample_fraction})")
total_files = len(all_files)
pbar = tqdm(total=total_files, desc='加载股票数据')
loaded_count = 0
for exchange, path, file in all_files:
if max_stocks is not None and loaded_count >= max_stocks:
break
if file.endswith('.csv'):
stock_code = f"{exchange}_{file.split('.')[0]}"
file_path = os.path.join(path, file)
try:
# 读取数据并验证列名
df = pd.read_csv(file_path)
# 验证必要的列是否存在
required_cols = ['date', 'open', 'high', 'low', 'close', 'volume']
if not all(col in df.columns for col in required_cols):
logger.warning(f"股票 {stock_code} 缺少必要列,跳过")
pbar.update(1)
continue
# 转换日期并过滤
df['date'] = pd.to_datetime(df['date'])
df = df[(df['date'] >= start_date) & (df['date'] <= end_date)]
if len(df) < 50: # 至少50个交易日
logger.info(f"股票 {stock_code} 数据不足({len(df)}条),跳过")
pbar.update(1)
continue
# 转换数据类型
for col in ['open', 'high', 'low', 'close']:
df[col] = pd.to_numeric(df[col], errors='coerce').astype(np.float32)
df['volume'] = pd.to_numeric(df['volume'], errors='coerce').astype(np.uint32)
# 删除包含NaN的行
df = df.dropna(subset=required_cols)
if len(df) > 0:
stock_data[stock_code] = df
loaded_count += 1
logger.debug(f"成功加载股票 {stock_code},数据条数: {len(df)}")
else:
logger.warning(f"股票 {stock_code} 过滤后无数据")
except Exception as e:
logger.error(f"加载股票 {stock_code} 失败: {str(e)}", exc_info=True)
pbar.update(1)
# 调试模式只处理少量股票
if debug_mode and loaded_count >= 10:
logger.info("调试模式: 已加载10只股票,提前结束")
break
pbar.close()
logger.info(f"成功加载 {len(stock_data)} 只股票数据")
return stock_data
# ========== 特征工程 (修复版) ==========
class FeatureEngineer:
def __init__(self, config):
self.config = config
def safe_fillna(self, series, default=0):
"""安全填充NaN值"""
if isinstance(series, pd.Series):
return series.fillna(default)
elif isinstance(series, np.ndarray):
return np.nan_to_num(series, nan=default)
return series
def transform(self, df):
"""添加技术指标特征(修复NumPy数组问题)"""
try:
# 创建临时副本用于TA-Lib计算
df_temp = df.copy()
# 将价格列转换为float64以满足TA-Lib要求
for col in ['open', 'high', 'low', 'close']:
df_temp[col] = df_temp[col].astype(np.float64)
# 基础特征
df['price_change'] = df['close'].pct_change().fillna(0)
df['volatility'] = df['close'].rolling(5).std().fillna(0)
df['volume_change'] = df['volume'].pct_change().fillna(0)
df['MA5'] = df['close'].rolling(5).mean().fillna(0)
df['MA20'] = df['close'].rolling(20).mean().fillna(0)
# 技术指标 - 修复NumPy数组问题
rsi = ta.RSI(df_temp['close'].values, timeperiod=14)
df['RSI14'] = self.safe_fillna(rsi, 50)
macd, macd_signal, macd_hist = ta.MACD(
df_temp['close'].values,
fastperiod=12,
slowperiod=26,
signalperiod=9
)
df['MACD_hist'] = self.safe_fillna(macd_hist, 0)
# 新增特征
mom = ta.MOM(df_temp['close'].values, timeperiod=10)
df['MOM10'] = self.safe_fillna(mom, 0)
atr = ta.ATR(
df_temp['high'].values,
df_temp['low'].values,
df_temp['close'].values,
timeperiod=14
)
df['ATR14'] = self.safe_fillna(atr, 0)
# 成交量加权平均价
vwap = (df['volume'] * (df['high'] + df['low'] + df['close']) / 3).cumsum() / df['volume'].cumsum()
df['VWAP'] = self.safe_fillna(vwap, 0)
# 相对强弱指数差值
df['RSI_diff'] = df['RSI14'] - df['RSI14'].rolling(5).mean().fillna(0)
# 价格波动比率
df['price_vol_ratio'] = df['price_change'] / (df['volatility'].replace(0, 1e-8) + 1e-8)
# 技术指标组合特征
df['MACD_RSI'] = df['MACD_hist'] * df['RSI14']
# 市场情绪指标
df['advance_decline'] = (df['close'] > df['open']).astype(int).rolling(5).sum().fillna(0)
# 时间特征
df['day_of_week'] = df['date'].dt.dayofweek
df['month'] = df['date'].dt.month
# 处理无穷大和NaN
df = df.replace([np.inf, -np.inf], np.nan)
df = df.fillna(0)
# 优化内存(只处理数值列)
return reduce_mem_usage(df)
except Exception as e:
logger.error(f"特征工程失败: {str(e)}", exc_info=True)
# 返回基本特征作为回退方案
df['price_change'] = df['close'].pct_change().fillna(0)
df['volatility'] = df['close'].rolling(5).std().fillna(0)
df['volume_change'] = df['volume'].pct_change().fillna(0)
df['MA5'] = df['close'].rolling(5).mean().fillna(0)
df['MA20'] = df['close'].rolling(20).mean().fillna(0)
# 填充缺失的技术指标
for col in self.config.PREDICT_FEATURES:
if col not in df.columns:
df[col] = 0
return df
# ========== 聚类模型 (添加保存/加载功能) ==========
class StockCluster:
def __init__(self, config):
self.config = config
self.scaler = StandardScaler()
self.kmeans = MiniBatchKMeans(
n_clusters=config.CLUSTER_NUM,
random_state=42,
batch_size=1000
)
self.cluster_map = {} # 股票代码到聚类ID的映射
self.model_file = "stock_cluster_model.pkl" # 模型保存路径
def save(self):
"""保存聚类模型到文件"""
# 创建包含所有必要组件的字典
model_data = {
'kmeans': self.kmeans,
'scaler': self.scaler,
'cluster_map': self.cluster_map,
'config_cluster_num': self.config.CLUSTER_NUM
}
# 使用joblib保存模型
joblib.dump(model_data, self.model_file)
logger.info(f"聚类模型已保存到: {self.model_file}")
def load(self):
"""从文件加载聚类模型"""
if os.path.exists(self.model_file):
model_data = joblib.load(self.model_file)
self.kmeans = model_data['kmeans']
self.scaler = model_data['scaler']
self.cluster_map = model_data['cluster_map']
logger.info(f"从 {self.model_file} 加载聚类模型")
return True
else:
logger.warning("聚类模型文件不存在,需要重新训练")
return False
def fit(self, stock_data):
"""训练聚类模型"""
logger.info("开始股票聚类分析...")
cluster_features = []
# 提取每只股票的特征
for stock_code, df in tqdm(stock_data.items(), desc="提取聚类特征"):
if len(df) < 50: # 至少50个交易日
continue
features = {}
for feat in self.config.CLUSTER_FEATURES:
if feat in df.columns:
# 使用统计特征
features[f"{feat}_mean"] = df[feat].mean()
features[f"{feat}_std"] = df[feat].std()
else:
# 特征缺失时填充0
features[f"{feat}_mean"] = 0
features[f"{feat}_std"] = 0
cluster_features.append(features)
if not cluster_features:
logger.warning("没有可用的聚类特征,使用默认聚类")
# 创建默认聚类映射
self.cluster_map = {code: 0 for code in stock_data.keys()}
return self
# 创建特征DataFrame
feature_df = pd.DataFrame(cluster_features)
feature_df = reduce_mem_usage(feature_df)
# 标准化特征
scaled_features = self.scaler.fit_transform(feature_df)
# 聚类
self.kmeans.fit(scaled_features)
clusters = self.kmeans.predict(scaled_features)
feature_df['cluster'] = clusters
# 创建股票到聚类的映射
stock_codes = list(stock_data.keys())[:len(clusters)] # 确保长度匹配
for i, stock_code in enumerate(stock_codes):
self.cluster_map[stock_code] = clusters[i]
logger.info("聚类分布统计:")
logger.info(feature_df['cluster'].value_counts().to_string())
logger.info(f"股票聚类完成,共分为 {self.config.CLUSTER_NUM} 个类别")
# 训练完成后自动保存模型
self.save()
return self
def transform(self, df, stock_code):
"""为数据添加聚类特征"""
cluster_id = self.cluster_map.get(stock_code, -1) # 默认为-1表示未知聚类
df['cluster'] = cluster_id
return df
# ========== 目标创建 ==========
class TargetCreator:
def __init__(self, config):
self.config = config
def create_targets(self, df):
"""创建目标变量 - 修改为收盘价高于开盘价5%"""
# 计算次日收盘价相对于开盘价的涨幅
df['next_day_open_to_close_gain'] = df['close'].shift(-1) / df['open'].shift(-1) - 1
# 计算次日最低价与开盘价比例
df['next_day_low_ratio'] = df['low'].shift(-1) / df['open'].shift(-1)
# 创建复合目标:收盘价比开盘价高5% 且 最低价≥开盘价98%
df['target'] = 0
mask = (df['next_day_open_to_close_gain'] > self.config.MIN_GAIN) & \
(df['next_day_low_ratio'] >= self.config.MIN_LOW_RATIO)
df.loc[mask, 'target'] = 1
# 删除最后一行(没有次日数据)
df = df.iloc[:-1]
# 检查目标分布
target_counts = df['target'].value_counts()
logger.info(f"目标分布: 0={target_counts.get(0, 0)}, 1={target_counts.get(1, 0)}")
# 添加调试信息
if self.config.DEBUG_MODE:
sample_targets = df[['open', 'close', 'next_day_open_to_close_gain', 'target']].tail(5)
logger.debug(f"目标创建示例:\n{sample_targets}")
return df
# ========== 模型训练 (内存优化版) ==========
class StockModelTrainer:
def __init__(self, config):
self.config = config
self.model_name = "stock_prediction_model"
self.feature_importance = None
def prepare_dataset(self, stock_data, cluster_model, feature_engineer):
"""准备训练数据集(内存优化版)"""
logger.info("准备训练数据集...")
X_list = []
y_list = []
stock_group_list = [] # 用于分组交叉验证
target_creator = TargetCreator(self.config)
# 使用生成器减少内存占用
for stock_code, df in tqdm(stock_data.items(), desc="处理股票数据"):
try:
# 特征工程
df = feature_engineer.transform(df.copy())
# 添加聚类特征
df = cluster_model.transform(df, stock_code)
# 创建目标
df = target_creator.create_targets(df)
# 只保留所需特征和目标
features = self.config.PREDICT_FEATURES
if 'target' not in df.columns:
logger.warning(f"股票 {stock_code} 缺少目标列,跳过")
continue
X = df[features]
y = df['target']
# 确保没有NaN值
if X.isnull().any().any():
logger.warning(f"股票 {stock_code} 特征包含NaN值,跳过")
continue
# 使用稀疏矩阵存储(减少内存)
sparse_X = sparse.csr_matrix(X.values.astype(np.float32))
X_list.append(sparse_X)
y_list.append(y.values)
stock_group_list.extend([stock_code] * len(X)) # 为每个样本添加股票代码作为组标识
# 定期清理内存
if len(X_list) % 100 == 0:
gc.collect()
print_memory_usage()
except Exception as e:
logger.error(f"处理股票 {stock_code} 失败: {str(e)}", exc_info=True)
if not X_list:
logger.error("没有可用的训练数据")
return None, None, None
# 合并所有数据
X_full = sparse.vstack(X_list)
y_full = np.concatenate(y_list)
groups = np.array(stock_group_list)
logger.info(f"数据集准备完成,样本数: {X_full.shape[0]}")
logger.info(f"目标分布: 0={sum(y_full==0)}, 1={sum(y_full==1)}")
return X_full, y_full, groups
def feature_selection(self, X, y):
"""执行特征选择(内存优化版)"""
logger.info("执行特征选择...")
# 使用基模型评估特征重要性
base_model = lgb.LGBMClassifier(
n_estimators=100,
random_state=42,
n_jobs=-1
)
# 分批训练(减少内存占用)
batch_size = 100000
for i in range(0, X.shape[0], batch_size):
end_idx = min(i + batch_size, X.shape[0])
X_batch = X[i:end_idx].toarray() if sparse.issparse(X) else X[i:end_idx]
y_batch = y[i:end_idx]
if i == 0:
base_model.fit(X_batch, y_batch)
else:
base_model.fit(X_batch, y_batch, init_model=base_model)
# 获取特征重要性
importance = pd.Series(base_model.feature_importances_, index=self.config.PREDICT_FEATURES)
importance = importance.sort_values(ascending=False)
logger.info("特征重要性:\n" + importance.to_string())
# 选择前K个重要特征
k = min(15, len(self.config.PREDICT_FEATURES))
selected_features = importance.head(k).index.tolist()
logger.info(f"选择前 {k} 个特征: {selected_features}")
# 更新配置中的特征列表
self.config.PREDICT_FEATURES = selected_features
# 转换特征矩阵
if sparse.issparse(X):
# 对于稀疏矩阵,我们需要重新索引
feature_indices = [self.config.PREDICT_FEATURES.index(f) for f in selected_features]
X_selected = X[:, feature_indices]
else:
X_selected = X[selected_features]
return X_selected, selected_features
def train_model(self, X, y, groups):
"""训练并优化模型(内存优化版)"""
if X is None or len(y) == 0:
logger.error("训练数据为空,无法训练模型")
return None
logger.info("开始训练模型...")
# 1. 处理类别不平衡
pos_count = sum(y == 1)
neg_count = sum(y == 0)
scale_pos_weight = neg_count / pos_count if pos_count > 0 else 1.0
logger.info(f"类别不平衡处理: 正样本权重 = {scale_pos_weight:.2f}")
# 2. 特征选择
X_selected, selected_features = self.feature_selection(X, y)
# 3. 自定义评分函数 - 关注正类召回率
def positive_recall_score(y_true, y_pred):
return recall_score(y_true, y_pred, pos_label=1)
custom_scorer = make_scorer(positive_recall_score, greater_is_better=True)
# 4. 使用分组时间序列交叉验证(减少折数)
group_kfold = GroupKFold(n_splits=2) # 减少折数以节省内存
cv = list(group_kfold.split(X_selected, y, groups=groups))
# 5. 创建模型
model = lgb.LGBMClassifier(
objective='binary',
random_state=42,
n_jobs=-1,
scale_pos_weight=scale_pos_weight,
verbose=-1
)
# 6. 参数搜索(减少迭代次数)
search = RandomizedSearchCV(
estimator=model,
param_distributions=self.config.PARAM_GRID,
n_iter=10, # 减少迭代次数以节省内存
scoring=custom_scorer,
cv=cv,
verbose=2,
n_jobs=1, # 减少并行任务以节省内存
pre_dispatch='2*n_jobs', # 控制任务分发
random_state=42
)
logger.info("开始参数搜索...")
# 分批处理数据(减少内存占用)
if sparse.issparse(X_selected):
X_dense = X_selected.toarray() # 转换为密集矩阵用于搜索
else:
X_dense = X_selected
search.fit(X_dense, y)
# 7. 使用最佳参数训练最终模型
best_params = search.best_params_
logger.info(f"最佳参数: {best_params}")
logger.info(f"最佳召回率: {search.best_score_}")
final_model = lgb.LGBMClassifier(
**best_params,
objective='binary',
random_state=42,
n_jobs=-1,
scale_pos_weight=scale_pos_weight
)
# 使用早停策略训练最终模型
logger.info("训练最终模型...")
final_model.fit(
X_dense, y,
eval_set=[(X_dense, y)],
eval_metric='binary_logloss',
callbacks=[
lgb.early_stopping(stopping_rounds=50, verbose=False),
lgb.log_evaluation(period=100)
]
)
# 保存特征重要性
self.feature_importance = pd.Series(
final_model.feature_importances_,
index=selected_features
).sort_values(ascending=False)
# 8. 保存模型
model_path = f"{self.model_name}.pkl"
joblib.dump((final_model, selected_features), model_path)
logger.info(f"模型已保存到: {model_path}")
return final_model
def evaluate_model(self, model, X_test, y_test):
"""评估模型性能"""
if model is None or len(X_test) == 0:
logger.warning("无法评估模型,缺少数据或模型")
return
# 预测测试集
y_pred = model.predict(X_test)
# 计算召回率
recall = recall_score(y_test, y_pred, pos_label=1)
logger.info(f"测试集召回率: {recall:.4f}")
# 计算满足条件的样本比例
condition_ratio = sum(y_test == 1) / len(y_test)
logger.info(f"满足条件的样本比例: {condition_ratio:.4f}")
# 详细分类报告
report = classification_report(y_test, y_pred)
logger.info("分类报告:\n" + report)
# 特征重要性
if self.feature_importance is not None:
logger.info("特征重要性:\n" + self.feature_importance.to_string())
# ========== 主程序 ==========
def main():
# 初始化配置
config = StockConfig()
logger.info("===== 股票上涨预测程序 (修复版) =====")
# 加载训练数据(添加抽样)
logger.info(f"加载训练数据: {config.START_DATE} 至 {config.END_DATE}")
train_data = load_stock_data(
config.SH_PATH, config.SZ_PATH,
config.START_DATE, config.END_DATE,
sample_fraction=config.SAMPLE_FRACTION,
debug_mode=config.DEBUG_MODE,
max_stocks=config.MAX_STOCKS
)
if not train_data:
logger.error("错误: 没有加载到任何股票数据,请检查数据路径和格式")
return
# 特征工程
feature_engineer = FeatureEngineer(config)
# 聚类分析 - 尝试加载现有模型,否则训练新模型
cluster_model = StockCluster(config)
if not cluster_model.load(): # 尝试加载模型
try:
cluster_model.fit(train_data)
except Exception as e:
logger.error(f"聚类分析失败: {str(e)}", exc_info=True)
# 创建默认聚类映射
cluster_model.cluster_map = {code: 0 for code in train_data.keys()}
logger.info("使用默认聚类(所有股票归为同一类)")
cluster_model.save() # 保存默认聚类模型
# 准备训练数据
trainer = StockModelTrainer(config)
try:
X_train, y_train, groups = trainer.prepare_dataset(
train_data, cluster_model, feature_engineer
)
except Exception as e:
logger.error(f"准备训练数据失败: {str(e)}", exc_info=True)
return
if X_train is None or len(y_train) == 0:
logger.error("错误: 没有可用的训练数据")
return
# 训练模型
model = trainer.train_model(X_train, y_train, groups)
if model is None:
logger.error("模型训练失败")
return
# 加载测试数据(添加抽样)
logger.info(f"\n加载测试数据: {config.TEST_START} 至 {config.TEST_END}")
test_data = load_stock_data(
config.SH_PATH, config.SZ_PATH,
config.TEST_START, config.TEST_END,
sample_fraction=config.SAMPLE_FRACTION,
debug_mode=config.DEBUG_MODE,
max_stocks=config.MAX_STOCKS
)
if test_data:
# 准备测试数据
X_test, y_test, _ = trainer.prepare_dataset(
test_data, cluster_model, feature_engineer
)
if X_test is not None and len(y_test) > 0:
# 评估模型
if sparse.issparse(X_test):
X_test = X_test.toarray()
trainer.evaluate_model(model, X_test, y_test)
else:
logger.warning("测试数据准备失败,无法评估模型")
else:
logger.warning("没有测试数据可用")
logger.info("===== 程序执行完成 =====")
if __name__ == "__main__":
main()
最新发布