pandas性能百倍提升之用字典索引或ndarray替换DataFrame索引以及内存占用分析

探讨Pandas在大数据量下的性能瓶颈及优化策略,对比DataFrame、NumPy数组与字典的效率和内存占用。

       在利用pandas进行数据分析时,DataFrame是其基本的数据结构,当数据量较小时还好,一旦数据量较大,比如几十万上百万时,这时DataFrame就会变得笨重,笨重主要体现在对其索引的操作上,而对DataFrame的索引操作又是基本的操作,所以这时,在性能上就会有很大的损失;对pandas的使用可以让我们可以直观简单的进行数据分析,但是往往会在性能上有较大的损失。当然,对于性能的损失不能一概而论,pandas具有很多的内置函数,这些函数的底层很多是c语言的python封装,如果我们可以灵活的使用这些内置函数,那么性能上的损失其实是很小的,甚至会提升性能,只是有时候在一些比较复杂的操作中,并没有对应的内置函数供我们直接调用,这时,就需要从其他的角度去考虑提升性能。

性能对比和分析

       在数据分析中,往往最消耗时间的是在loop上,一旦数据量较大,比如一个50万行的DataFrame,如果我们需要进行逐行的loop,那么时间的消耗就相当于是一个loop的50万倍,这种量级的放大是很恐怖的,所以,当我们要不可避免的使用loop的时候,就需要非常的谨慎,尽量的减少一个loop需要消耗的时间。

       在pandas中,一个loop中比较消耗时间的地方往往有两方面:一是对很少的数据量用pandas的内置函数,起步时间消耗过多,对此可以转而用python原生的方式实现,具体可以看笔者的这篇文章;二是对DataFrame的索引操作上。本文就是解决第二个问题带来的时间消耗。

       pandas中对于DataFrame的索引操作是相对低效的,所以我们不应该在一个几十万的循环中使用DataFrame的索引操作,否则会造成程序效率极其低下。为了保持DataFrame的这种数据的相对结构,我们可以有两种方式去替换DataFrame的频繁的索引操作:一,替换成numpy的ndarray;二,替换成python的字典数据结构。下面我们通过一个简单的例子来对比下这三种方式的效率,如下所示。

import time
import pandas as pd
import numpy as np

df=pd.DataFrame(np.arange(800000).reshape(200000,4),columns=list('abcd'))
t1=time.time()
s1=df.apply(lambda x:x['a']+x['d'],axis=1)
t2=time.time()
print(t2-t1)

arr=df.values
t3=time.time()
s=np.apply_along_axis(lambda x:x[0]+x[3],axis=1,arr=arr)
t4=time.time()
print(t4-t3)

dic=df.to_dict()
l=[]
t5=time.time()
for i in range(len(df)):
    l.append(dic['a'][i]+dic['d'][i])
t6=time.time()
print(t6-t5)

# output: 10.034282922744751
#         1.6413774490356445
#         0.11075925827026367

       上述例子仅仅只是一个例子,只是为了传达本文的思想而已,因为实际上可以完全可以通过df[['a','d']].sum(axis=1)函数来快速实现。从例子中我们可以看到,三种方式的运行效率相差很大。首先第一种方式,我们是直接利用DataFrame的apply方法实现逐行的loop,对于每一个loop,通过直接对Series的索引来实现两列的相加;第二种方式,我们是先将DatFrame转为ndarray,然后在numpy中采取类似的做法;第三种方式,我们先将DataFrame转为python的字典对象,然后通过for loop实现,内一个loop中直接通过字典索引实现两列的相加。

       通过对比,第一种和第二种方式之间,后者相当于把DataFrame转为numpy的ndarray再进行处理,可知numpy的ndarray的效率更高,虽然pandas也是基于numpy的,但是由于pandas进行了进一步的封装,所以效率自然更低,因此,我们总是可以用numpy来处理比较耗时的pandas任务,特别是在数据量很大的时候,且无法通过内置函数直接办到的任务,那么numpy提升的效率可以很显著。我们再看第一种和第三种方式的对比,后者是先把DataFrame先转为字典,然后通过for loop实现,这里两者的区别还在于前者是apply,后者是for loop,但其实apply和for loop在非内置函数的简单操作上的效率是差不多的,甚至apply会更慢些,但是在pandas内置函数的操作上apply会快很多,具体可看笔者的这篇文章;所以这里两者的差距我们可以认为是由索引造成的,明显的,字典索引相对于DataFrame的索引,前者的效率会高很多,在python中,字典索引几乎是最为高效的索引方式了,因此,我们把DataFrame转为字典并对字典进行操作,这种方式提升的效率效果是最好的,性能几乎提升近百倍!

