<think>根据错误信息,主要问题有两个:
1. 在计算技术指标(如RSI、MACD)时,出现了“input array type is not double”的错误。这是因为TA-Lib要求输入的数据类型是双精度浮点数(double),而我们之前为了节省内存将价格数据转换为了float32。
2. 在准备训练数据集时,由于前面的特征工程失败,导致没有成功处理任何股票数据,因此X_list为空,在合并时出现了“No objects to concatenate”的错误。
解决方案:
1. 在调用TA-Lib函数之前,将所需的列转换为双精度浮点数(float64)。计算完成后再转换回float32以节省内存。
2. 在特征工程类中,对于TA-Lib计算的部分,临时转换为float64,计算后再转换回float32。
3. 确保至少有一些股票数据被成功处理,避免合并空列表。
具体修改如下:
修改特征工程类中的transform方法:
- 对于需要TA-Lib计算的列(如close),在计算前转换为float64,计算完成后再将结果转换为float32,并将原数据转换回float32。
- 注意:我们只对计算技术指标时需要的列进行临时转换,避免内存占用过大。
此外,为了确保程序健壮性,我们还需要:
- 在特征工程中捕获异常,避免单只股票的失败影响整个程序。
- 确保即使有部分股票失败,只要还有成功处理的股票,程序就能继续运行。
下面我们修改代码:</think>根据错误日志,主要存在两个问题:
1. 特征计算时出现 "input array type is not double" 错误
2. 最终没有成功处理任何股票数据,导致数据合并失败
### 解决方案
#### 问题1: "input array type is not double" 错误
这个错误是由于TA-Lib库要求输入数据为双精度浮点数(double),而我们之前为了节省内存使用了float32类型。解决方案是在特征工程中临时转换为double进行计算。
修改特征工程类:
```python
class FeatureEngineer:
def __init__(self, config):
self.config = config
def transform(self, df):
"""添加技术指标特征"""
# 创建临时副本用于TA-Lib计算
df_temp = df.copy()
# 将价格列转换为float64以满足TA-Lib要求
for col in ['open', 'high', 'low', 'close']:
df_temp[col] = df_temp[col].astype(np.float64)
# 价格变化
df['price_change'] = df['close'].pct_change().fillna(0)
# 波动率
df['volatility'] = df['close'].rolling(5).std().fillna(0)
# 成交量变化
df['volume_change'] = df['volume'].pct_change().fillna(0)
# 移动平均线
df['MA5'] = df['close'].rolling(5).mean().fillna(0)
df['MA20'] = df['close'].rolling(20).mean().fillna(0)
# RSI - 使用临时副本
df['RSI14'] = ta.RSI(df_temp['close'].values, timeperiod=14)
df['RSI14'] = df['RSI14'].fillna(50) # 默认值50
# MACD - 使用临时副本
macd, macd_signal, macd_hist = ta.MACD(df_temp['close'].values,
fastperiod=12,
slowperiod=26,
signalperiod=9)
df['MACD'] = np.nan_to_num(macd)
df['MACD_signal'] = np.nan_to_num(macd_signal)
df['MACD_hist'] = np.nan_to_num(macd_hist)
# 布林带 - 使用临时副本
upper, middle, lower = ta.BBANDS(df_temp['close'].values,
timeperiod=20,
nbdevup=2,
nbdevdn=2)
df['BB_upper'] = np.nan_to_num(upper)
df['BB_middle'] = np.nan_to_num(middle)
df['BB_lower'] = np.nan_to_num(lower)
# 处理无穷大和NaN
df = df.replace([np.inf, -np.inf], np.nan)
df = df.fillna(0)
# 优化内存
return reduce_mem_usage(df)
```
#### 问题2: 没有成功处理任何股票数据
这通常是由于特征工程失败或数据格式问题导致的。我们需要添加更健壮的错误处理机制:
1. 在数据加载阶段添加格式验证
2. 在特征工程阶段添加异常捕获
3. 确保即使部分股票失败,也能处理其他股票
修改数据加载和特征工程部分:
```python
# 修改数据加载函数
def load_stock_data(sh_path, sz_path, start_date, end_date):
"""加载股票数据,并过滤日期范围"""
stock_data = {}
total_files = 0
# 统计总文件数
for exchange, path in [('SH', sh_path), ('SZ', sz_path)]:
if os.path.exists(path):
total_files += len([f for f in os.listdir(path) if f.endswith('.csv')])
pbar = tqdm(total=total_files, desc='加载股票数据')
for exchange, path in [('SH', sh_path), ('SZ', sz_path)]:
if not os.path.exists(path):
continue
for file in os.listdir(path):
if file.endswith('.csv'):
stock_code = f"{exchange}{file.split('.')[0]}"
file_path = os.path.join(path, file)
try:
# 读取数据并验证列名
df = pd.read_csv(file_path)
# 验证必要的列是否存在
required_cols = ['date', 'open', 'high', 'low', 'close', 'volume']
if not all(col in df.columns for col in required_cols):
print(f"股票 {stock_code} 缺少必要列,跳过")
continue
# 转换日期并过滤
df['date'] = pd.to_datetime(df['date'])
df = df[(df['date'] >= start_date) & (df['date'] <= end_date)]
if len(df) > 100: # 至少100个交易日
# 转换数据类型
for col in ['open', 'high', 'low', 'close']:
# 确保是数值类型
df[col] = pd.to_numeric(df[col], errors='coerce')
df[col] = df[col].astype(np.float32)
df['volume'] = pd.to_numeric(df['volume'], errors='coerce')
df['volume'] = df['volume'].astype(np.uint32)
# 删除包含NaN的行
df = df.dropna(subset=required_cols)
if len(df) > 0:
stock_data[stock_code] = df
except Exception as e:
print(f"加载股票 {stock_code} 失败: {str(e)}")
pbar.update(1)
pbar.close()
print(f"成功加载 {len(stock_data)} 只股票数据")
return stock_data
# 修改特征工程中的异常处理
class FeatureEngineer:
def transform(self, df):
try:
# ... 特征计算代码 ...
except Exception as e:
print(f"特征工程失败: {str(e)}")
# 返回基本特征
df['price_change'] = df['close'].pct_change().fillna(0)
df['volatility'] = df['close'].rolling(5).std().fillna(0)
df['volume_change'] = df['volume'].pct_change().fillna(0)
df['MA5'] = df['close'].rolling(5).mean().fillna(0)
df['MA20'] = df['close'].rolling(20).mean().fillna(0)
# 填充缺失的技术指标
for col in ['RSI14', 'MACD', 'MACD_signal', 'MACD_hist',
'BB_upper', 'BB_middle', 'BB_lower']:
if col not in df.columns:
df[col] = 0
return df
```
### 完整修复后的主程序
```python
# 在原有程序中应用上述修复
# ========== 主程序 ==========
def main():
# 初始化配置
config = StockConfig()
# 加载训练数据
print(f"加载训练数据: {config.START_DATE} 至 {config.END_DATE}")
train_data = load_stock_data(config.SH_PATH, config.SZ_PATH,
config.START_DATE, config.END_DATE)
if not train_data:
print("错误: 没有加载到任何股票数据,请检查数据路径和格式")
return
# 特征工程
feature_engineer = FeatureEngineer(config)
# 聚类分析
cluster_model = StockCluster(config)
try:
cluster_model.fit(train_data)
except Exception as e:
print(f"聚类分析失败: {str(e)}")
# 创建默认聚类映射
cluster_model.cluster_map = {code: 0 for code in train_data.keys()}
print("使用默认聚类(所有股票归为同一类)")
# 准备训练数据
trainer = StockModelTrainer(config)
try:
X_train, y_train = trainer.prepare_dataset(train_data, cluster_model, feature_engineer)
except ValueError as e:
print(f"准备训练数据失败: {str(e)}")
# 尝试使用更简单的特征
print("尝试使用基本特征...")
config.PREDICT_FEATURES = ['open', 'high', 'low', 'close', 'volume', 'cluster']
X_train, y_train = trainer.prepare_dataset(train_data, cluster_model, feature_engineer)
if len(X_train) == 0:
print("错误: 没有可用的训练数据")
return
# 训练模型
model = trainer.train_model(X_train, y_train)
# 加载测试数据
print(f"\n加载测试数据: {config.TEST_START} 至 {config.TEST_END}")
test_data = load_stock_data(config.SH_PATH, config.SZ_PATH,
config.TEST_START, config.TEST_END)
if test_data:
# 准备测试数据
X_test, y_test = trainer.prepare_dataset(test_data, cluster_model, feature_engineer)
# 评估模型
y_pred = model.predict(X_test)
recall = recall_score(y_test, y_pred, pos_label=1)
print(f"测试集召回率: {recall:.4f}")
# 计算满足条件的样本比例
condition_ratio = sum(y_test == 1) / len(y_test)
print(f"满足条件的样本比例: {condition_ratio:.4f}")
else:
print("没有测试数据可用")
if __name__ == "__main__":
main()
```
### 关键修复点解释
1. **数据类型转换修复**:
```python
# 创建临时副本用于TA-Lib计算
df_temp = df.copy()
for col in ['open', 'high', 'low', 'close']:
df_temp[col] = df_temp[col].astype(np.float64)
```
- 创建临时DataFrame并转换为float64以满足TA-Lib要求
- 原始数据保持float32节省内存
2. **数据验证增强**:
```python
# 验证必要的列是否存在
required_cols = ['date', 'open', 'high', 'low', 'close', 'volume']
if not all(col in df.columns for col in required_cols):
print(f"股票 {stock_code} 缺少必要列,跳过")
continue
# 确保是数值类型
df[col] = pd.to_numeric(df[col], errors='coerce')
```
- 检查必要列是否存在
- 使用`pd.to_numeric`强制转换数值类型
- 删除包含NaN的行
3. **健壮的错误处理**:
```python
try:
# 特征计算
except Exception as e:
print(f"特征工程失败: {str(e)}")
# 返回基本特征
df['price_change'] = ...
```
- 捕获特征工程中的异常
- 提供回退方案(基本特征)
- 填充缺失的技术指标
4. **备用方案**:
```python
# 聚类失败时使用默认聚类
cluster_model.cluster_map = {code: 0 for code in train_data.keys()}
# 特征工程失败时使用基本特征
config.PREDICT_FEATURES = ['open', 'high', 'low', 'close', 'volume', 'cluster']
```
- 聚类失败时将所有股票归为同一类
- 特征工程失败时仅使用基本价格特征
5. **数据可用性检查**:
```python
if not train_data:
print("错误: 没有加载到任何股票数据,请检查数据路径和格式")
return
if len(X_train) == 0:
print("错误: 没有可用的训练数据")
return
```
- 在关键步骤检查数据可用性
- 提供明确的错误信息
### 预防性措施
1. **数据预处理脚本**:
创建一个单独的数据验证脚本,确保所有CSV文件格式正确:
```python
import pandas as pd
import os
def validate_stock_data(directory):
for file in os.listdir(directory):
if file.endswith('.csv'):
path = os.path.join(directory, file)
try:
df = pd.read_csv(path)
# 验证列名
required = ['date', 'open', 'high', 'low', 'close', 'volume']
if not all(col in df.columns for col in required):
print(f"文件 {file} 缺少必要列")
# 验证数据类型
for col in required:
pd.to_numeric(df[col], errors='raise')
except Exception as e:
print(f"文件 {file} 验证失败: {str(e)}")
# 验证上证和深证数据
validate_stock_data(r"D:\股票量化数据库\股票csv数据\上证")
validate_stock_data(r"D:\股票量化数据库\股票csv数据\深证")
```
2. **日志记录**:
添加详细日志记录,便于追踪问题:
```python
import logging
logging.basicConfig(
filename='stock_prediction.log',
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
# 在代码中使用
try:
# 某些操作
except Exception as e:
logging.error(f"处理股票 {stock_code} 失败: {str(e)}", exc_info=True)
```
3. **逐步调试模式**:
```python
DEBUG_MODE = True # 设置为False关闭调试模式
if DEBUG_MODE:
# 只处理少量股票
test_codes = list(train_data.keys())[:10]
train_data = {code: train_data[code] for code in test_codes}
```
这些修复应该能解决您遇到的 "input array type is not double" 错误和数据合并失败问题,使程序能够成功运行。