# stock_factor_processing_dai.py
import dai
import pandas as pd
import numpy as np
from typing import List, Dict, Optional, Tuple
class StockFactorProcessorDAI:
"""
基于DAI的数据拉取与因子处理类(修正版)
说明:
- 使用 dai.query(sql, filters=...) 拉取分区表数据,推荐通过 filters 传入 date/instrument 分区条件以加速查询
- DataFrame 使用 MultiIndex (date, instrument)
"""
def __init__(self):
pass
def fetch_price_data(
self,
instruments: List[str],
start_date: str,
end_date: str,
fields: List[str] = ["open", "high", "low", "close", "volume", "amount"]
) -> pd.DataFrame:
"""
使用 dai.query 获取日线级别行情 (示例使用 cn_stock_bar1d)
注意:
- 不要在 SQL 里重复对 date 做过滤。如果使用 filters 参数,请把 date 放到 filters 里(避免在 SQL 中再写 date 条件)。
- filters 会用于查询里的所有表,能显著加速分区表读取。
返回:
- pandas.DataFrame, index 为 ['date','instrument'] 的 MultiIndex
"""
# 安全检查
if not instruments:
raise ValueError("instruments 列表不能为空")
# 将 instrument 列表格式化为 SQL IN 子句(如果需要先做小量过滤)
# 这里我们仍然使用 filters 来加速按 date 分区
instruments_sql = ",".join([f"'{s}'" for s in instruments])
sql = f"""
SELECT date, instrument, {', '.join(fields)}
FROM cn_stock_bar1d
WHERE instrument IN ({instruments_sql})
"""
filters = {"date": [start_date, end_date]}
try:
res = dai.query(sql, filters=filters).df()
except Exception as e:
print("DAI 查询失败:", e)
return pd.DataFrame()
if res.empty:
return res
# 确保 date 为 datetime
if not pd.api.types.is_datetime64_any_dtype(res["date"]):
res["date"] = pd.to_datetime(res["date"])
# 设置 MultiIndex 并排序
res = res.set_index(["date", "instrument"]).sort_index()
return res
def build_full_index(
self,
raw_df: pd.DataFrame,
start_date: Optional[str] = None,
end_date: Optional[str] = None
) -> pd.DataFrame:
"""
构建连续日期 x instruments 的 MultiIndex 并重新索引原始数据以填补缺失
"""
if raw_df.empty:
return raw_df
# 得到 instruments 和日期范围
instruments = raw_df.index.get_level_values("instrument").unique()
if start_date is None:
start_date = raw_df.index.get_level_values("date").min()
if end_date is None:
end_date = raw_df.index.get_level_values("date").max()
date_range = pd.date_range(start=start_date, end=end_date, freq="D")
full_index = pd.MultiIndex.from_product([date_range, instruments], names=["date", "instrument"])
full_df = raw_df.reindex(full_index)
return full_df
def compute_base_factors(self, price_df: pd.DataFrame) -> Dict[str, pd.Series]:
"""
计算基础因子(示例)
输入 price_df 为 MultiIndex(date, instrument) 的 DataFrame, 包含 close, volume, amount 等列
返回 dict: {factor_name: pd.Series (同 price_df.index)}
"""
if price_df.empty:
return {}
# 使用 groupby(level='instrument') 做按票的时序计算
close = price_df["close"]
vol = price_df["volume"]
amt = price_df.get("amount", pd.Series(index=close.index, dtype=float))
# 20日动量 (close / close.shift(20) - 1)
momentum = close.groupby(level="instrument").pct_change(periods=20)
# 简化价值因子:close / amount(注意 amount 为成交额,通常需要市值等数据更合理)
value = close / amt.replace({0: np.nan})
# 波动率:近20日收益率的样本标准差
ret = close.groupby(level="instrument").pct_change()
volatility = ret.groupby(level="instrument").rolling(20).std().droplevel(0)
# 流动性:当日成交量 / 过去20日平均成交量
avg_vol_20 = vol.groupby(level="instrument").rolling(20).mean().droplevel(0)
liquidity = vol / avg_vol_20
factors = {
"momentum": momentum,
"value": value,
"volatility": volatility,
"liquidity": liquidity
}
return factors
def standardize_factors(
self,
factors: Dict[str, pd.Series],
mode: str = "cross" # 'cross' 按日期截面标准化; 'global' 全样本 z-score
) -> Dict[str, pd.Series]:
"""
因子标准化与缩尾
mode:
- 'cross' : 在每个 date 的截面上做 z-score 标准化(通常用于因子在截面上比较)
- 'global': 全样本 z-score
返回 standardized_factors dict
"""
std_factors = {}
for name, s in factors.items():
if s is None or s.empty:
std_factors[name] = s
continue
if mode == "cross":
# 对每个 date 做截面 z-score
def zscore_group(x):
m = x.mean()
sd = x.std()
if pd.isna(sd) or sd == 0:
return x * 0.0
return (x - m) / sd
# groupby level date
z = s.groupby(level="date").apply(lambda g: zscore_group(g.droplevel(0)))
# groupby.apply 会把 index 变成 (date, instrument) 两级, 恢复 index
z.index = s.index # 对齐回原索引
standardized = z
else:
# global z-score
mean = s.mean()
std = s.std()
standardized = (s - mean) / std if std and std > 0 else s * 0.0
# 缩尾(clip): 1% - 99%
lower = standardized.groupby(level="date").quantile(0.01).reindex(standardized.index, level=0)
upper = standardized.groupby(level="date").quantile(0.99).reindex(standardized.index, level=0)
# 对齐后 clip
standardized = standardized.clip(lower=lower, upper=upper)
std_factors[name] = standardized
return std_factors
def build_composite_factor(
self,
standardized_factors: Dict[str, pd.Series],
weights: Optional[Dict[str, float]] = None
) -> pd.Series:
"""
复合因子 =标准化因子加权和
如果 weights 未提供,采用等权
"""
if not standardized_factors:
return pd.Series(dtype=float)
names = list(standardized_factors.keys())
if weights is None:
w = {n: 1.0 / len(names) for n in names}
else:
w = weights
# 先取索引(各因子索引应当一致)
idx = standardized_factors[names[0]].index
composite = pd.Series(0.0, index=idx)
for n in names:
factor = standardized_factors[n].reindex(idx)
weight = w.get(n, 0.0)
composite = composite.add(factor.fillna(0.0) * weight, fill_value=0.0)
return composite
def calculate_daily_rank_ic(
self,
factor: pd.Series,
forward_returns: pd.Series
) -> pd.Series:
"""
计算每日 Rank-IC(按日期截面):
- 对每个 date,计算 factor(rank) 与 forward_returns(rank) 的 Spearman 相关系数
返回:
- pd.Series, index=date, 每日 rank_ic
"""
# 合并并按 date 分组
if factor.empty or forward_returns.empty:
return pd.Series(dtype=float)
df = pd.concat([factor.rename("factor"), forward_returns.rename("fut_ret")], axis=1)
df = df.dropna(how="any") # 同一截面上必须同时有因子和未来收益
if df.empty:
return pd.Series(dtype=float)
def rank_ic_for_date(g):
# 以 rank 计算相关性(Spearman)
return g["factor"].rank().corr(g["fut_ret"].rank())
ic_series = df.groupby(level="date").apply(lambda g: rank_ic_for_date(g.droplevel(0)))
ic_series.index = pd.to_datetime(ic_series.index)
return ic_series
def quantile_performance(
self,
factor: pd.Series,
forward_returns: pd.Series,
quantiles: int = 5
) -> pd.DataFrame:
"""
分层(Quantile)回测:按 date 划分截面 quantiles,然后计算每层未来收益均值及样本数
返回 DataFrame: index=date, columns = Q1_mean, Q1_count, Q2_mean, ...
"""
if factor.empty or forward_returns.empty:
return pd.DataFrame()
df = pd.concat([factor.rename("factor"), forward_returns.rename("fut_ret")], axis=1)
# 每日分层
results = []
for date, grp in df.groupby(level="date"):
data = grp.droplevel(0).dropna()
if data.empty or len(data) < quantiles:
continue
try:
labels = range(1, quantiles + 1)
data["q"] = pd.qcut(data["factor"], q=quantiles, labels=labels, duplicates="drop")
except Exception:
# qcut 可能因重复值失败,退化为 rank-based binning
data["q"] = pd.cut(data["factor"].rank(method="first"), bins=quantiles, labels=labels)
# 计算每组均值和样本数
stats = {}
for q in labels:
sel = data["q"] == q
stats[f"Q{q}_mean"] = data.loc[sel, "fut_ret"].mean()
stats[f"Q{q}_count"] = sel.sum()
stats["date"] = pd.to_datetime(date)
results.append(stats)
if not results:
return pd.DataFrame()
res_df = pd.DataFrame(results).set_index("date").sort_index()
return res_df
def detect_abnormal_fluctuations(
self,
price_df: pd.DataFrame,
threshold: float =
最新发布