文章目录
- Toolkit 作用
- Toolkit 逐函数解析
- 1. 获取默认配置
- 2. update_config
- 3. config
- 4. `__init__`
- 5. get_reddit_news
- 6. get_finnhub_news
- 7. get_reddit_stock_info
- 8. get_chinese_social_sentiment
- 9. get_finnhub_company_insider_sentiment
- 10. get_YFin_data
- 11. get_YFin_data_online
- 12. get_stockstats_indicators_report
- 13. get_stockstats_indicators_report_online
- 14. get_simfin_balance_sheet 资产负债表
- 15. get_simfin_income_statement利润表
- 16. get_simfin_cashflow_statement 现金流量表
- 17. get_simfin_ratios 财务指标比率
- 18. get_simfin_company_info 公司基本信息
- 19. get_simfin_shareprices 历史股价数据
- 20. get_fundamentals_openai 财务+估值
- 21. get_china_fundamentals A 股财务数据
- 22. get_stock_fundamentals_unified 整合财务+估值
- create_msg_delete函数(非Toolkit类)
- ChromaDBManager
- FinancialSituationMemory类
Toolkit 作用
-
Toolkit(tradingagents/agents/utils/agent_utils.py)是 TradingAgents 框架里的工具集,封装了各种外部数据源的调用接口(新闻、财务、行情、情绪、指标等),并通过@tool装饰器暴露给 LLM 使用。 -
TradingAgents Toolkit 功能总览表
| 模块类别 | 主要函数/工具 | 功能描述 | 数据来源 | 状态 |
|---|---|---|---|---|
| 市场行情 (Market) | get_market_data_unified | 获取股票/指数的市场行情数据(价格、成交量、技术指标) | Yahoo Finance, Tushare, 东方财富等 | |
get_price_history | 获取历史 K 线数据 | 同上 | ||
get_technical_indicators | 生成技术指标(MA, RSI, MACD 等) | 本地计算 | ||
| 新闻 (News) | get_news_articles | 抓取金融新闻 | Google News, 东方财富新闻等 | |
summarize_news | 对新闻做摘要 | LLM 处理 | ||
| 社交媒体 (Social) | get_social_sentiment | 获取社交媒体情绪(Twitter, 微博) | API / 爬虫 | 数据源可能受限 |
analyze_sentiment | 使用 LLM 对评论、帖子做情绪分析 | LLM | ||
| 基本面 (Fundamentals) | get_stock_fundamentals_unified | 统一接口:获取美股/A股/港股的财务数据与估值 | Yahoo Finance, SimFin, Tushare, AKShare, 东方财富 | |
get_fundamentals_openai | 使用 OpenAI Agent 调用财务数据 | OpenAI Agent | 废弃 | |
get_china_fundamentals | 获取中国 A 股财务数据(旧接口) | Tushare, AKShare | 废弃 | |
| 风险分析 (Risk) | get_risk_metrics | 计算风险指标(波动率、夏普比率、回撤) | 本地计算 | |
stress_test | 压力测试(不同市场情景下的资产表现) | 本地模拟 | ||
portfolio_risk_analysis | 投资组合风险分析 | 本地计算 + 历史行情 |
Toolkit 逐函数解析
- 大部分函数都是通过
tradingagents.dataflows.interface获取数据。
1. 获取默认配置
from tradingagents.default_config import DEFAULT_CONFIG
_config = DEFAULT_CONFIG.copy()
- 获取默认配置,默认配置如下:
import os
DEFAULT_CONFIG = {
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"),
"data_dir": os.path.join(os.path.expanduser("~"), "Documents", "TradingAgents", "data"),
"data_cache_dir": os.path.join(
os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
"dataflows/data_cache",
),
# LLM settings
"llm_provider": "openai",
"deep_think_llm": "o4-mini",
"quick_think_llm": "gpt-4o-mini",
"backend_url": "https://api.openai.com/v1",
# Debate and discussion settings
"max_debate_rounds": 1,
"max_risk_discuss_rounds": 1,
"max_recur_limit": 100,
# Tool settings
"online_tools": True,
# Note: Database and cache configuration is now managed by .env file and config.database_manager
# No database/cache settings in default config to avoid configuration conflicts
}
2. update_config
@classmethod
def update_config(cls, config):
"""Update the class-level configuration."""
cls._config.update(config)
- 更新
Toolkit的全局配置(类级别),比如数据源 API key、默认参数等。
3. config
@property
def config(self):
"""Access the configuration."""
return self._config
- 返回当前配置。
4. __init__
def __init__(self, config=None):
if config:
self.update_config(config)
- 构造函数,可传入配置并更新默认配置。
5. get_reddit_news
@staticmethod
@tool
def get_reddit_news(
curr_date: Annotated[str, "Date you want to get news for in yyyy-mm-dd format"],
) -> str:
"""
Retrieve global news from Reddit within a specified time frame.
Args:
curr_date (str): Date you want to get news for in yyyy-mm-dd format
Returns:
str: A formatted dataframe containing the latest global news from Reddit in the specified time frame.
"""
global_news_result = interface.get_reddit_global_news(curr_date, 7, 5)
return global_news_result
- 从 Reddit 获取某天起过去 7 天内的 全球新闻(最多 5 条),主要用于宏观舆情分析。
6. get_finnhub_news
@staticmethod
@tool
def get_finnhub_news(
ticker: Annotated[str, "Search query of a company, e.g. 'AAPL, TSM, etc."],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
):
"""
Retrieve the latest news about a given stock from Finnhub within a date range
Args:
ticker (str): Ticker of a company. e.g. AAPL, TSM
start_date (str): Start date in yyyy-mm-dd format
end_date (str): End date in yyyy-mm-dd format
Returns:
str: A formatted dataframe containing news about the company within the date range from start_date to end_date
"""
end_date_str = end_date
end_date = datetime.strptime(end_date, "%Y-%m-%d")
start_date = datetime.strptime(start_date, "%Y-%m-%d")
look_back_days = (end_date - start_date).days
finnhub_news_result = interface.get_finnhub_news(
ticker, end_date_str, look_back_days
)
return finnhub_news_result
- 调用 Finnhub API 获取指定股票在
start_date ~ end_date的新闻。
7. get_reddit_stock_info
@staticmethod
@tool
def get_reddit_stock_info(
ticker: Annotated[str, "Ticker of a company. e.g. AAPL, TSM"],
curr_date: Annotated[str, "Current date you want to get news for"],
) -> str:
"""
Retrieve the latest news about a given stock from Reddit, given the current date.
Args:
ticker (str): Ticker of a company. e.g. AAPL, TSM
curr_date (str): current date in yyyy-mm-dd format to get news for
Returns:
str: A formatted dataframe containing the latest news about the company on the given date
"""
stock_news_results = interface.get_reddit_company_news(ticker, curr_date, 7, 5)
return stock_news_results
- 从 Reddit 获取某只股票的近期新闻和讨论。
8. get_chinese_social_sentiment
@staticmethod
@tool
def get_chinese_social_sentiment(
ticker: Annotated[str, "股票代码,例如 '600519.SS' 或 '000001.SZ'"],
curr_date: Annotated[str, "要获取的日期,yyyy-mm-dd 格式"],
) -> str:
"""
获取某只股票在中国社交媒体上的情绪分析。
Args:
ticker (str): 股票代码 (A股格式)
curr_date (str): 要获取的日期
Returns:
str: 包含中国社交媒体情绪分析的格式化 dataframe
"""
try:
sentiment_result = interface.get_chinese_sentiment(ticker, curr_date, 7)
except Exception as e:
logger.warning(f"中国舆情数据获取失败,回退到 Reddit: {e}")
sentiment_result = interface.get_reddit_company_news(ticker, curr_date, 7, 5)
return sentiment_result
- 获取 A 股股票在 中国本土社交媒体 上的舆情/情绪数据。
- 如果中国数据源失败 → 自动回退到 Reddit。
- 常用于国内公司投资情绪分析。
9. get_finnhub_company_insider_sentiment
@staticmethod
@tool
def get_finnhub_company_insider_sentiment(
ticker: Annotated[str, "股票代码,例如 'AAPL', 'TSM'"],
) -> str:
"""
获取公司内部人买卖行为的情绪数据。
Args:
ticker (str): 公司代码
Returns:
str: 格式化的内部人情绪数据
"""
insider_sentiment = interface.get_finnhub_insider_sentiment(ticker)
return insider_sentiment
- 调用 Finnhub API 获取 内部人买卖股票的情绪指标(比如 CEO、CFO 买入/卖出, 属于“聪明钱”指标)。
10. get_YFin_data
@staticmethod
@tool
def get_YFin_data(
ticker: Annotated[str, "公司股票代码,例如 'AAPL', 'TSM'"],
period: Annotated[str, "数据周期,例如 '1mo', '6mo', '1y'"],
) -> str:
"""
获取 Yahoo Finance 历史行情数据。
Args:
ticker (str): 股票代码
period (str): 时间范围
Returns:
str: 格式化的历史行情 dataframe
"""
df = interface.get_yahoo_finance_history(ticker, period)
return df.to_string()
- 从 Yahoo Finance 获取某只股票的历史行情(K 线数据)。
11. get_YFin_data_online
@staticmethod
@tool
def get_YFin_data_online(
ticker: Annotated[str, "公司股票代码"],
start_date: Annotated[str, "开始日期 yyyy-mm-dd"],
end_date: Annotated[str, "结束日期 yyyy-mm-dd"],
) -> str:
"""
获取 Yahoo Finance 区间行情。
"""
df = interface.get_yahoo_finance_range(ticker, start_date, end_date)
return df.to_string()
- 类似
get_YFin_data,但支持 指定开始和结束日期。
12. get_stockstats_indicators_report
@staticmethod
@tool
def get_stockstats_indicators_report(
ticker: Annotated[str, "公司股票代码"],
period: Annotated[str, "分析周期,例如 '6mo'"],
) -> str:
"""
获取技术指标分析报告(基于 stockstats)。
"""
df = interface.get_stockstats_indicators(ticker, period)
return df.to_string()
- 基于 stockstats 库计算技术指标(均线、RSI、MACD 等)。
13. get_stockstats_indicators_report_online
@staticmethod
@tool
def get_stockstats_indicators_report_online(
ticker: Annotated[str, "公司股票代码"],
start_date: Annotated[str, "开始日期"],
end_date: Annotated[str, "结束日期"],
) -> str:
"""
获取指定区间的技术指标分析。
"""
df = interface.get_stockstats_indicators_range(ticker, start_date, end_date)
return df.to_string()
get_stockstats_indicators_report类似,但支持 自定义区间。
14. get_simfin_balance_sheet 资产负债表
@staticmethod
@tool
def get_simfin_balance_sheet(
ticker: Annotated[str, "公司股票代码,例如 'AAPL', 'TSM'"],
report_type: Annotated[str, "报告类型,例如 'annual', 'quarterly'"],
) -> str:
"""
获取公司资产负债表 (Balance Sheet)。
"""
df = interface.get_simfin_balance_sheet(ticker, report_type)
return df.to_string()
- 从 SimFin API 获取 资产负债表。
- 可选年度 (annual) 或季度 (quarterly)。
- 返回格式化的 dataframe。
15. get_simfin_income_statement利润表
@staticmethod
@tool
def get_simfin_income_statement(
ticker: Annotated[str, "公司股票代码"],
report_type: Annotated[str, "annual 或 quarterly"],
) -> str:
"""
获取公司利润表 (Income Statement)。
"""
df = interface.get_simfin_income_statement(ticker, report_type)
return df.to_string()
- 获取 利润表(营业收入、净利润、毛利率等)。
- 主要用于盈利能力分析。
16. get_simfin_cashflow_statement 现金流量表
@staticmethod
@tool
def get_simfin_cashflow_statement(
ticker: Annotated[str, "公司股票代码"],
report_type: Annotated[str, "annual 或 quarterly"],
) -> str:
"""
获取公司现金流量表 (Cashflow Statement)。
"""
df = interface.get_simfin_cashflow_statement(ticker, report_type)
return df.to_string()
- 获取 现金流量表(经营活动、投资活动、融资活动现金流)。
- 用于衡量公司“造血能力”和资金链稳定性。
17. get_simfin_ratios 财务指标比率
@staticmethod
@tool
def get_simfin_ratios(
ticker: Annotated[str, "公司股票代码"],
report_type: Annotated[str, "annual 或 quarterly"],
) -> str:
"""
获取公司财务比率 (Ratios),例如 PE、ROE、负债率。
"""
df = interface.get_simfin_ratios(ticker, report_type)
return df.to_string()
- 获取 财务指标比率(PE, PB, ROE, 负债率, 流动比率)。
- 适合做跨公司对比。
18. get_simfin_company_info 公司基本信息
@staticmethod
@tool
def get_simfin_company_info(
ticker: Annotated[str, "公司股票代码"]
) -> str:
"""
获取公司基本信息(行业、地区、规模等)。
"""
df = interface.get_simfin_company_info(ticker)
return df.to_string()
- 获取公司的 基本信息(行业分类、上市地、公司规模)。
- 在做行业对比或聚类分析时很有用。
19. get_simfin_shareprices 历史股价数据
@staticmethod
@tool
def get_simfin_shareprices(
ticker: Annotated[str, "公司股票代码"],
start_date: Annotated[str, "开始日期"],
end_date: Annotated[str, "结束日期"],
) -> str:
"""
获取公司历史股价 (Share Prices)。
"""
df = interface.get_simfin_shareprices(ticker, start_date, end_date)
return df.to_string()
- 获取 SimFin 的历史股价数据。
- 类似 Yahoo Finance,但数据源不同。
20. get_fundamentals_openai 财务+估值
@staticmethod
@tool
def get_fundamentals_openai(
ticker: Annotated[str, "公司代码,例如 AAPL, TSLA"],
report_type: Annotated[str, "annual 或 quarterly"] = "annual",
) -> str:
"""
使用 OpenAI Agent 获取公司财务和估值信息。
(已废弃,推荐使用 get_stock_fundamentals_unified)
"""
logger.warning("⚠️ [DEPRECATED] 推荐使用 get_stock_fundamentals_unified() 代替")
return interface.get_fundamentals_openai(ticker, report_type)
- 原始版本,用 OpenAI Agent 来获取财务+估值。
- 已经 废弃,现在统一整合进
get_stock_fundamentals_unified。
- 已经 废弃,现在统一整合进
21. get_china_fundamentals A 股财务数据
@staticmethod
@tool
def get_china_fundamentals(
ticker: Annotated[str, "中国股票代码,例如 600519"],
report_type: Annotated[str, "年度/季度"] = "annual",
) -> str:
"""
获取中国 A 股财务数据(通过 Tushare 或 AKShare)。
(已废弃,推荐使用 get_stock_fundamentals_unified)
"""
logger.warning("⚠️ [DEPRECATED] 推荐使用 get_stock_fundamentals_unified() 代替")
return interface.get_china_fundamentals(ticker, report_type)
- 早期用于获取 A 股财务数据。
- 数据源:Tushare / AKShare。
- 已 废弃,功能已被统一接口替代。
22. get_stock_fundamentals_unified 整合财务+估值
@staticmethod
@tool
def get_stock_fundamentals_unified(
ticker: Annotated[str, "股票代码,例如 AAPL, 600519, 00700.HK"],
market: Annotated[str, "市场类型:us / cn / hk"] = "us",
report_type: Annotated[str, "annual 或 quarterly"] = "annual",
) -> str:
"""
统一接口:自动识别市场并获取财务数据与估值。
"""
return interface.get_stock_fundamentals_unified(ticker, market, report_type)
- 核心统一入口,整合了所有财务+估值工具:
- 美股 → Yahoo Finance + SimFin
- A 股 → Tushare + AKShare
- 港股 → Yahoo Finance + 东方财富
- 自动根据
market参数(us/cn/hk)选择合适数据源。 - 返回内容包括:
- 财报(资产负债表、利润表、现金流)
- 估值指标(PE, PB, PEG, ROE 等)
- 行业对比
create_msg_delete函数(非Toolkit类)
def create_msg_delete():
def delete_messages(state):
"""Clear messages and add placeholder for Anthropic compatibility"""
messages = state["messages"]
# Remove all messages
removal_operations = [RemoveMessage(id=m.id) for m in messages]
# Add a minimal placeholder message
placeholder = HumanMessage(content="Continue")
return {"messages": removal_operations + [placeholder]}
return delete_messages
- 清空消息历史,并插入一个占位的
HumanMessage("Continue"),保证在像 Anthropic 这类模型里保持对话兼容性。
ChromaDBManager
类说明
- 目标:保证整个项目里只存在一个 ChromaDB 客户端,避免多线程或多进程同时初始化带来的冲突。
- 关键点:用了 单例模式 + 线程锁 来保证全局唯一性。
class ChromaDBManager:
"""单例ChromaDB管理器,避免并发创建集合的冲突"""
_instance = None
_lock = threading.Lock()
_collections: Dict[str, any] = {}
_client = None
def __new__(cls):
if cls._instance is None:
...
cls._instance = super(ChromaDBManager, cls).__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if not self._initialized:
try:
...
self._initialized = True
def get_or_create_collection(self, name: str):
"""线程安全地获取或创建集合"""
with self._lock:
if name in self._collections:
logger.info(f"📚 [ChromaDB] 使用缓存集合: {name}")
return self._collections[name]
try:
# 尝试获取现有集合
collection = self._client.get_collection(name=name)
logger.info(f"📚 [ChromaDB] 获取现有集合: {name}")
except Exception:
try:
# 创建新集合
...
# 缓存集合
self._collections[name] = collection
return collection
逐函数解析
类属性
_instance = None # 存储类的唯一实例
_lock = threading.Lock() # 线程锁,保证并发安全
_collections: Dict[str, any] = {} # 缓存已经创建/获取的集合,避免重复创建
_client = None # 底层 ChromaDB 客户端实例
__new__(cls)
def __new__(cls):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super(ChromaDBManager, cls).__new__(cls)
cls._instance._initialized = False
return cls._instance
-
作用:实现单例模式,确保只创建一个实例。
-
逻辑:
- 如果
_instance还没创建 → 加锁。 - 再次检查
_instance(双重检查锁 DCL,避免竞态)。 - 创建实例,并标记
_initialized=False,表示还没初始化。
- 如果
-
返回值:类的唯一实例。
__init__(self)
def __init__(self):
if not self._initialized:
try:
# 自动检测操作系统版本并使用最优配置
import platform
system = platform.system()
if system == "Windows":
# 使用改进的Windows 11检测
from .chromadb_win11_config import is_windows_11
if is_windows_11():
# Windows 11 或更新版本,使用优化配置
from .chromadb_win11_config import get_win11_chromadb_client
self._client = get_win11_chromadb_client()
logger.info(f"📚 [ChromaDB] Windows 11优化配置初始化完成 (构建号: {platform.version()})")
else:
# Windows 10 或更老版本,使用兼容配置
from .chromadb_win10_config import get_win10_chromadb_client
self._client = get_win10_chromadb_client()
logger.info(f"📚 [ChromaDB] Windows 10兼容配置初始化完成")
else:
# 非Windows系统,使用标准配置
settings = Settings(
allow_reset=True,
anonymized_telemetry=False,
is_persistent=False
)
self._client = chromadb.Client(settings)
logger.info(f"📚 [ChromaDB] {system}标准配置初始化完成")
self._initialized = True
except Exception as e:
logger.error(f"❌ [ChromaDB] 初始化失败: {e}")
# 使用最简单的配置作为备用
try:
settings = Settings(
allow_reset=True,
anonymized_telemetry=False, # 关键:禁用遥测
is_persistent=False
)
self._client = chromadb.Client(settings)
logger.info(f"📚 [ChromaDB] 使用备用配置初始化完成")
except Exception as backup_error:
# 最后的备用方案
self._client = chromadb.Client()
logger.warning(f"⚠️ [ChromaDB] 使用最简配置初始化: {backup_error}")
self._initialized = True
-
作用:在第一次创建时初始化
ChromaDB客户端。 -
逻辑:
- 检测操作系统(Windows / Linux / Mac)。
- Windows:进一步区分 Windows 11 优化配置 和 Windows 10 兼容配置。
- 其它系统:用标准配置
Settings(...)。 - 初始化失败 → 尝试 备用配置(禁用遥测、非持久化)。
- 如果还失败 → 用 最简配置
chromadb.Client()。
-
输出:无(初始化
_client)。
get_or_create_collection(self, name: str)
def get_or_create_collection(self, name: str):
"""线程安全地获取或创建集合(输出ChromaDB 集合对象)"""
with self._lock:
if name in self._collections:
logger.info(f"📚 [ChromaDB] 使用缓存集合: {name}")
return self._collections[name]
try:
# 尝试获取现有集合
collection = self._client.get_collection(name=name)
logger.info(f"📚 [ChromaDB] 获取现有集合: {name}")
except Exception:
try:
# 创建新集合
collection = self._client.create_collection(name=name)
logger.info(f"📚 [ChromaDB] 创建新集合: {name}")
except Exception as e:
# 可能是并发创建,再次尝试获取
try:
collection = self._client.get_collection(name=name)
logger.info(f"📚 [ChromaDB] 并发创建后获取集合: {name}")
except Exception as final_error:
logger.error(f"❌ [ChromaDB] 集合操作失败: {name}, 错误: {final_error}")
raise final_error
# 缓存集合
self._collections[name] = collection
return collection
-
作用:获取或创建一个集合(类似数据库里的表)。
-
线程安全:加锁保证多个线程不会同时创建同一个集合。
-
逻辑:
- 如果缓存
_collections里已有 → 直接返回。 - 否则尝试
get_collection(name)。 - 如果失败(说明不存在) →
create_collection(name)。 - 如果创建也失败(可能是并发竞争) → 再尝试
get_collection(name)。 - 如果还是失败 → 抛出异常。
- 成功则缓存集合,并返回。
- 如果缓存
FinancialSituationMemory类
类说明
- 类可以视为一个 “财务情况记忆库”:
- 写入:新情况 + 建议 → 生成 embedding → 存入 ChromaDB
- 读取:新情况 → 生成 embedding → 相似度搜索 → 找到最相关的历史建议
简化后代码
class FinancialSituationMemory:
def __init__(self, name, config):
self.config = config
self.llm_provider = config.get("llm_provider", "openai").lower()
# 配置向量缓存的长度限制(向量缓存默认启用长度检查)
self.max_embedding_length = int(os.getenv('MAX_EMBEDDING_CONTENT_LENGTH', '50000')) # 默认50K字符
self.enable_embedding_length_check = os.getenv('ENABLE_EMBEDDING_LENGTH_CHECK', 'true').lower() == 'true' # 向量缓存默认启用
# 根据LLM提供商选择嵌入模型和客户端
self.fallback_available = False # 初始化降级选项标志
if self.llm_provider == "dashscope" or self.llm_provider == "alibaba":
self.embedding = "text-embedding-v3"
self.client = None # DashScope不需要OpenAI客户端
# 设置DashScope API密钥
dashscope_key = os.getenv('DASHSCOPE_API_KEY')
if dashscope_key:
try:
# 尝试导入和初始化DashScope
import dashscope
from dashscope import TextEmbedding
dashscope.api_key = dashscope_key
logger.info(f"✅ DashScope API密钥已配置,启用记忆功能")
except ImportError as e:
# DashScope包未安装 ...
except Exception as e:
# 其他初始化错误 ...
else:
# 没有DashScope密钥,禁用记忆功能 ...
elif self.llm_provider == "deepseek":
...
# 使用单例ChromaDB管理器
self.chroma_manager = ChromaDBManager()
self.situation_collection = self.chroma_manager.get_or_create_collection(name)
def _smart_text_truncation(self, text, max_length=8192):
"""智能文本截断,保持语义完整性和缓存兼容性"""
if len(text) <= max_length:
return text, False # 返回原文本和是否截断的标志
# 尝试在句子边界截断
sentences = text.split('。')
...
# 尝试在段落边界截断
paragraphs = text.split('\n')
...
return truncated, True
def get_embedding(self, text):
"""Get embedding for a text using the configured provider"""
# 检查记忆功能是否被禁用
if self.client == "DISABLED":
# 内存功能已禁用,返回空向量 ...
if len(text) == 0: ... # 输入文本长度为0,返回空向量
# 检查是否启用长度限制
if self.enable_embedding_length_check and text_length > self.max_embedding_length:
# 文本过长跳过向量化并存储跳过信息 ...
return [0.0] * 1024
# 存储文本处理信息
self._last_text_info = { ... }
if (self.llm_provider == "dashscope" or
self.llm_provider == "alibaba" or
(self.llm_provider == "google" and self.client is None) or
(self.llm_provider == "deepseek" and self.client is None) or
(self.llm_provider == "openrouter" and self.client is None)):
# 使用阿里百炼的嵌入模型
try: ...
return embedding
else:
...
return [0.0] * 1024
def get_embedding_config_status(self):
"""获取向量缓存配置状态"""
return {
'enabled': self.enable_embedding_length_check,
'max_embedding_length': self.max_embedding_length,
'max_embedding_length_formatted': f"{self.max_embedding_length:,}字符",
'provider': self.llm_provider,
'client_status': 'DISABLED' if self.client == "DISABLED" else 'ENABLED'
}
def add_situations(self, situations_and_advice):
"""Add financial situations and their corresponding advice. Parameter is a list of tuples (situation, rec)"""
...
self.situation_collection.add(
documents=situations,
metadatas=[{"recommendation": rec} for rec in advice],
embeddings=embeddings,
ids=ids,
)
Example
# Example usage
matcher = FinancialSituationMemory()
# Example data
example_data = [
(
"High inflation rate with rising interest rates and declining consumer spending",
"Consider defensive sectors like consumer staples and utilities. Review fixed-income portfolio duration.",
),
(
"Tech sector showing high volatility with increasing institutional selling pressure",
"Reduce exposure to high-growth tech stocks. Look for value opportunities in established tech companies with strong cash flows.",
),
(
"Strong dollar affecting emerging markets with increasing forex volatility",
"Hedge currency exposure in international positions. Consider reducing allocation to emerging market debt.",
),
(
"Market showing signs of sector rotation with rising yields",
"Rebalance portfolio to maintain target allocations. Consider increasing exposure to sectors benefiting from higher rates.",
),
]
# Add the example situations and recommendations
matcher.add_situations(example_data)
# Example query
current_situation = """
Market showing increased volatility in tech sector, with institutional investors
reducing positions and rising interest rates affecting growth stock valuations
"""
try:
recommendations = matcher.get_memories(current_situation, n_matches=2)
for i, rec in enumerate(recommendations, 1):
logger.info(f"\nMatch {i}:")
logger.info(f"Similarity Score: {rec.get('similarity', 0):.2f}")
logger.info(f"Matched Situation: {rec.get('situation', '')}")
logger.info(f"Recommendation: {rec.get('recommendation', '')}")
except Exception as e:
logger.error(f"Error during recommendation: {str(e)}")
逐函数解析
__init__(self, config)

