torch.quantile or np.quantile的计算

本文介绍如何使用PyTorch计算张量的分位数,包括当分位数位置为非整数时的插值计算方法。通过具体示例展示了不同情况下的计算过程。
部署运行你感兴趣的模型镜像

torch的文档详细说明了quantile的计算方法。

主要是将q的范围[0, 1] 转成输入index的范围[0, n]。 也就是说,将q 乘 n。然后插值计算。

quantile位置不是整数

a = torch.tensor([0.0, 1.1, 2.1, 3.1])
q_result = torch.quantile(a, torch.tensor([0.1, 0.9]))
# tensor([0.3300, 2.8000])

a的index范围: [0, 3]
两个q值0.1, 0.9 都乘以3之后得到0.3, 2.7。

  • 0.3在0和1之间,需要插值。缺省插值方法为a + (b-a) * fraction。
0.0+ (1.1 - 0.0)*0.3 = 0.3300
  • 2.7在2和3之间
2.1 + (3.1 - 2.1 ) * 0.7 = 2.8000

quantile位置为整数

不需要插值。
例如

a = torch.tensor([0.1, 0.2, 0.3, 0.33, 0.9])
q2 = torch.quantile(a, torch.tensor([0.5]))
# 0.3 

0.5 * 4 = 2, 因此取a[2]

您可能感兴趣的与本文相关的镜像

PyTorch 2.8

PyTorch 2.8

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

帮我看下这个框架对不对,让后如果要用lstm和xgboot混合模型来预测可以嘛?可以话怎么弄,另外加入jvquant的l2数据获取--- 1️⃣ 先把数据源跑通(最优先) 任务 目的 现成资源 ✅ 申请 MiniQMT 实盘权限 拿到真实行情 + 交易接口 找券商(国金、华鑫、中泰等)开通 QMT 实盘,勾选“量化接口” ✅ 用 xtquant 取到任意 1 只股票 1 天 1 分钟线 验证数据链路 `pip install xtquant` → 跑官方 demo:`from xtquant import xtdata; xtdata.download_history_data('000001.SZ', period='1m', start_time='20250724')` ✅ 用 akshare 补股东户数 弥补 MiniQMT 没有基本面 `pip install akshare` → `ak.stock_zh_a_gdhs(symbol="300539")` --- 2️⃣ 把“特征函数”写完整(第二步) 你现在缺的 3 个核心函数,我都给你补成可运行的雏形,先跑通再优化: ```python # file: zhuang_features.py from xtquant import xtdata import akshare as ak import pandas as pd import numpy as np def get_chip_density(code, end_date, n=90): """筹码集中度:90%成本区间 / 中位价""" bars = xtdata.get_market_data([code], end_time=end_date, count=n, period='1d')[code] closes = bars['close'].dropna() if len(closes) < n*0.8: return np.nan p90 = closes.quantile(0.95) # 90%分位 p10 = closes.quantile(0.05) p50 = closes.median() return (p90 - p10) / p50 def get_holder_change(code): """最近一期股东户数环比变化率""" try: df = ak.stock_zh_a_gdhs(symbol=code.replace('.SZ', '').replace('.SH', '')) # 取最近两期 latest = df.iloc[0]['股东户数'] prev = df.iloc[1]['股东户数'] return (latest - prev) / prev except: return np.nan def get_breakout(code, end_date): """最近是否缩量突破 20 日新高""" bars = xtdata.get_market_data([code], end_time=end_date, count=20, period='1d')[code] vol = bars['volume'] price = bars['close'] vol_ratio = vol.iloc[-1] / vol.mean() is_high = price.iloc[-1] == price.max() return 1 if (vol_ratio < 0.8 and is_high) else 0 ``` --- 3️⃣ 把“训练样本”做出来(第三步) 任务 具体动作 ✅ 标注庄股 用 Excel 把历史上 20 只公认庄股(如 300539、002995、603123 等)在启动日打标签 1,其余随机 200 只正常股打 0 ✅ 批量跑特征 用上面 3 个函数,把每只股票在启动日前 1 天的特征都跑成 CSV,格式:`code,chip_density,holder_change,breakout,label` ✅ 存到 `dataset.csv` 后面 PyTorch 直接 `pd.read_csv('dataset.csv')` --- 4️⃣ 把“模型”写成可训练脚本(第四步) ```python # file: train_model.py import torch, pandas as pd from sklearn.model_selection import train_test_split df = pd.read_csv('dataset.csv').dropna() X = torch.tensor(df[['chip_density','holder_change','breakout']].values, dtype=torch.float32) y = torch.tensor(df['label'].values, dtype=torch.float32).unsqueeze(1) X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2) class ZhuangNet(torch.nn.Module): def __init__(self): super().__init__() self.net = torch.nn.Sequential( torch.nn.Linear(3, 16), torch.nn.ReLU(), torch.nn.Linear(16, 1), torch.nn.Sigmoid() ) def forward(self, x): return self.net(x) model = ZhuangNet() opt = torch.optim.Adam(model.parameters(), 1e-3) criterion = torch.nn.BCELoss() for epoch in range(200): opt.zero_grad() loss = criterion(model(X_train), y_train) loss.backward() opt.step() if epoch % 20 == 0: with torch.no_grad(): val_loss = criterion(model(X_val), y_val) print(epoch, loss.item(), val_loss.item()) torch.save(model.state_dict(), 'zhuang_net.pt') ``` --- 5️⃣ 把“实时扫描”写成脚本(第五步) ```python # file: realtime_scan.py import torch, datetime, json, requests from zhuang_features import get_chip_density, get_holder_change, get_breakout model = ZhuangNet() model.load_state_dict(torch.load('zhuang_net.pt')) model.eval() codes = xtdata.get_stock_list_in_sector('沪深A股') today = datetime.datetime.today().strftime('%Y%m%d') found = [] for code in codes: try: f1 = get_chip_density(code, today) f2 = get_holder_change(code) f3 = get_breakout(code, today) if any(pd.isna([f1,f2,f3])): continue prob = model(torch.tensor([[f1,f2,f3]])).item() if prob > 0.8: found.append((code, round(prob,3))) except Exception as e: print(code, e) # 推送到企业微信机器人 webhook = 'https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=你的key' data = {"msgtype":"text","text":{"content": f"发现庄股候选:{found}"}} requests.post(webhook, json=data) ``` --- 6️⃣ 把“实盘执行”加一道保险(第六步) - ✅ 先用模拟盘跑 2 周,确认信号无未来函数(比如把 `end_date` 设成昨日,再对比今天的实际走势)。 - ✅ 用 `schedule` 或 `crontab` 每天 14:50 自动跑一次,盘中不干扰。 - ✅ 加熔断:单只股票当日成交额 < 5000 万直接跳过,避免流动性陷阱。 --- 7️⃣ 把“日志 & 回测”补齐(最后一步) - ✅ 把每日扫描结果写进 `sqlite` 或 `csv`,方便后续回测胜率。 - ✅ 用 `backtrader` 对 2023-2024 全市场跑一遍,看信号后 5 日收益分布,确定最终阈值(可能 0.7 比 0.8 更好)。 --- 🚩一句话总结 你现在缺的只是“把骨架连上血肉的七步流程”。 按上面 1→7 的顺序逐条完成: 申请接口 → 补特征 → 造样本 → 训练 → 实时扫描 → 模拟盘 → 回测, 最终就能在 MiniQMT 上全自动跑起来。
07-27
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值