内存占用分析

        上面只是单纯的从性能上进行分析对比,下面还要对比一下这三种方式在内存占用上的区别。首先,对于DataFrame和ndarray,由于前者是基于后者的,因此两者的内存占用其实是产不多的,而且由于前者对后者进行了封装,所以严格来讲,DataFrame的内存占用会大一些,ndarray的内存占用会小一些,但是差别不大。由于ndarray是对内存结构进行了优化的,所以相比于python的字典对象,储存相同的信息,字典对象的内存占用会大很多,当然,这并不是说就是字典对象的缺点,因为字典对象可以存储不同类型的对象,而ndarray只可以存储同一类型的元素,因此,不同的设计方式使得两者在内存占用上不太一样,各有优劣。但是当我们通过把DataFrame转为dict时,相比于DataFrame,dict可储存多类型对象的优势不再,这时dict对象显著变大了,这就是其一个缺点了。具体的可看如下结果。

import sys

size=total_size(dic)
print(sys.getsizeof(df))
print(arr.nbytes)
print(size)

# output: 3200104
#         3200000
#         86715124

       可以看到,df占据了3200104字节,arr占据了3200000字节,而dic则占据了86715124字节。这里对于arr和dic不能直接用sys.getsizeof获取其实际内存,因为sys.getsizeof只能获取对象的本身的内存,如果对象是个容器,容器内部的内容内存占用是无法获取到的;字典实际上是个容器,而ndarray又有自己的内存设计,可以通过其nbytes属性获取,而df则兼容了sys.getsizeof这个接口。所以这里对于dic,笔者用的是total_size这个函数,关于该函数的定义,可以查看笔者的这篇文章

       Anyway,最后我们看到的是,字典占据的内存暴涨,相当于是df和ndarray的二十多倍。虽然字典在访问速度上极快,但是内存占用也是极高的,典型的用空间换时间。所以,当我们考虑性能的时候,也要兼考虑内存占用,特别是在大数据量的情况下,如果我们盲目的将df转为字典,那么可能内存就承受不住了。所以如果内存充分的足够,那么可以转为字典,但是如果内存并不是那么充裕,那么我们可以采用转为ndarray的方式去提高性能,因为ndarray的内存占用是三者里最佳的,而且在性能上也是很不错的,因此,在大数据量的情形下,考虑到内存占用,ndarray往往是一个更佳的选择!