-
作用
初始化类,设置向量数据库(ChromaDB),并根据config和环境变量自动选择合适的 Embedding 服务。 -
主要逻辑
- 保存
config。 - 初始化一些内部状态(provider、model、status、last_text_info 等)。
- 根据
config["provider"]来选择 embedding 服务:- DashScope / Alibaba → 用阿里云的 embedding API
- DeepSeek → 优先用 DashScope,其次 OpenAI,最后 DeepSeek 自己的 embedding
- Google → 优先 DashScope,如果配置里有
openai_api_key就启用 fallback - OpenRouter → DashScope embedding
- 本地 Ollama (localhost:11434) → 使用 nomic-embed-text
- 默认 → 尝试 OpenAI embedding,失败则禁用
- 保存
- 创建一个 ChromaDB 客户端,并建立/获取一个集合
financial_situations。
- 返回值 : 无(构造函数)。
def __init__(self, name, config):
self.config = config
self.llm_provider = config.get("llm_provider", "openai").lower()
# 配置向量缓存的长度限制(向量缓存默认启用长度检查)
self.max_embedding_length = int(os.getenv('MAX_EMBEDDING_CONTENT_LENGTH', '50000')) # 默认50K字符
self.enable_embedding_length_check = os.getenv('ENABLE_EMBEDDING_LENGTH_CHECK', 'true').lower() == 'true' # 向量缓存默认启用
# 根据LLM提供商选择嵌入模型和客户端
# 初始化降级选项标志
self.fallback_available = False
if self.llm_provider == "dashscope" or self.llm_provider == "alibaba":
self.embedding = "text-embedding-v3"
self.client = None # DashScope不需要OpenAI客户端
# 设置DashScope API密钥
dashscope_key = os.getenv('DASHSCOPE_API_KEY')
if dashscope_key:
try:
# 尝试导入和初始化DashScope
import dashscope
from dashscope import TextEmbedding
dashscope.api_key = dashscope_key
logger.info(f"✅ DashScope API密钥已配置,启用记忆功能")
# 可选:测试API连接(简单验证)
# 这里不做实际调用,只验证导入和密钥设置
except ImportError as e:
# DashScope包未安装
logger.error(f"❌ DashScope包未安装: {e}")
self.client = "DISABLED"
logger.warning(f"⚠️ 记忆功能已禁用")
except Exception as e:
# 其他初始化错误
logger.error(f"❌ DashScope初始化失败: {e}")
self.client = "DISABLED"
logger.warning(f"⚠️ 记忆功能已禁用")
else:
# 没有DashScope密钥,禁用记忆功能
self.client = "DISABLED"
logger.warning(f"⚠️ 未找到DASHSCOPE_API_KEY,记忆功能已禁用")
logger.info(f"💡 系统将继续运行,但不会保存或检索历史记忆")
elif self.llm_provider == "deepseek":
# 检查是否强制使用OpenAI嵌入
force_openai = os.getenv('FORCE_OPENAI_EMBEDDING', 'false').lower() == 'true'
if not force_openai:
# 尝试使用阿里百炼嵌入
dashscope_key = os.getenv('DASHSCOPE_API_KEY')
if dashscope_key:
try:
# 测试阿里百炼是否可用
import dashscope
from dashscope import TextEmbedding
dashscope.api_key = dashscope_key
# 验证TextEmbedding可用性(不需要实际调用)
self.embedding = "text-embedding-v3"
self.client = None
logger.info(f"💡 DeepSeek使用阿里百炼嵌入服务")
except ImportError as e:
logger.error(f"⚠️ DashScope包未安装: {e}")
dashscope_key = None # 强制降级
except Exception as e:
logger.error(f"⚠️ 阿里百炼嵌入初始化失败: {e}")
dashscope_key = None # 强制降级
else:
dashscope_key = None # 跳过阿里百炼
if not dashscope_key or force_openai:
# 降级到OpenAI嵌入
self.embedding = "text-embedding-3-small"
openai_key = os.getenv('OPENAI_API_KEY')
if openai_key:
self.client = OpenAI(
api_key=openai_key,
base_url=config.get("backend_url", "https://api.openai.com/v1")
)
logger.warning(f"⚠️ DeepSeek回退到OpenAI嵌入服务")
else:
# 最后尝试DeepSeek自己的嵌入
deepseek_key = os.getenv('DEEPSEEK_API_KEY')
if deepseek_key:
try:
self.client = OpenAI(
api_key=deepseek_key,
base_url="https://api.deepseek.com"
)
logger.info(f"💡 DeepSeek使用自己的嵌入服务")
except Exception as e:
logger.error(f"❌ DeepSeek嵌入服务不可用: {e}")
# 禁用内存功能
self.client = "DISABLED"
logger.info(f"🚨 内存功能已禁用,系统将继续运行但不保存历史记忆")
else:
# 禁用内存功能而不是抛出异常
self.client = "DISABLED"
logger.info(f"🚨 未找到可用的嵌入服务,内存功能已禁用")
elif self.llm_provider == "google":
# Google AI使用阿里百炼嵌入(如果可用),否则禁用记忆功能
dashscope_key = os.getenv('DASHSCOPE_API_KEY')
openai_key = os.getenv('OPENAI_API_KEY')
if dashscope_key:
try:
# 尝试初始化DashScope
import dashscope
from dashscope import TextEmbedding
self.embedding = "text-embedding-v3"
self.client = None
dashscope.api_key = dashscope_key
# 检查是否有OpenAI密钥作为降级选项
if openai_key:
logger.info(f"💡 Google AI使用阿里百炼嵌入服务(OpenAI作为降级选项)")
self.fallback_available = True
self.fallback_client = OpenAI(api_key=openai_key, base_url=config["backend_url"])
self.fallback_embedding = "text-embedding-3-small"
else:
logger.info(f"💡 Google AI使用阿里百炼嵌入服务(无降级选项)")
self.fallback_available = False
except ImportError as e:
logger.error(f"❌ DashScope包未安装: {e}")
self.client = "DISABLED"
logger.warning(f"⚠️ Google AI记忆功能已禁用")
except Exception as e:
logger.error(f"❌ DashScope初始化失败: {e}")
self.client = "DISABLED"
logger.warning(f"⚠️ Google AI记忆功能已禁用")
else:
# 没有DashScope密钥,禁用记忆功能
self.client = "DISABLED"
self.fallback_available = False
logger.warning(f"⚠️ Google AI未找到DASHSCOPE_API_KEY,记忆功能已禁用")
logger.info(f"💡 系统将继续运行,但不会保存或检索历史记忆")
elif self.llm_provider == "openrouter":
# OpenRouter支持:优先使用阿里百炼嵌入,否则禁用记忆功能
dashscope_key = os.getenv('DASHSCOPE_API_KEY')
if dashscope_key:
try:
# 尝试使用阿里百炼嵌入
import dashscope
from dashscope import TextEmbedding
self.embedding = "text-embedding-v3"
self.client = None
dashscope.api_key = dashscope_key
logger.info(f"💡 OpenRouter使用阿里百炼嵌入服务")
except ImportError as e:
logger.error(f"❌ DashScope包未安装: {e}")
self.client = "DISABLED"
logger.warning(f"⚠️ OpenRouter记忆功能已禁用")
except Exception as e:
logger.error(f"❌ DashScope初始化失败: {e}")
self.client = "DISABLED"
logger.warning(f"⚠️ OpenRouter记忆功能已禁用")
else:
# 没有DashScope密钥,禁用记忆功能
self.client = "DISABLED"
logger.warning(f"⚠️ OpenRouter未找到DASHSCOPE_API_KEY,记忆功能已禁用")
logger.info(f"💡 系统将继续运行,但不会保存或检索历史记忆")
elif config["backend_url"] == "http://localhost:11434/v1":
self.embedding = "nomic-embed-text"
self.client = OpenAI(base_url=config["backend_url"])
else:
self.embedding = "text-embedding-3-small"
openai_key = os.getenv('OPENAI_API_KEY')
if openai_key:
self.client = OpenAI(
api_key=openai_key,
base_url=config["backend_url"]
)
else:
self.client = "DISABLED"
logger.warning(f"⚠️ 未找到OPENAI_API_KEY,记忆功能已禁用")
# 使用单例ChromaDB管理器
self.chroma_manager = ChromaDBManager()
self.situation_collection = self.chroma_manager.get_or_create_collection(name)
_smart_text_truncation(self, text, max_length)
def _smart_text_truncation(self, text, max_length=8192):
"""智能文本截断,保持语义完整性和缓存兼容性"""
if len(text) <= max_length:
return text, False # 返回原文本和是否截断的标志
# 尝试在句子边界截断
sentences = text.split('。')
if len(sentences) > 1:
truncated = ""
for sentence in sentences:
if len(truncated + sentence + '。') <= max_length - 50: # 留50字符余量
truncated += sentence + '。'
else:
break
if len(truncated) > max_length // 2: # 至少保留一半内容
logger.info(f"📝 智能截断:在句子边界截断,保留{len(truncated)}/{len(text)}字符")
return truncated, True
# 尝试在段落边界截断
paragraphs = text.split('\n')
if len(paragraphs) > 1:
truncated = ""
for paragraph in paragraphs:
if len(truncated + paragraph + '\n') <= max_length - 50:
truncated += paragraph + '\n'
else:
break
if len(truncated) > max_length // 2:
logger.info(f"📝 智能截断:在段落边界截断,保留{len(truncated)}/{len(text)}字符")
return truncated, True
# 最后选择:保留前半部分和后半部分的关键信息
front_part = text[:max_length//2]
back_part = text[-(max_length//2-100):] # 留100字符给连接符
truncated = front_part + "\n...[内容截断]...\n" + back_part
logger.warning(f"⚠️ 强制截断:保留首尾关键信息,{len(text)}字符截断为{len(truncated)}字符")
return truncated, True
-
作用
保证输入文本不会超过max_length,但尽量保持语义完整(比如按句子、段落截断)。 -
主要逻辑
- 如果文本长度 ≤
max_length→ 原样返回。 - 如果太长:
- 尝试在句号(。.!?)之后截断,保留前面一段完整句子。
- 否则尝试按段落
\n截断。 - 如果都不行,就取前
max_length//2和后max_length//2拼接,中间插入...。
- 把截断后的文本保存到
self.last_text_info,记录是否截断、采用哪种策略、原始/最终长度等。
- 如果文本长度 ≤
-
输入
text(str):原始文本max_length(int):允许的最大长度
-
输出
- 截断后的文本 (str)
get_embedding(self, text, max_length=1000)
def get_embedding(self, text):
"""Get embedding for a text using the configured provider"""
# 检查记忆功能是否被禁用
if self.client == "DISABLED":
# 内存功能已禁用,返回空向量
logger.debug(f"⚠️ 记忆功能已禁用,返回空向量")
return [0.0] * 1024 # 返回1024维的零向量
# 验证输入文本
if not text or not isinstance(text, str):
logger.warning(f"⚠️ 输入文本为空或无效,返回空向量")
return [0.0] * 1024
text_length = len(text)
if text_length == 0:
logger.warning(f"⚠️ 输入文本长度为0,返回空向量")
return [0.0] * 1024
# 检查是否启用长度限制
if self.enable_embedding_length_check and text_length > self.max_embedding_length:
logger.warning(f"⚠️ 文本过长({text_length:,}字符 > {self.max_embedding_length:,}字符),跳过向量化")
# 存储跳过信息
self._last_text_info = {
'original_length': text_length,
'processed_length': 0,
'was_truncated': False,
'was_skipped': True,
'provider': self.llm_provider,
'strategy': 'length_limit_skip',
'max_length': self.max_embedding_length
}
return [0.0] * 1024
# 记录文本信息(不进行任何截断)
if text_length > 8192:
logger.info(f"📝 处理长文本: {text_length}字符,提供商: {self.llm_provider}")
# 存储文本处理信息
self._last_text_info = {
'original_length': text_length,
'processed_length': text_length, # 不截断,保持原长度
'was_truncated': False, # 永不截断
'was_skipped': False,
'provider': self.llm_provider,
'strategy': 'no_truncation_with_fallback' # 标记策略
}
if (self.llm_provider == "dashscope" or
self.llm_provider == "alibaba" or
(self.llm_provider == "google" and self.client is None) or
(self.llm_provider == "deepseek" and self.client is None) or
(self.llm_provider == "openrouter" and self.client is None)):
# 使用阿里百炼的嵌入模型
try:
# 导入DashScope模块
import dashscope
from dashscope import TextEmbedding
# 检查DashScope API密钥是否可用
if not hasattr(dashscope, 'api_key') or not dashscope.api_key:
logger.warning(f"⚠️ DashScope API密钥未设置,记忆功能降级")
return [0.0] * 1024 # 返回空向量
# 尝试调用DashScope API
response = TextEmbedding.call(
model=self.embedding,
input=text
)
# 检查响应状态
if response.status_code == 200:
# 成功获取embedding
embedding = response.output['embeddings'][0]['embedding']
logger.debug(f"✅ DashScope embedding成功,维度: {len(embedding)}")
return embedding
else:
# API返回错误状态码
error_msg = f"{response.code} - {response.message}"
# 检查是否为长度限制错误
if any(keyword in error_msg.lower() for keyword in ['length', 'token', 'limit', 'exceed']):
logger.warning(f"⚠️ DashScope长度限制: {error_msg}")
# 检查是否有降级选项
if hasattr(self, 'fallback_available') and self.fallback_available:
logger.info(f"💡 尝试使用OpenAI降级处理长文本")
try:
response = self.fallback_client.embeddings.create(
model=self.fallback_embedding,
input=text
)
embedding = response.data[0].embedding
logger.info(f"✅ OpenAI降级成功,维度: {len(embedding)}")
return embedding
except Exception as fallback_error:
logger.error(f"❌ OpenAI降级失败: {str(fallback_error)}")
logger.info(f"💡 所有降级选项失败,记忆功能降级")
return [0.0] * 1024
else:
logger.info(f"💡 无可用降级选项,记忆功能降级")
return [0.0] * 1024
else:
logger.error(f"❌ DashScope API错误: {error_msg}")
return [0.0] * 1024 # 返回空向量而不是抛出异常
except Exception as e:
error_str = str(e).lower()
# 检查是否为长度限制错误
if any(keyword in error_str for keyword in ['length', 'token', 'limit', 'exceed', 'too long']):
logger.warning(f"⚠️ DashScope长度限制异常: {str(e)}")
# 检查是否有降级选项
if hasattr(self, 'fallback_available') and self.fallback_available:
logger.info(f"💡 尝试使用OpenAI降级处理长文本")
try:
response = self.fallback_client.embeddings.create(
model=self.fallback_embedding,
input=text
)
embedding = response.data[0].embedding
logger.info(f"✅ OpenAI降级成功,维度: {len(embedding)}")
return embedding
except Exception as fallback_error:
logger.error(f"❌ OpenAI降级失败: {str(fallback_error)}")
logger.info(f"💡 所有降级选项失败,记忆功能降级")
return [0.0] * 1024
else:
logger.info(f"💡 无可用降级选项,记忆功能降级")
return [0.0] * 1024
elif 'import' in error_str:
logger.error(f"❌ DashScope包未安装: {str(e)}")
elif 'connection' in error_str:
logger.error(f"❌ DashScope网络连接错误: {str(e)}")
elif 'timeout' in error_str:
logger.error(f"❌ DashScope请求超时: {str(e)}")
else:
logger.error(f"❌ DashScope embedding异常: {str(e)}")
logger.warning(f"⚠️ 记忆功能降级,返回空向量")
return [0.0] * 1024
else:
# 使用OpenAI兼容的嵌入模型
if self.client is None:
logger.warning(f"⚠️ 嵌入客户端未初始化,返回空向量")
return [0.0] * 1024 # 返回空向量
elif self.client == "DISABLED":
# 内存功能已禁用,返回空向量
logger.debug(f"⚠️ 内存功能已禁用,返回空向量")
return [0.0] * 1024 # 返回1024维的零向量
# 尝试调用OpenAI兼容的embedding API
try:
response = self.client.embeddings.create(
model=self.embedding,
input=text
)
embedding = response.data[0].embedding
logger.debug(f"✅ {self.llm_provider} embedding成功,维度: {len(embedding)}")
return embedding
except Exception as e:
error_str = str(e).lower()
# 检查是否为长度限制错误
length_error_keywords = [
'token', 'length', 'too long', 'exceed', 'maximum', 'limit',
'context', 'input too large', 'request too large'
]
is_length_error = any(keyword in error_str for keyword in length_error_keywords)
if is_length_error:
# 长度限制错误:直接降级,不截断重试
logger.warning(f"⚠️ {self.llm_provider}长度限制: {str(e)}")
logger.info(f"💡 为保证分析准确性,不截断文本,记忆功能降级")
else:
# 其他类型的错误
if 'attributeerror' in error_str:
logger.error(f"❌ {self.llm_provider} API调用错误: {str(e)}")
elif 'connectionerror' in error_str or 'connection' in error_str:
logger.error(f"❌ {self.llm_provider}网络连接错误: {str(e)}")
elif 'timeout' in error_str:
logger.error(f"❌ {self.llm_provider}请求超时: {str(e)}")
elif 'keyerror' in error_str:
logger.error(f"❌ {self.llm_provider}响应格式错误: {str(e)}")
else:
logger.error(f"❌ {self.llm_provider} embedding异常: {str(e)}")
logger.warning(f"⚠️ 记忆功能降级,返回空向量")
return [0.0] * 1024
-
作用 : 把文本转成向量 embedding。
-
主要逻辑
- 调用
_smart_text_truncation保证长度安全。 - 根据 provider 选择对应的 API:
- DashScope → 调用
dashscope.TextEmbedding.call,模型是text-embedding-v2 - OpenAI → 调用
client.embeddings.create(model="text-embedding-3-small") - DeepSeek → 可能走 DeepSeek 自己的 embedding API
- Ollama → 调用
client.embeddings.create(model="nomic-embed-text") - 其它情况 → 返回
[0.0] * 1024(代表禁用)
- DashScope → 调用
- 如果调用失败,捕获异常并返回零向量。
- 调用
-
输入
text(str):输入文本max_length(int):最大允许长度(默认 1000)
-
输出
- 向量 embedding (list[float])
get_embedding_config_status(self)
def get_embedding_config_status(self):
"""获取向量缓存配置状态"""
return {
'enabled': self.enable_embedding_length_check,
'max_embedding_length': self.max_embedding_length,
'max_embedding_length_formatted': f"{self.max_embedding_length:,}字符",
'provider': self.llm_provider,
'client_status': 'DISABLED' if self.client == "DISABLED" else 'ENABLED'
}
- 作用 返回当前 embedding 的配置信息,主要用于调试/检查状态。
get_last_text_info(self)
def get_last_text_info(self):
"""获取最后处理的文本信息"""
return getattr(self, '_last_text_info', None)
- 作用:返回最近一次
_smart_text_truncation的信息(调试用)。
add_situations(self, situations)
def add_situations(self, situations_and_advice):
"""Add financial situations and their corresponding advice. Parameter is a list of tuples (situation, rec)"""
situations = []
advice = []
ids = []
embeddings = []
offset = self.situation_collection.count()
for i, (situation, recommendation) in enumerate(situations_and_advice):
situations.append(situation)
advice.append(recommendation)
ids.append(str(offset + i))
embeddings.append(self.get_embedding(situation))
self.situation_collection.add(
documents=situations,
metadatas=[{"recommendation": rec} for rec in advice],
embeddings=embeddings,
ids=ids,
)
- 作用:把一组新的“财务情况 + 建议”存入数据库。
主要逻辑
-
遍历
situations,每个元素是(situation, recommendation)。 -
对
situation文本生成 embedding。 -
构造数据项:
{ "id": str(uuid.uuid4()), "embedding": 向量, "metadata": {"recommendation": 建议} } -
批量写入
self.collection(ChromaDB)。
get_cache_info(self)
def get_cache_info(self):
"""获取缓存相关信息,用于调试和监控"""
info = {
'collection_count': self.situation_collection.count(),
'client_status': 'enabled' if self.client != "DISABLED" else 'disabled',
'embedding_model': self.embedding,
'provider': self.llm_provider
}
# 添加最后一次文本处理信息
if hasattr(self, '_last_text_info'):
info['last_text_processing'] = self._last_text_info
return info
- 作用
返回当前 ChromaDB 集合的一些元信息(比如名称、模型、provider 等),帮助确认系统运行状态。
get_memories(self, situation, n_results=5)
def get_memories(self, current_situation, n_matches=1):
"""Find matching recommendations using embeddings with smart truncation handling"""
# 获取当前情况的embedding
query_embedding = self.get_embedding(current_situation)
# 检查是否为空向量(记忆功能被禁用或出错)
if all(x == 0.0 for x in query_embedding):
logger.debug(f"⚠️ 查询embedding为空向量,返回空结果")
return []
# 检查是否有足够的数据进行查询
collection_count = self.situation_collection.count()
if collection_count == 0:
logger.debug(f"📭 记忆库为空,返回空结果")
return []
# 调整查询数量,不能超过集合中的文档数量
actual_n_matches = min(n_matches, collection_count)
try:
# 执行相似度查询
results = self.situation_collection.query(
query_embeddings=[query_embedding],
n_results=actual_n_matches
)
# 处理查询结果
memories = []
if results and 'documents' in results and results['documents']:
documents = results['documents'][0]
metadatas = results.get('metadatas', [[]])[0]
distances = results.get('distances', [[]])[0]
for i, doc in enumerate(documents):
metadata = metadatas[i] if i < len(metadatas) else {}
distance = distances[i] if i < len(distances) else 1.0
memory_item = {
'situation': doc,
'recommendation': metadata.get('recommendation', ''),
'similarity': 1.0 - distance, # 转换为相似度分数
'distance': distance
}
memories.append(memory_item)
# 记录查询信息
if hasattr(self, '_last_text_info') and self._last_text_info.get('was_truncated'):
logger.info(f"🔍 截断文本查询完成,找到{len(memories)}个相关记忆")
logger.debug(f"📊 原文长度: {self._last_text_info['original_length']}, "
f"处理后长度: {self._last_text_info['processed_length']}")
else:
logger.debug(f"🔍 记忆查询完成,找到{len(memories)}个相关记忆")
return memories
except Exception as e:
logger.error(f"❌ 记忆查询失败: {str(e)}")
return []
-
作用:检索最相似的历史情况,返回对应的建议。
-
主要逻辑
- 对输入
situation生成 embedding。 - 调用
self.collection.query,检索最相似的n_results条记录。 - 解析返回结果:提取文本、embedding 相似度/距离、推荐建议。
- 返回一个结果列表。
- 对输入
2183

被折叠的 条评论
为什么被折叠?



