2012_05_27

待调整涨跌停标记后完成高质量数据加载及异常数据标记 # -*- coding: utf-8 -*- import jqdata import pandas as pd import numpy as np import logging from jqdata import * from datetime import datetime from scipy.stats import iqr from sklearn.ensemble import IsolationForest import arch from scipy.stats import gaussian_kde logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) class StockDataLoader: def __init__(self): self.market_cap_data = {} self.special_events = self._create_special_events_db() self.security_info_cache = {} def _create_special_events_db(self): events = { 'circuit_breaker': ['2016-01-04', '2016-01-07'], 'major_events': [ '2015-06-12', '2015-07-08', '2016-01-04', '2016-01-07', '2018-03-23', '2019-05-06', '2020-01-23', '2020-02-03', '2020-03-09', '2020-03-12', '2020-03-16', '2020-03-18', '2020-07-06', '2021-01-04', '2021-02-18', '2021-07-26', '2022-02-24', '2022-03-07', '2022-03-08', '2022-03-09', '2022-04-25', '2022-10-24', '2022-10-28', '2023-01-30' ], 'black_swan': [ '2015-06-12', '2018-03-23', '2020-01-23', '2020-03-09', '2021-07-26', '2022-02-24' ], 'extreme_market': [ '2015-06-19', '2015-06-26', '2015-07-03', '2015-07-09', '2015-08-24', '2015-08-25', '2016-01-04', '2016-01-07', '2018-10-11', '2019-05-06', '2020-02-03', '2020-07-06', '2020-07-16', '2021-02-18', '2021-07-27', '2022-03-15', '2022-04-25', '2022-10-24' ], 'gem_reform': [ '2020-08-24', '2020-08-27', '2020-09-09', '2021-02-18', '2021-04-19', '2021-08-24' ] } events_date = {} for key, date_list in events.items(): events_date[key] = [pd.Timestamp(date) for date in date_list] return events_date def get_security_info_cached(self, stock_code): if stock_code not in self.security_info_cache: self.security_info_cache[stock_code] = get_security_info(stock_code) return self.security_info_cache[stock_code] def is_special_event(self, date, event_type): return date in self.special_events.get(event_type, []) def get_all_stocks(self): all_stocks = get_all_securities(types=['stock'], date='2012-12-31').index.tolist() stocks = [] for stock in all_stocks: stock_info = get_security_info(stock) if stock_info: start_date = pd.Timestamp(stock_info.start_date) end_date = pd.Timestamp(stock_info.end_date) if stock_info.end_date else None if start_date < pd.Timestamp('2013-01-01'): if end_date is None or end_date > pd.Timestamp('2023-12-31'): stocks.append(stock) else: logger.debug(f"排除退市股票: {stock}, 退市日期: {stock_info.end_date}") return stocks def filter_st_stocks(self, stocks, start_date, end_date): logger.info(f"开始过滤ST股票 ({len(stocks)}只)") non_st_stocks = [] batch_size = 100 trade_days = get_trade_days(start_date=start_date, end_date=end_date) for i in range(0, len(stocks), batch_size): batch = stocks[i:i+batch_size] logger.info(f"正在检查ST状态批次 {i//batch_size+1}/{(len(stocks)-1)//batch_size+1}") for stock in batch: try: st_status = get_extras('is_st', stock, start_date=start_date, end_date=end_date) if st_status is None or st_status.dropna().empty: non_st_stocks.append(stock) else: if st_status.iloc[:, 0].sum() == 0: non_st_stocks.append(stock) else: st_dates = st_status[st_status.iloc[:, 0]].index.tolist() logger.debug(f"过滤ST股票: {stock}, ST日期: {st_dates}") except Exception as e: logger.error(f"检查{stock}的ST状态失败: {str(e)}") non_st_stocks.append(stock) logger.info(f"过滤后剩余股票数量: {len(non_st_stocks)}") return non_st_stocks def get_market_cap(self, stock, date='2012-12-31'): if stock not in self.market_cap_data: try: q = query(valuation).filter(valuation.code == stock) df = get_fundamentals(q, date=date) if not df.empty: self.market_cap_data[stock] = df['market_cap'].iloc[0] else: self.market_cap_data[stock] = np.nan except Exception as e: logger.warning(f"获取{stock}市值失败: {str(e)}") self.market_cap_data[stock] = np.nan return self.market_cap_data[stock] def categorize_stocks(self, stocks): market_caps = [] for stock in stocks: cap = self.get_market_cap(stock) if not np.isnan(cap): market_caps.append((stock, cap)) sorted_stocks = sorted(market_caps, key=lambda x: x[1], reverse=True) total = len(sorted_stocks) large_cap = [s[0] for s in sorted_stocks[:total//3]] mid_cap = [s[0] for s in sorted_stocks[total//3:2*total//3]] small_cap = [s[0] for s in sorted_stocks[2*total//3:]] return large_cap, mid_cap, small_cap def sample_stocks(self, large_cap, mid_cap, small_cap, n=100): large_sample = np.random.choice(large_cap, min(n, len(large_cap)), replace=False) if large_cap else [] mid_sample = np.random.choice(mid_cap, min(n, len(mid_cap)), replace=False) if mid_cap else [] small_sample = np.random.choice(small_cap, min(n, len(small_cap)), replace=False) if small_cap else [] return list(large_sample) + list(mid_sample) + list(small_sample) def calculate_price_limits(self, price_data): price_data = price_data.copy() unique_codes = price_data['code'].unique() security_types = { code: self.get_security_info_cached(code).type if self.get_security_info_cached(code) else 'normal' for code in unique_codes } price_data['security_type'] = price_data['code'].map(security_types) price_data['price_limit_threshold'] = 0.10 gem_mask = (price_data['security_type'] == 'gem') & (price_data['date'] >= '2020-08-24') price_data.loc[gem_mask, 'price_limit_threshold'] = 0.20 ks_mask = price_data['security_type'] == 'ks' price_data.loc[ks_mask, 'price_limit_threshold'] = 0.20 bj_mask = price_data['security_type'] == 'bj' price_data.loc[bj_mask, 'price_limit_threshold'] = 0.30 price_data['up_limit'] = np.round(price_data['pre_close'] * (1 + price_data['price_limit_threshold']), 2) price_data['down_limit'] = np.round(price_data['pre_close'] * (1 - price_data['price_limit_threshold']), 2) price_data['up_limit_hit'] = ( (price_data['high'] >= price_data['up_limit'] - 0.015) & (price_data['low'] <= price_data['up_limit'] + 0.015) ).astype(int) price_data['down_limit_hit'] = ( (price_data['low'] <= price_data['down_limit'] + 0.015) & (price_data['high'] >= price_data['down_limit'] - 0.015) ).astype(int) price_data['limit_one_way'] = ( (price_data['up_limit_hit'] == 1) & (price_data['low'] == price_data['high']) & (price_data['open'] == price_data['close']) ).astype(int) price_data.drop(['security_type', 'up_limit', 'down_limit'], axis=1, inplace=True) return price_data def mark_special_events_vectorized(self, price_data): price_data = price_data.copy() price_data['special_events'] = '' for event_type in self.special_events.keys(): event_mask = price_data['date'].isin(self.special_events[event_type]) price_data.loc[event_mask, 'special_events'] = price_data.loc[event_mask, 'special_events'] + event_type + ',' price_data['special_events'] = price_data['special_events'].str.rstrip(',') price_data['special_events'] = price_data['special_events'].replace('', np.nan) return price_data def mark_anomalies(self, price_data): """异常标记层:集成MAD、KDE、非参数波动率检测""" if price_data.empty: return price_data # 仅分析正常交易日 valid_mask = ( (price_data['suspended'] == 0) & (price_data['up_limit_hit'] == 0) & (price_data['down_limit_hit'] == 0) & price_data['special_events'].isna() ) valid_data = price_data[valid_mask].copy() if valid_data.empty: return price_data valid_data['return'] = np.log(valid_data['close'] / valid_data['pre_close']) # 初始化异常标记列 price_data['mad_anomaly'] = 0 price_data['kde_anomaly'] = 0 price_data['vol_anomaly'] = 0 # MAD异常检测 for stock, group in valid_data.groupby('code'): returns = group['return'] if len(returns) < 10: continue median = returns.median() mad = np.median(np.abs(returns - median)) threshold = 5 * 1.4826 * mad anomaly_mask = np.abs(returns - median) > threshold anomaly_indices = group[anomaly_mask].index price_data.loc[anomaly_indices, 'mad_anomaly'] = 1 # KDE异常检测 for stock, group in valid_data.groupby('code'): X = group[['return', 'volume']].values if len(X) < 20: continue X_norm = (X - X.mean(axis=0)) / X.std(axis=0) return_kde = gaussian_kde(X_norm[:, 0]) volume_kde = gaussian_kde(X_norm[:, 1]) densities = return_kde(X_norm[:, 0]) * volume_kde(X_norm[:, 1]) threshold = np.percentile(densities, 1) anomaly_mask = densities < threshold anomaly_indices = group[anomaly_mask].index price_data.loc[anomaly_indices, 'kde_anomaly'] = 1 # 非参数波动率异常检测 for stock, group in valid_data.groupby('code'): returns = group['return'] if len(returns) < 20: continue realized_vol = returns.rolling(5).std() realized_vol_no_na = realized_vol.dropna() if len(realized_vol_no_na) == 0: continue med_vol = realized_vol_no_na.median() mad_vol = np.median(np.abs(realized_vol_no_na - med_vol)) threshold = med_vol + 3 * 1.4826 * mad_vol anomaly_mask = realized_vol > threshold anomaly_indices = group[anomaly_mask].index price_data.loc[anomaly_indices, 'vol_anomaly'] = 1 return price_data def load_price_data(self, stocks, start_date, end_date): trade_days = get_trade_days(start_date=start_date, end_date=end_date) logger.info(f"交易日数量: {len(trade_days)}") data_frames = [] batch_size = 100 total = len(stocks) for i in range(0, total, batch_size): batch = stocks[i:i+batch_size] logger.info(f"加载股票批次 {i//batch_size+1}/{(total-1)//batch_size+1} ({len(batch)}只股票)") try: batch_data = get_price( batch, start_date=trade_days[0], end_date=trade_days[-1], fields=['open', 'close', 'high', 'low', 'volume', 'pre_close'], frequency='daily', panel=False, skip_paused=False, fq='pre', fill_paused=True ) if batch_data is None or batch_data.empty: logger.warning(f"批次加载失败,跳过此批次") continue if 'time' in batch_data.columns: batch_data.rename(columns={'time': 'date'}, inplace=True) batch_data['date'] = pd.to_datetime(batch_data['date']) batch_data['suspended'] = (batch_data['volume'] == 0).astype(int) batch_data.sort_values(['code', 'date'], inplace=True) batch_data['prev_suspended'] = batch_data.groupby('code')['suspended'].shift(1) batch_data['resumption_first_day'] = ((batch_data['suspended'] == 0) & (batch_data['prev_suspended'] == 1)).astype(int) batch_data.drop('prev_suspended', axis=1, inplace=True) logger.info("向量化计算涨跌停标记...") batch_data = self.calculate_price_limits(batch_data) logger.info("向量化标记特殊事件...") batch_data = self.mark_special_events_vectorized(batch_data) logger.info("执行异常标记层...") batch_data = self.mark_anomalies(batch_data) data_frames.append(batch_data) except Exception as e: logger.error(f"批次加载失败: {str(e)},跳过此批次") if data_frames: combined = pd.concat(data_frames) return combined.set_index(['date', 'code']).sort_index() return pd.DataFrame() def run_data_validation(data): """重构后的数据验证函数""" logger.info("\n" + "="*60) logger.info("开始运行数据验证测试") logger.info("="*60) data_reset = data.reset_index() results = { 'missing_values': {}, 'limit_issues': 0, 'event_mark_issues': {}, 'anomaly_mark_stats': {'mad': 0, 'kde': 0, 'vol': 0}, 'special_cases': {'zero_volume': 0} } # 1. 检查缺失值 null_counts = data_reset.isnull().sum() results['missing_values'] = {col: count for col, count in null_counts.items() if count > 0} # 2. 检查涨跌停标记 # 获取涨停阈值 data_reset['calculated_threshold'] = np.round( data_reset['pre_close'] * (1 + data_reset['price_limit_threshold']), 2 ) false_negatives = data_reset[ (data_reset['high'] >= data_reset['calculated_threshold'] - 0.015) & (data_reset['up_limit_hit'] == 0) & (data_reset['suspended'] == 0) ] false_positives = data_reset[ (data_reset['high'] < data_reset['calculated_threshold'] - 0.015) & (data_reset['up_limit_hit'] == 1) & (data_reset['suspended'] == 0) ] results['limit_issues'] = len(false_negatives) + len(false_positives) # 3. 检查特殊事件标记 special_events_db = { 'circuit_breaker': ['2016-01-04', '2016-01-07'], 'major_events': [ '2015-06-12', '2015-07-08', '2016-01-04', '2016-01-07', '2018-03-23', '2019-05-06', '2020-01-23', '2020-02-03', '2020-03-09', '2020-03-12', '2016-03-16', '2020-03-18', '2020-07-06', '2021-01-04', '2021-02-18', '2021-07-26', '2022-02-24', '2022-03-07', '2022-03-08', '2022-03-09', '2022-04-25', '2022-10-24', '2022-10-28', '2023-01-30' ], 'black_swan': [ '2015-06-12', '2018-03-23', '2020-01-23', '2020-03-09', '2021-07-26', '2022-02-24' ], 'extreme_market': [ '2015-06-19', '2015-06-26', '2015-07-03', '2015-07-09', '2015-08-24', '2015-08-25', '2016-01-04', '2016-01-07', '2018-10-11', '2019-05-06', '2020-02-03', '2020-07-06', '2020-07-16', '2021-02-18', '2021-07-27', '2022-03-15', '2022-04-25', '2022-10-24' ], 'gem_reform': [ '2020-08-24', '2020-08-27', '2020-09-09', '2021-02-18', '2021-04-19', '2021-08-24' ] } for event_type, date_list in special_events_db.items(): dates = [pd.Timestamp(date) for date in date_list] marked = data_reset[data_reset['special_events'].str.contains(event_type, na=False)] expected = len(dates) actual = marked['date'].nunique() results['event_mark_issues'][event_type] = abs(actual - expected) # 4. 统计异常标记 results['anomaly_mark_stats']['mad'] = data_reset['mad_anomaly'].sum() results['anomaly_mark_stats']['kde'] = data_reset['kde_anomaly'].sum() results['anomaly_mark_stats']['vol'] = data_reset['vol_anomaly'].sum() # 5. 检查特殊情况 # 非停牌日零值 results['special_cases']['zero_volume'] = data_reset[ (data_reset['volume'] == 0) & (data_reset['suspended'] == 0) ].shape[0] # 输出验证结果 logger.info("验证结果:") if results['missing_values']: logger.warning(f"⚠️ 缺失值: {results['missing_values']}") else: logger.info("✅ 无缺失值") if results['limit_issues'] > 0: logger.warning(f"⚠️ 涨跌停标记问题: {results['limit_issues']}处") else: logger.info("✅ 涨跌停标记正确") for event_type, issues in results['event_mark_issues'].items(): if issues > 0: logger.warning(f"⚠️ {event_type}事件标记不匹配: 差异{issues}处") else: logger.info(f"✅ {event_type}事件标记正确") logger.info(f"异常标记统计 - MAD: {results['anomaly_mark_stats']['mad']}, " f"KDE: {results['anomaly_mark_stats']['kde']}, " f"波动率: {results['anomaly_mark_stats']['vol']}") if results['special_cases']['zero_volume'] > 0: logger.warning(f"⚠️ 非停牌日零成交量: {results['special_cases']['zero_volume']}处") else: logger.info("✅ 无非停牌日零成交量问题") logger.info("="*60) logger.info("数据验证测试完成") logger.info("="*60) return results def main(): logger.info("="*60) logger.info("开始执行数据加载") logger.info("="*60) loader = StockDataLoader() logger.info("获取2013年前上市且在2013~2023年间未退市的股票...") all_stocks = loader.get_all_stocks() logger.info(f"共找到{len(all_stocks)}只符合条件的股票") logger.info("过滤2014~2023年间被ST的股票...") non_st_stocks = loader.filter_st_stocks(all_stocks, '2014-01-01', '2023-12-31') logger.info(f"过滤后剩余股票: {len(non_st_stocks)}") logger.info("按市值分组...") large_cap, mid_cap, small_cap = loader.categorize_stocks(non_st_stocks) logger.info(f"分组完成: 大盘股({len(large_cap)}), 中盘股({len(mid_cap)}), 微盘股({len(small_cap)})") logger.info("随机抽取股票...") sampled_stocks = loader.sample_stocks(large_cap, mid_cap, small_cap, n=100) logger.info(f"抽样完成: 共选取{len(sampled_stocks)}只股票") logger.info("开始加载2014-2023年价格数据(前复权)...") price_data = loader.load_price_data(sampled_stocks, '2014-01-01', '2023-12-31') if price_data.empty: logger.error("数据加载失败,无有效数据") return logger.info(f"数据加载完成,共{len(price_data)}条记录") run_data_validation(price_data) logger.info("="*60) logger.info("数据加载和测试完成") logger.info("="*60) if __name__ == "__main__": main() 当前代码有什么问题?
08-16
<think>好的,我现在需要仔细阅读用户提供的Python代码,并找出其中存在的问题。这段代码主要用于股票数据的加载和处理,包括涨跌停标记、异常数据检测等功能。我需要从结构、逻辑、性能、潜在错误等多个方面进行分析。 首先,我会通读整个代码,了解各个类和函数的作用。然后,我会逐部分检查可能的问题点。以下是我的思考过程: 1. **代码结构方面**: - `StockDataLoader`类的初始化中,`_create_special_events_db`方法返回的事件日期是否正确转换为Timestamp?看起来是的,但需要确认是否所有日期都正确。 - `get_all_stocks`方法中,获取所有证券后,对股票的筛选条件是否正确?比如起始日期和退市日期的判断是否正确处理了可能的None值。 - `filter_st_stocks`方法中,使用`get_extras`获取ST状态时,是否处理了可能的错误情况? 2. **性能问题**: - 在`load_price_data`方法中,分批次获取数据,但每次处理批次时都调用`calculate_price_limits`、`mark_special_events_vectorized`、`mark_anomalies`,这些操作是否可以在合并数据后统一处理,减少重复操作? - `mark_anomalies`方法中的循环处理每个股票的group,对于大量股票可能效率低下,是否可以向量化处理? 3. **潜在的逻辑错误**: - `calculate_price_limits`中的涨跌停标记逻辑是否正确?例如,判断涨停是否触及的条件是否合理,是否有浮点数精度问题? - `mark_special_events_vectorized`是否正确地将多个事件类型合并到一个字段中? - `mark_anomalies`中的异常检测方法是否正确应用,比如KDE部分是否正确归一化数据? 4. **数据验证部分**: - `run_data_validation`中的验证逻辑是否覆盖了所有关键点?例如,涨跌停标记的验证是否正确? - 特殊事件的检查是否使用了与初始化相同的事件列表?注意到`special_events_db`在`run_data_validation`中重新定义,可能存在不一致的风险。 5. **异常处理和日志**: - 代码中是否有足够的异常处理?例如在获取市值数据时,是否有处理网络错误或API返回异常? - 日志记录是否足够详细,方便调试和追踪问题? 6. **其他问题**: - `get_all_stocks`中排除退市股票的逻辑是否正确?比如,如果股票的end_date为空,是否应该视为未退市? - `sample_stocks`中使用`np.random.choice`进行抽样,是否应该设置随机种子以保证结果可复现? - `mark_anomalies`中的异常标记是否覆盖了所有可能的异常情况?例如,是否考虑了停牌期间的异常? 现在,我需要具体分析每个可能的问题点,并总结出主要问题。</think>当前代码存在以下问题: 1. **涨跌停标记逻辑缺陷** - 在`calculate_price_limits`方法中,判断涨停是否触及的条件存在精度问题: ```python (price_data['high'] >= price_data['up_limit'] - 0.015) & (price_data['low'] <= price_data['up_limit'] + 0.015) ``` 这种宽松的边界条件可能导致误判,建议改为精确判断并处理浮点精度问题 2. **特殊事件标记不一致** - `run_data_validation`函数中重新定义的`special_events_db`与类初始化中的定义不一致,存在维护风险,例如日期列表中的`2016-03-16`在原始数据中不存在 3. **异常检测效率低下** - `mark_anomalies`方法对每个股票分组循环处理,导致性能瓶颈。建议使用向量化操作替代循环: ```python # 可优化为groupby+transform的向量化处理方式 ``` 4. **数据验证逻辑漏洞** - `run_data_validation`的涨跌停验证部分使用了`suspended == 0`条件,但`calculate_price_limits`计算时未排除停牌日,可能导致验证结果不准确 5. **随机抽样缺乏种子控制** - `sample_stocks`方法使用随机抽样但未设置随机种子,导致结果不可复现: ```python # 建议添加随机种子参数 np.random.seed(42) # 示例种子值 ``` 6. **停牌日处理不完善** - `load_price_data`中标记停牌日的`suspended`字段仅根据成交量为0判断,未考虑其他停牌情况(如临时停牌) 7. **异常标记字段未合并** - `mark_anomalies`生成三个独立的异常标记字段(mad_anomaly/kde_anomaly/vol_anomaly),缺乏综合异常评分字段 8. **内存管理问题** - 在`load_price_data`中分批次加载数据时,`data_frames`列表持续保存所有批次数据可能导致内存占用过高 9. **时间序列对齐问题** - 使用`get_price`获取数据时,不同批次股票可能存在时间序列不一致问题,未进行统一的日期对齐处理 10. **复权因子处理缺失** - `load_price_data`使用前复权参数`fq='pre'`,但未验证复权因子的准确性,可能导致历史价格数据失真
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值