# stock_factor_processing_dai.py import dai import pandas as pd import numpy as np from typing import List, Dict, Optional, Tuple class StockFactorProcessorDAI: """ 基于DAI的数据拉取与因子处理类(修正版) 说明: - 使用 dai.query(sql, filters=...) 拉取分区表数据,推荐通过 filters 传入 date/instrument 分区条件以加速查询 - DataFrame 使用 MultiIndex (date, instrument) """ def __init__(self): pass def fetch_price_data( self, instruments: List[str], start_date: str, end_date: str, fields: List[str] = ["open", "high", "low", "close", "volume", "amount"] ) -> pd.DataFrame: """ 使用 dai.query 获取日线级别行情 (示例使用 cn_stock_bar1d) 注意: - 不要在 SQL 里重复对 date 做过滤。如果使用 filters 参数,请把 date 放到 filters 里(避免在 SQL 中再写 date 条件)。 - filters 会用于查询里的所有表,能显著加速分区表读取。 返回: - pandas.DataFrame, index 为 ['date','instrument'] 的 MultiIndex """ # 安全检查 if not instruments: raise ValueError("instruments 列表不能为空") # 将 instrument 列表格式化为 SQL IN 子句(如果需要先做小量过滤) # 这里我们仍然使用 filters 来加速按 date 分区 instruments_sql = ",".join([f"'{s}'" for s in instruments]) sql = f""" SELECT date, instrument, {', '.join(fields)} FROM cn_stock_bar1d WHERE instrument IN ({instruments_sql}) """ filters = {"date": [start_date, end_date]} try: res = dai.query(sql, filters=filters).df() except Exception as e: print("DAI 查询失败:", e) return pd.DataFrame() if res.empty: return res # 确保 date 为 datetime if not pd.api.types.is_datetime64_any_dtype(res["date"]): res["date"] = pd.to_datetime(res["date"]) # 设置 MultiIndex 并排序 res = res.set_index(["date", "instrument"]).sort_index() return res def build_full_index( self, raw_df: pd.DataFrame, start_date: Optional[str] = None, end_date: Optional[str] = None ) -> pd.DataFrame: """ 构建连续日期 x instruments 的 MultiIndex 并重新索引原始数据以填补缺失 """ if raw_df.empty: return raw_df # 得到 instruments 日期范围 instruments = raw_df.index.get_level_values("instrument").unique() if start_date is None: start_date = raw_df.index.get_level_values("date").min() if end_date is None: end_date = raw_df.index.get_level_values("date").max() date_range = pd.date_range(start=start_date, end=end_date, freq="D") full_index = pd.MultiIndex.from_product([date_range, instruments], names=["date", "instrument"]) full_df = raw_df.reindex(full_index) return full_df def compute_base_factors(self, price_df: pd.DataFrame) -> Dict[str, pd.Series]: """ 计算基础因子(示例) 输入 price_df 为 MultiIndex(date, instrument) 的 DataFrame, 包含 close, volume, amount 等列 返回 dict: {factor_name: pd.Series (同 price_df.index)} """ if price_df.empty: return {} # 使用 groupby(level='instrument') 做按票的时序计算 close = price_df["close"] vol = price_df["volume"] amt = price_df.get("amount", pd.Series(index=close.index, dtype=float)) # 20日动量 (close / close.shift(20) - 1) momentum = close.groupby(level="instrument").pct_change(periods=20) # 简化价值因子:close / amount(注意 amount 为成交额,通常需要市值等数据更合理) value = close / amt.replace({0: np.nan}) # 波动率:近20日收益率的样本标准差 ret = close.groupby(level="instrument").pct_change() volatility = ret.groupby(level="instrument").rolling(20).std().droplevel(0) # 流动性:当日成交量 / 过去20日平均成交量 avg_vol_20 = vol.groupby(level="instrument").rolling(20).mean().droplevel(0) liquidity = vol / avg_vol_20 factors = { "momentum": momentum, "value": value, "volatility": volatility, "liquidity": liquidity } return factors def standardize_factors( self, factors: Dict[str, pd.Series], mode: str = "cross" # 'cross' 按日期截面标准化; 'global' 全样本 z-score ) -> Dict[str, pd.Series]: """ 因子标准化与缩尾 mode: - 'cross' : 在每个 date 的截面上做 z-score 标准化(通常用于因子在截面上比较) - 'global': 全样本 z-score 返回 standardized_factors dict """ std_factors = {} for name, s in factors.items(): if s is None or s.empty: std_factors[name] = s continue if mode == "cross": # 对每个 date 做截面 z-score def zscore_group(x): m = x.mean() sd = x.std() if pd.isna(sd) or sd == 0: return x * 0.0 return (x - m) / sd # groupby level date z = s.groupby(level="date").apply(lambda g: zscore_group(g.droplevel(0))) # groupby.apply 会把 index 变成 (date, instrument) 两级, 恢复 index z.index = s.index # 对齐回原索引 standardized = z else: # global z-score mean = s.mean() std = s.std() standardized = (s - mean) / std if std and std > 0 else s * 0.0 # 缩尾(clip): 1% - 99% lower = standardized.groupby(level="date").quantile(0.01).reindex(standardized.index, level=0) upper = standardized.groupby(level="date").quantile(0.99).reindex(standardized.index, level=0) # 对齐后 clip standardized = standardized.clip(lower=lower, upper=upper) std_factors[name] = standardized return std_factors def build_composite_factor( self, standardized_factors: Dict[str, pd.Series], weights: Optional[Dict[str, float]] = None ) -> pd.Series: """ 复合因子 =标准化因子加权 如果 weights 未提供,采用等权 """ if not standardized_factors: return pd.Series(dtype=float) names = list(standardized_factors.keys()) if weights is None: w = {n: 1.0 / len(names) for n in names} else: w = weights # 先取索引(各因子索引应当一致) idx = standardized_factors[names[0]].index composite = pd.Series(0.0, index=idx) for n in names: factor = standardized_factors[n].reindex(idx) weight = w.get(n, 0.0) composite = composite.add(factor.fillna(0.0) * weight, fill_value=0.0) return composite def calculate_daily_rank_ic( self, factor: pd.Series, forward_returns: pd.Series ) -> pd.Series: """ 计算每日 Rank-IC(按日期截面): - 对每个 date,计算 factor(rank) 与 forward_returns(rank) 的 Spearman 相关系数 返回: - pd.Series, index=date, 每日 rank_ic """ # 合并并按 date 分组 if factor.empty or forward_returns.empty: return pd.Series(dtype=float) df = pd.concat([factor.rename("factor"), forward_returns.rename("fut_ret")], axis=1) df = df.dropna(how="any") # 同一截面上必须同时有因子未来收益 if df.empty: return pd.Series(dtype=float) def rank_ic_for_date(g): # 以 rank 计算相关性(Spearman) return g["factor"].rank().corr(g["fut_ret"].rank()) ic_series = df.groupby(level="date").apply(lambda g: rank_ic_for_date(g.droplevel(0))) ic_series.index = pd.to_datetime(ic_series.index) return ic_series def quantile_performance( self, factor: pd.Series, forward_returns: pd.Series, quantiles: int = 5 ) -> pd.DataFrame: """ 分层(Quantile)回测:按 date 划分截面 quantiles,然后计算每层未来收益均值及样本数 返回 DataFrame: index=date, columns = Q1_mean, Q1_count, Q2_mean, ... """ if factor.empty or forward_returns.empty: return pd.DataFrame() df = pd.concat([factor.rename("factor"), forward_returns.rename("fut_ret")], axis=1) # 每日分层 results = [] for date, grp in df.groupby(level="date"): data = grp.droplevel(0).dropna() if data.empty or len(data) < quantiles: continue try: labels = range(1, quantiles + 1) data["q"] = pd.qcut(data["factor"], q=quantiles, labels=labels, duplicates="drop") except Exception: # qcut 可能因重复值失败,退化为 rank-based binning data["q"] = pd.cut(data["factor"].rank(method="first"), bins=quantiles, labels=labels) # 计算每组均值样本数 stats = {} for q in labels: sel = data["q"] == q stats[f"Q{q}_mean"] = data.loc[sel, "fut_ret"].mean() stats[f"Q{q}_count"] = sel.sum() stats["date"] = pd.to_datetime(date) results.append(stats) if not results: return pd.DataFrame() res_df = pd.DataFrame(results).set_index("date").sort_index() return res_df def detect_abnormal_fluctuations( self, price_df: pd.DataFrame, threshold: float =
最新发布
11-01
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值