【经济学】量化模型TradingAgents 工具集成层与数据(财报+ 基本信息&指标+基本面分析)+ChromaDB 客户端+财务情况记忆库

Toolkit 作用

  • Toolkit(tradingagents/agents/utils/agent_utils.py) TradingAgents 框架里的工具集,封装了各种外部数据源的调用接口(新闻、财务、行情、情绪、指标等),并通过 @tool 装饰器暴露给 LLM 使用。

  • TradingAgents Toolkit 功能总览表

模块类别主要函数/工具功能描述数据来源状态
市场行情 (Market)get_market_data_unified获取股票/指数的市场行情数据(价格、成交量、技术指标)Yahoo Finance, Tushare, 东方财富等
get_price_history获取历史 K 线数据同上
get_technical_indicators生成技术指标(MA, RSI, MACD 等)本地计算
新闻 (News)get_news_articles抓取金融新闻Google News, 东方财富新闻等
summarize_news对新闻做摘要LLM 处理
社交媒体 (Social)get_social_sentiment获取社交媒体情绪(Twitter, 微博)API / 爬虫数据源可能受限
analyze_sentiment使用 LLM 对评论、帖子做情绪分析LLM
基本面 (Fundamentals)get_stock_fundamentals_unified统一接口:获取美股/A股/港股的财务数据与估值Yahoo Finance, SimFin, Tushare, AKShare, 东方财富
get_fundamentals_openai使用 OpenAI Agent 调用财务数据OpenAI Agent废弃
get_china_fundamentals获取中国 A 股财务数据(旧接口)Tushare, AKShare废弃
风险分析 (Risk)get_risk_metrics计算风险指标(波动率、夏普比率、回撤)本地计算
stress_test压力测试(不同市场情景下的资产表现)本地模拟
portfolio_risk_analysis投资组合风险分析本地计算 + 历史行情

Toolkit 逐函数解析

  • 大部分函数都是通过tradingagents.dataflows.interface获取数据。

1. 获取默认配置

from tradingagents.default_config import DEFAULT_CONFIG
_config = DEFAULT_CONFIG.copy()
  • 获取默认配置,默认配置如下:
import os

DEFAULT_CONFIG = {
    "project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
    "results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"),
    "data_dir": os.path.join(os.path.expanduser("~"), "Documents", "TradingAgents", "data"),
    "data_cache_dir": os.path.join(
        os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
        "dataflows/data_cache",
    ),
    # LLM settings
    "llm_provider": "openai",
    "deep_think_llm": "o4-mini",
    "quick_think_llm": "gpt-4o-mini",
    "backend_url": "https://api.openai.com/v1",
    # Debate and discussion settings
    "max_debate_rounds": 1,
    "max_risk_discuss_rounds": 1,
    "max_recur_limit": 100,
    # Tool settings
    "online_tools": True,

    # Note: Database and cache configuration is now managed by .env file and config.database_manager
    # No database/cache settings in default config to avoid configuration conflicts
}

2. update_config

@classmethod
def update_config(cls, config):
    """Update the class-level configuration."""
    cls._config.update(config)
  • 更新 Toolkit 的全局配置(类级别),比如数据源 API key、默认参数等。

3. config

@property
def config(self):
    """Access the configuration."""
    return self._config
  • 返回当前配置。

4. __init__

def __init__(self, config=None):
    if config:
        self.update_config(config)
  • 构造函数,可传入配置并更新默认配置。

5. get_reddit_news

@staticmethod
@tool
def get_reddit_news(
    curr_date: Annotated[str, "Date you want to get news for in yyyy-mm-dd format"],
) -> str:
    """
    Retrieve global news from Reddit within a specified time frame.
    Args:
        curr_date (str): Date you want to get news for in yyyy-mm-dd format
    Returns:
        str: A formatted dataframe containing the latest global news from Reddit in the specified time frame.
    """
    
    global_news_result = interface.get_reddit_global_news(curr_date, 7, 5)

    return global_news_result
  • 从 Reddit 获取某天起过去 7 天内的 全球新闻(最多 5 条),主要用于宏观舆情分析。

6. get_finnhub_news

@staticmethod
@tool
def get_finnhub_news(
    ticker: Annotated[str, "Search query of a company, e.g. 'AAPL, TSM, etc."],
    start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
    end_date: Annotated[str, "End date in yyyy-mm-dd format"],
):
    """
    Retrieve the latest news about a given stock from Finnhub within a date range
    Args:
        ticker (str): Ticker of a company. e.g. AAPL, TSM
        start_date (str): Start date in yyyy-mm-dd format
        end_date (str): End date in yyyy-mm-dd format
    Returns:
        str: A formatted dataframe containing news about the company within the date range from start_date to end_date
    """

    end_date_str = end_date

    end_date = datetime.strptime(end_date, "%Y-%m-%d")
    start_date = datetime.strptime(start_date, "%Y-%m-%d")
    look_back_days = (end_date - start_date).days

    finnhub_news_result = interface.get_finnhub_news(
        ticker, end_date_str, look_back_days
    )

    return finnhub_news_result
  • 调用 Finnhub API 获取指定股票在 start_date ~ end_date 的新闻。

7. get_reddit_stock_info

@staticmethod
@tool
def get_reddit_stock_info(
    ticker: Annotated[str, "Ticker of a company. e.g. AAPL, TSM"],
    curr_date: Annotated[str, "Current date you want to get news for"],
) -> str:
    """
    Retrieve the latest news about a given stock from Reddit, given the current date.
    Args:
        ticker (str): Ticker of a company. e.g. AAPL, TSM
        curr_date (str): current date in yyyy-mm-dd format to get news for
    Returns:
        str: A formatted dataframe containing the latest news about the company on the given date
    """

    stock_news_results = interface.get_reddit_company_news(ticker, curr_date, 7, 5)

    return stock_news_results
  • 从 Reddit 获取某只股票的近期新闻和讨论。

8. get_chinese_social_sentiment

@staticmethod
@tool
def get_chinese_social_sentiment(
    ticker: Annotated[str, "股票代码,例如 '600519.SS' 或 '000001.SZ'"],
    curr_date: Annotated[str, "要获取的日期,yyyy-mm-dd 格式"],
) -> str:
    """
    获取某只股票在中国社交媒体上的情绪分析。
    Args:
        ticker (str): 股票代码 (A股格式)
        curr_date (str): 要获取的日期
    Returns:
        str: 包含中国社交媒体情绪分析的格式化 dataframe
    """
    try:
        sentiment_result = interface.get_chinese_sentiment(ticker, curr_date, 7)
    except Exception as e:
        logger.warning(f"中国舆情数据获取失败,回退到 Reddit: {e}")
        sentiment_result = interface.get_reddit_company_news(ticker, curr_date, 7, 5)
    return sentiment_result
  • 获取 A 股股票在 中国本土社交媒体 上的舆情/情绪数据。
    • 如果中国数据源失败 → 自动回退到 Reddit
    • 常用于国内公司投资情绪分析。

9. get_finnhub_company_insider_sentiment

@staticmethod
@tool
def get_finnhub_company_insider_sentiment(
    ticker: Annotated[str, "股票代码,例如 'AAPL', 'TSM'"],
) -> str:
    """
    获取公司内部人买卖行为的情绪数据。
    Args:
        ticker (str): 公司代码
    Returns:
        str: 格式化的内部人情绪数据
    """
    insider_sentiment = interface.get_finnhub_insider_sentiment(ticker)
    return insider_sentiment
  • 调用 Finnhub API 获取 内部人买卖股票的情绪指标(比如 CEO、CFO 买入/卖出, 属于“聪明钱”指标)。

10. get_YFin_data

@staticmethod
@tool
def get_YFin_data(
    ticker: Annotated[str, "公司股票代码,例如 'AAPL', 'TSM'"],
    period: Annotated[str, "数据周期,例如 '1mo', '6mo', '1y'"],
) -> str:
    """
    获取 Yahoo Finance 历史行情数据。
    Args:
        ticker (str): 股票代码
        period (str): 时间范围
    Returns:
        str: 格式化的历史行情 dataframe
    """
    df = interface.get_yahoo_finance_history(ticker, period)
    return df.to_string()
  • Yahoo Finance 获取某只股票的历史行情(K 线数据)。

11. get_YFin_data_online

@staticmethod
@tool
def get_YFin_data_online(
    ticker: Annotated[str, "公司股票代码"],
    start_date: Annotated[str, "开始日期 yyyy-mm-dd"],
    end_date: Annotated[str, "结束日期 yyyy-mm-dd"],
) -> str:
    """
    获取 Yahoo Finance 区间行情。
    """
    df = interface.get_yahoo_finance_range(ticker, start_date, end_date)
    return df.to_string()
  • 类似 get_YFin_data,但支持 指定开始和结束日期

12. get_stockstats_indicators_report

@staticmethod
@tool
def get_stockstats_indicators_report(
    ticker: Annotated[str, "公司股票代码"],
    period: Annotated[str, "分析周期,例如 '6mo'"],
) -> str:
    """
    获取技术指标分析报告(基于 stockstats)。
    """
    df = interface.get_stockstats_indicators(ticker, period)
    return df.to_string()
  • 基于 stockstats 库计算技术指标(均线、RSI、MACD 等)。

13. get_stockstats_indicators_report_online

@staticmethod
@tool
def get_stockstats_indicators_report_online(
    ticker: Annotated[str, "公司股票代码"],
    start_date: Annotated[str, "开始日期"],
    end_date: Annotated[str, "结束日期"],
) -> str:
    """
    获取指定区间的技术指标分析。
    """
    df = interface.get_stockstats_indicators_range(ticker, start_date, end_date)
    return df.to_string()
  • get_stockstats_indicators_report 类似,但支持 自定义区间

14. get_simfin_balance_sheet 资产负债表

@staticmethod
@tool
def get_simfin_balance_sheet(
    ticker: Annotated[str, "公司股票代码,例如 'AAPL', 'TSM'"],
    report_type: Annotated[str, "报告类型,例如 'annual', 'quarterly'"],
) -> str:
    """
    获取公司资产负债表 (Balance Sheet)。
    """
    df = interface.get_simfin_balance_sheet(ticker, report_type)
    return df.to_string()
  • SimFin API 获取 资产负债表
    • 可选年度 (annual) 或季度 (quarterly)。
    • 返回格式化的 dataframe。

15. get_simfin_income_statement利润表

@staticmethod
@tool
def get_simfin_income_statement(
    ticker: Annotated[str, "公司股票代码"],
    report_type: Annotated[str, "annual 或 quarterly"],
) -> str:
    """
    获取公司利润表 (Income Statement)。
    """
    df = interface.get_simfin_income_statement(ticker, report_type)
    return df.to_string()
  • 获取 利润表(营业收入、净利润、毛利率等)。
    • 主要用于盈利能力分析。

16. get_simfin_cashflow_statement 现金流量表

@staticmethod
@tool
def get_simfin_cashflow_statement(
    ticker: Annotated[str, "公司股票代码"],
    report_type: Annotated[str, "annual 或 quarterly"],
) -> str:
    """
    获取公司现金流量表 (Cashflow Statement)。
    """
    df = interface.get_simfin_cashflow_statement(ticker, report_type)
    return df.to_string()
  • 获取 现金流量表(经营活动、投资活动、融资活动现金流)。
    • 用于衡量公司“造血能力”和资金链稳定性。

17. get_simfin_ratios 财务指标比率

@staticmethod
@tool
def get_simfin_ratios(
    ticker: Annotated[str, "公司股票代码"],
    report_type: Annotated[str, "annual 或 quarterly"],
) -> str:
    """
    获取公司财务比率 (Ratios),例如 PE、ROE、负债率。
    """
    df = interface.get_simfin_ratios(ticker, report_type)
    return df.to_string()
  • 获取 财务指标比率(PE, PB, ROE, 负债率, 流动比率)。
    • 适合做跨公司对比。

18. get_simfin_company_info 公司基本信息

@staticmethod
@tool
def get_simfin_company_info(
    ticker: Annotated[str, "公司股票代码"]
) -> str:
    """
    获取公司基本信息(行业、地区、规模等)。
    """
    df = interface.get_simfin_company_info(ticker)
    return df.to_string()
  • 获取公司的 基本信息(行业分类、上市地、公司规模)。
    • 在做行业对比或聚类分析时很有用。

19. get_simfin_shareprices 历史股价数据

@staticmethod
@tool
def get_simfin_shareprices(
    ticker: Annotated[str, "公司股票代码"],
    start_date: Annotated[str, "开始日期"],
    end_date: Annotated[str, "结束日期"],
) -> str:
    """
    获取公司历史股价 (Share Prices)。
    """
    df = interface.get_simfin_shareprices(ticker, start_date, end_date)
    return df.to_string()
  • 获取 SimFin 的历史股价数据。
    • 类似 Yahoo Finance,但数据源不同。

20. get_fundamentals_openai 财务+估值

@staticmethod
@tool
def get_fundamentals_openai(
    ticker: Annotated[str, "公司代码,例如 AAPL, TSLA"],
    report_type: Annotated[str, "annual 或 quarterly"] = "annual",
) -> str:
    """
    使用 OpenAI Agent 获取公司财务和估值信息。
    (已废弃,推荐使用 get_stock_fundamentals_unified)
    """
    logger.warning("⚠️ [DEPRECATED] 推荐使用 get_stock_fundamentals_unified() 代替")
    return interface.get_fundamentals_openai(ticker, report_type)
  • 原始版本,用 OpenAI Agent 来获取财务+估值。
    • 已经 废弃,现在统一整合进 get_stock_fundamentals_unified

21. get_china_fundamentals A 股财务数据

@staticmethod
@tool
def get_china_fundamentals(
    ticker: Annotated[str, "中国股票代码,例如 600519"],
    report_type: Annotated[str, "年度/季度"] = "annual",
) -> str:
    """
    获取中国 A 股财务数据(通过 Tushare 或 AKShare)。
    (已废弃,推荐使用 get_stock_fundamentals_unified)
    """
    logger.warning("⚠️ [DEPRECATED] 推荐使用 get_stock_fundamentals_unified() 代替")
    return interface.get_china_fundamentals(ticker, report_type)
  • 早期用于获取 A 股财务数据
    • 数据源:Tushare / AKShare
    • 废弃,功能已被统一接口替代。

22. get_stock_fundamentals_unified 整合财务+估值

@staticmethod
@tool
def get_stock_fundamentals_unified(
    ticker: Annotated[str, "股票代码,例如 AAPL, 600519, 00700.HK"],
    market: Annotated[str, "市场类型:us / cn / hk"] = "us",
    report_type: Annotated[str, "annual 或 quarterly"] = "annual",
) -> str:
    """
    统一接口:自动识别市场并获取财务数据与估值。
    """
    return interface.get_stock_fundamentals_unified(ticker, market, report_type)
  • 核心统一入口,整合了所有财务+估值工具:
    • 美股 → Yahoo Finance + SimFin
    • A 股 → Tushare + AKShare
    • 港股 → Yahoo Finance + 东方财富
  • 自动根据 market 参数(us/cn/hk)选择合适数据源。
  • 返回内容包括:
    • 财报(资产负债表、利润表、现金流)
    • 估值指标(PE, PB, PEG, ROE 等)
    • 行业对比

create_msg_delete函数(非Toolkit类)

def create_msg_delete():
    def delete_messages(state):
        """Clear messages and add placeholder for Anthropic compatibility"""
        messages = state["messages"]
        
        # Remove all messages
        removal_operations = [RemoveMessage(id=m.id) for m in messages]
        
        # Add a minimal placeholder message
        placeholder = HumanMessage(content="Continue")
        
        return {"messages": removal_operations + [placeholder]}
    
    return delete_messages
  • 清空消息历史,并插入一个占位的 HumanMessage("Continue"),保证在像 Anthropic 这类模型里保持对话兼容性。

ChromaDBManager

类说明

  • 目标:保证整个项目里只存在一个 ChromaDB 客户端,避免多线程或多进程同时初始化带来的冲突。
  • 关键点:用了 单例模式 + 线程锁 来保证全局唯一性。

class ChromaDBManager:
    """单例ChromaDB管理器,避免并发创建集合的冲突"""
    _instance = None
    _lock = threading.Lock()
    _collections: Dict[str, any] = {}
    _client = None

    def __new__(cls):
        if cls._instance is None:
            ...
                    cls._instance = super(ChromaDBManager, cls).__new__(cls)
                    cls._instance._initialized = False
        return cls._instance

   def __init__(self):
    if not self._initialized:
        try:
            ...
            self._initialized = True
    def get_or_create_collection(self, name: str):
        """线程安全地获取或创建集合"""
        with self._lock:
            if name in self._collections:
                logger.info(f"📚 [ChromaDB] 使用缓存集合: {name}")
                return self._collections[name]
            try:
                # 尝试获取现有集合
                collection = self._client.get_collection(name=name)
                logger.info(f"📚 [ChromaDB] 获取现有集合: {name}")
            except Exception:
                try:
                    # 创建新集合
                    ...
            # 缓存集合
            self._collections[name] = collection
            return collection

逐函数解析

类属性

_instance = None # 存储类的唯一实例
_lock = threading.Lock() # 线程锁,保证并发安全
_collections: Dict[str, any] = {} # 缓存已经创建/获取的集合,避免重复创建
_client = None # 底层 ChromaDB 客户端实例

__new__(cls)

def __new__(cls):
    if cls._instance is None:
        with cls._lock:
            if cls._instance is None:
                cls._instance = super(ChromaDBManager, cls).__new__(cls)
                cls._instance._initialized = False
    return cls._instance
  • 作用:实现单例模式,确保只创建一个实例。

  • 逻辑

    1. 如果 _instance 还没创建 → 加锁。
    2. 再次检查 _instance(双重检查锁 DCL,避免竞态)。
    3. 创建实例,并标记 _initialized=False,表示还没初始化。
  • 返回值:类的唯一实例。

__init__(self)

    def __init__(self):
        if not self._initialized:
            try:
                # 自动检测操作系统版本并使用最优配置
                import platform
                system = platform.system()
                
                if system == "Windows":
                    # 使用改进的Windows 11检测
                    from .chromadb_win11_config import is_windows_11
                    if is_windows_11():
                        # Windows 11 或更新版本,使用优化配置
                        from .chromadb_win11_config import get_win11_chromadb_client
                        self._client = get_win11_chromadb_client()
                        logger.info(f"📚 [ChromaDB] Windows 11优化配置初始化完成 (构建号: {platform.version()})")
                    else:
                        # Windows 10 或更老版本,使用兼容配置
                        from .chromadb_win10_config import get_win10_chromadb_client
                        self._client = get_win10_chromadb_client()
                        logger.info(f"📚 [ChromaDB] Windows 10兼容配置初始化完成")
                else:
                    # 非Windows系统,使用标准配置
                    settings = Settings(
                        allow_reset=True,
                        anonymized_telemetry=False,
                        is_persistent=False
                    )
                    self._client = chromadb.Client(settings)
                    logger.info(f"📚 [ChromaDB] {system}标准配置初始化完成")
                
                self._initialized = True
            except Exception as e:
                logger.error(f"❌ [ChromaDB] 初始化失败: {e}")
                # 使用最简单的配置作为备用
                try:
                    settings = Settings(
                        allow_reset=True,
                        anonymized_telemetry=False,  # 关键:禁用遥测
                        is_persistent=False
                    )
                    self._client = chromadb.Client(settings)
                    logger.info(f"📚 [ChromaDB] 使用备用配置初始化完成")
                except Exception as backup_error:
                    # 最后的备用方案
                    self._client = chromadb.Client()
                    logger.warning(f"⚠️ [ChromaDB] 使用最简配置初始化: {backup_error}")
                self._initialized = True
  • 作用:在第一次创建时初始化 ChromaDB 客户端。

  • 逻辑

    1. 检测操作系统(Windows / Linux / Mac)。
    2. Windows:进一步区分 Windows 11 优化配置Windows 10 兼容配置
    3. 其它系统:用标准配置 Settings(...)
    4. 初始化失败 → 尝试 备用配置(禁用遥测、非持久化)。
    5. 如果还失败 → 用 最简配置 chromadb.Client()
  • 输出:无(初始化 _client)。


get_or_create_collection(self, name: str)

    def get_or_create_collection(self, name: str):
        """线程安全地获取或创建集合(输出ChromaDB 集合对象)"""
        with self._lock:
            if name in self._collections:
                logger.info(f"📚 [ChromaDB] 使用缓存集合: {name}")
                return self._collections[name]

            try:
                # 尝试获取现有集合
                collection = self._client.get_collection(name=name)
                logger.info(f"📚 [ChromaDB] 获取现有集合: {name}")
            except Exception:
                try:
                    # 创建新集合
                    collection = self._client.create_collection(name=name)
                    logger.info(f"📚 [ChromaDB] 创建新集合: {name}")
                except Exception as e:
                    # 可能是并发创建,再次尝试获取
                    try:
                        collection = self._client.get_collection(name=name)
                        logger.info(f"📚 [ChromaDB] 并发创建后获取集合: {name}")
                    except Exception as final_error:
                        logger.error(f"❌ [ChromaDB] 集合操作失败: {name}, 错误: {final_error}")
                        raise final_error

            # 缓存集合
            self._collections[name] = collection
            return collection
  • 作用:获取或创建一个集合(类似数据库里的表)。

  • 线程安全:加锁保证多个线程不会同时创建同一个集合。

  • 逻辑

    1. 如果缓存 _collections 里已有 → 直接返回。
    2. 否则尝试 get_collection(name)
    3. 如果失败(说明不存在) → create_collection(name)
    4. 如果创建也失败(可能是并发竞争) → 再尝试 get_collection(name)
    5. 如果还是失败 → 抛出异常。
    6. 成功则缓存集合,并返回。

FinancialSituationMemory类

类说明

  • 类可以视为一个 “财务情况记忆库”
    • 写入:新情况 + 建议 → 生成 embedding → 存入 ChromaDB
    • 读取:新情况 → 生成 embedding → 相似度搜索 → 找到最相关的历史建议

简化后代码

class FinancialSituationMemory:
    def __init__(self, name, config):
        self.config = config
        self.llm_provider = config.get("llm_provider", "openai").lower()
        # 配置向量缓存的长度限制(向量缓存默认启用长度检查)
        self.max_embedding_length = int(os.getenv('MAX_EMBEDDING_CONTENT_LENGTH', '50000'))  # 默认50K字符
        self.enable_embedding_length_check = os.getenv('ENABLE_EMBEDDING_LENGTH_CHECK', 'true').lower() == 'true'  # 向量缓存默认启用
        
        # 根据LLM提供商选择嵌入模型和客户端
        self.fallback_available = False # 初始化降级选项标志
        if self.llm_provider == "dashscope" or self.llm_provider == "alibaba":
            self.embedding = "text-embedding-v3"
            self.client = None  # DashScope不需要OpenAI客户端

            # 设置DashScope API密钥
            dashscope_key = os.getenv('DASHSCOPE_API_KEY')
            if dashscope_key:
                try:
                    # 尝试导入和初始化DashScope
                    import dashscope
                    from dashscope import TextEmbedding
                    dashscope.api_key = dashscope_key
                    logger.info(f"✅ DashScope API密钥已配置,启用记忆功能")
                except ImportError as e:
                    # DashScope包未安装 ...
                except Exception as e:
                    # 其他初始化错误 ...
            else:
                # 没有DashScope密钥,禁用记忆功能 ...
        elif self.llm_provider == "deepseek":
        	...

        # 使用单例ChromaDB管理器
        self.chroma_manager = ChromaDBManager()
        self.situation_collection = self.chroma_manager.get_or_create_collection(name)


    def _smart_text_truncation(self, text, max_length=8192):
        """智能文本截断,保持语义完整性和缓存兼容性"""
        if len(text) <= max_length:
            return text, False  # 返回原文本和是否截断的标志
        # 尝试在句子边界截断
        sentences = text.split('。')
        ...
        # 尝试在段落边界截断
        paragraphs = text.split('\n')
        ...
        return truncated, True

    def get_embedding(self, text):
        """Get embedding for a text using the configured provider"""
        # 检查记忆功能是否被禁用
        if self.client == "DISABLED":
            # 内存功能已禁用,返回空向量 ...
            
        if len(text) == 0: ... # 输入文本长度为0,返回空向量
        
        # 检查是否启用长度限制
        if self.enable_embedding_length_check and text_length > self.max_embedding_length:            
            # 文本过长跳过向量化并存储跳过信息 ...
            return [0.0] * 1024
            
        # 存储文本处理信息
        self._last_text_info = {  ...  }

        if (self.llm_provider == "dashscope" or
            self.llm_provider == "alibaba" or
            (self.llm_provider == "google" and self.client is None) or
            (self.llm_provider == "deepseek" and self.client is None) or
            (self.llm_provider == "openrouter" and self.client is None)):
            # 使用阿里百炼的嵌入模型
            try:    ...
                    return embedding
                else:
				...
                return [0.0] * 1024

    def get_embedding_config_status(self):
        """获取向量缓存配置状态"""
        return {
            'enabled': self.enable_embedding_length_check,
            'max_embedding_length': self.max_embedding_length,
            'max_embedding_length_formatted': f"{self.max_embedding_length:,}字符",
            'provider': self.llm_provider,
            'client_status': 'DISABLED' if self.client == "DISABLED" else 'ENABLED'
        }

    def add_situations(self, situations_and_advice):
        """Add financial situations and their corresponding advice. Parameter is a list of tuples (situation, rec)"""
        ...
        self.situation_collection.add(
            documents=situations,
            metadatas=[{"recommendation": rec} for rec in advice],
            embeddings=embeddings,
            ids=ids,
        )



Example

    # Example usage
    matcher = FinancialSituationMemory()

    # Example data
    example_data = [
        (
            "High inflation rate with rising interest rates and declining consumer spending",
            "Consider defensive sectors like consumer staples and utilities. Review fixed-income portfolio duration.",
        ),
        (
            "Tech sector showing high volatility with increasing institutional selling pressure",
            "Reduce exposure to high-growth tech stocks. Look for value opportunities in established tech companies with strong cash flows.",
        ),
        (
            "Strong dollar affecting emerging markets with increasing forex volatility",
            "Hedge currency exposure in international positions. Consider reducing allocation to emerging market debt.",
        ),
        (
            "Market showing signs of sector rotation with rising yields",
            "Rebalance portfolio to maintain target allocations. Consider increasing exposure to sectors benefiting from higher rates.",
        ),
    ]

    # Add the example situations and recommendations
    matcher.add_situations(example_data)

    # Example query
    current_situation = """
    Market showing increased volatility in tech sector, with institutional investors 
    reducing positions and rising interest rates affecting growth stock valuations
    """

    try:
        recommendations = matcher.get_memories(current_situation, n_matches=2)

        for i, rec in enumerate(recommendations, 1):
            logger.info(f"\nMatch {i}:")
            logger.info(f"Similarity Score: {rec.get('similarity', 0):.2f}")
            logger.info(f"Matched Situation: {rec.get('situation', '')}")
            logger.info(f"Recommendation: {rec.get('recommendation', '')}")

    except Exception as e:
        logger.error(f"Error during recommendation: {str(e)}")

逐函数解析

__init__(self, config)

在这里插入图片描述

  • 作用
    初始化类,设置向量数据库(ChromaDB),并根据 config 和环境变量自动选择合适的 Embedding 服务。

  • 主要逻辑

    1. 保存 config
    2. 初始化一些内部状态(provider、model、status、last_text_info 等)。
    3. 根据 config["provider"] 来选择 embedding 服务:
      • DashScope / Alibaba → 用阿里云的 embedding API
      • DeepSeek → 优先用 DashScope,其次 OpenAI,最后 DeepSeek 自己的 embedding
      • Google → 优先 DashScope,如果配置里有 openai_api_key 就启用 fallback
      • OpenRouter → DashScope embedding
      • 本地 Ollama (localhost:11434) → 使用 nomic-embed-text
      • 默认 → 尝试 OpenAI embedding,失败则禁用
  1. 创建一个 ChromaDB 客户端,并建立/获取一个集合 financial_situations
  • 返回值 : 无(构造函数)。
    def __init__(self, name, config):
        self.config = config
        self.llm_provider = config.get("llm_provider", "openai").lower()

        # 配置向量缓存的长度限制(向量缓存默认启用长度检查)
        self.max_embedding_length = int(os.getenv('MAX_EMBEDDING_CONTENT_LENGTH', '50000'))  # 默认50K字符
        self.enable_embedding_length_check = os.getenv('ENABLE_EMBEDDING_LENGTH_CHECK', 'true').lower() == 'true'  # 向量缓存默认启用
        
        # 根据LLM提供商选择嵌入模型和客户端
        # 初始化降级选项标志
        self.fallback_available = False
        
        if self.llm_provider == "dashscope" or self.llm_provider == "alibaba":
            self.embedding = "text-embedding-v3"
            self.client = None  # DashScope不需要OpenAI客户端

            # 设置DashScope API密钥
            dashscope_key = os.getenv('DASHSCOPE_API_KEY')
            if dashscope_key:
                try:
                    # 尝试导入和初始化DashScope
                    import dashscope
                    from dashscope import TextEmbedding

                    dashscope.api_key = dashscope_key
                    logger.info(f"✅ DashScope API密钥已配置,启用记忆功能")

                    # 可选:测试API连接(简单验证)
                    # 这里不做实际调用,只验证导入和密钥设置

                except ImportError as e:
                    # DashScope包未安装
                    logger.error(f"❌ DashScope包未安装: {e}")
                    self.client = "DISABLED"
                    logger.warning(f"⚠️ 记忆功能已禁用")

                except Exception as e:
                    # 其他初始化错误
                    logger.error(f"❌ DashScope初始化失败: {e}")
                    self.client = "DISABLED"
                    logger.warning(f"⚠️ 记忆功能已禁用")
            else:
                # 没有DashScope密钥,禁用记忆功能
                self.client = "DISABLED"
                logger.warning(f"⚠️ 未找到DASHSCOPE_API_KEY,记忆功能已禁用")
                logger.info(f"💡 系统将继续运行,但不会保存或检索历史记忆")
        elif self.llm_provider == "deepseek":
            # 检查是否强制使用OpenAI嵌入
            force_openai = os.getenv('FORCE_OPENAI_EMBEDDING', 'false').lower() == 'true'

            if not force_openai:
                # 尝试使用阿里百炼嵌入
                dashscope_key = os.getenv('DASHSCOPE_API_KEY')
                if dashscope_key:
                    try:
                        # 测试阿里百炼是否可用
                        import dashscope
                        from dashscope import TextEmbedding

                        dashscope.api_key = dashscope_key
                        # 验证TextEmbedding可用性(不需要实际调用)
                        self.embedding = "text-embedding-v3"
                        self.client = None
                        logger.info(f"💡 DeepSeek使用阿里百炼嵌入服务")
                    except ImportError as e:
                        logger.error(f"⚠️ DashScope包未安装: {e}")
                        dashscope_key = None  # 强制降级
                    except Exception as e:
                        logger.error(f"⚠️ 阿里百炼嵌入初始化失败: {e}")
                        dashscope_key = None  # 强制降级
            else:
                dashscope_key = None  # 跳过阿里百炼

            if not dashscope_key or force_openai:
                # 降级到OpenAI嵌入
                self.embedding = "text-embedding-3-small"
                openai_key = os.getenv('OPENAI_API_KEY')
                if openai_key:
                    self.client = OpenAI(
                        api_key=openai_key,
                        base_url=config.get("backend_url", "https://api.openai.com/v1")
                    )
                    logger.warning(f"⚠️ DeepSeek回退到OpenAI嵌入服务")
                else:
                    # 最后尝试DeepSeek自己的嵌入
                    deepseek_key = os.getenv('DEEPSEEK_API_KEY')
                    if deepseek_key:
                        try:
                            self.client = OpenAI(
                                api_key=deepseek_key,
                                base_url="https://api.deepseek.com"
                            )
                            logger.info(f"💡 DeepSeek使用自己的嵌入服务")
                        except Exception as e:
                            logger.error(f"❌ DeepSeek嵌入服务不可用: {e}")
                            # 禁用内存功能
                            self.client = "DISABLED"
                            logger.info(f"🚨 内存功能已禁用,系统将继续运行但不保存历史记忆")
                    else:
                        # 禁用内存功能而不是抛出异常
                        self.client = "DISABLED"
                        logger.info(f"🚨 未找到可用的嵌入服务,内存功能已禁用")
        elif self.llm_provider == "google":
            # Google AI使用阿里百炼嵌入(如果可用),否则禁用记忆功能
            dashscope_key = os.getenv('DASHSCOPE_API_KEY')
            openai_key = os.getenv('OPENAI_API_KEY')
            
            if dashscope_key:
                try:
                    # 尝试初始化DashScope
                    import dashscope
                    from dashscope import TextEmbedding

                    self.embedding = "text-embedding-v3"
                    self.client = None
                    dashscope.api_key = dashscope_key
                    
                    # 检查是否有OpenAI密钥作为降级选项
                    if openai_key:
                        logger.info(f"💡 Google AI使用阿里百炼嵌入服务(OpenAI作为降级选项)")
                        self.fallback_available = True
                        self.fallback_client = OpenAI(api_key=openai_key, base_url=config["backend_url"])
                        self.fallback_embedding = "text-embedding-3-small"
                    else:
                        logger.info(f"💡 Google AI使用阿里百炼嵌入服务(无降级选项)")
                        self.fallback_available = False
                        
                except ImportError as e:
                    logger.error(f"❌ DashScope包未安装: {e}")
                    self.client = "DISABLED"
                    logger.warning(f"⚠️ Google AI记忆功能已禁用")
                except Exception as e:
                    logger.error(f"❌ DashScope初始化失败: {e}")
                    self.client = "DISABLED"
                    logger.warning(f"⚠️ Google AI记忆功能已禁用")
            else:
                # 没有DashScope密钥,禁用记忆功能
                self.client = "DISABLED"
                self.fallback_available = False
                logger.warning(f"⚠️ Google AI未找到DASHSCOPE_API_KEY,记忆功能已禁用")
                logger.info(f"💡 系统将继续运行,但不会保存或检索历史记忆")
        elif self.llm_provider == "openrouter":
            # OpenRouter支持:优先使用阿里百炼嵌入,否则禁用记忆功能
            dashscope_key = os.getenv('DASHSCOPE_API_KEY')
            if dashscope_key:
                try:
                    # 尝试使用阿里百炼嵌入
                    import dashscope
                    from dashscope import TextEmbedding

                    self.embedding = "text-embedding-v3"
                    self.client = None
                    dashscope.api_key = dashscope_key
                    logger.info(f"💡 OpenRouter使用阿里百炼嵌入服务")
                except ImportError as e:
                    logger.error(f"❌ DashScope包未安装: {e}")
                    self.client = "DISABLED"
                    logger.warning(f"⚠️ OpenRouter记忆功能已禁用")
                except Exception as e:
                    logger.error(f"❌ DashScope初始化失败: {e}")
                    self.client = "DISABLED"
                    logger.warning(f"⚠️ OpenRouter记忆功能已禁用")
            else:
                # 没有DashScope密钥,禁用记忆功能
                self.client = "DISABLED"
                logger.warning(f"⚠️ OpenRouter未找到DASHSCOPE_API_KEY,记忆功能已禁用")
                logger.info(f"💡 系统将继续运行,但不会保存或检索历史记忆")
        elif config["backend_url"] == "http://localhost:11434/v1":
            self.embedding = "nomic-embed-text"
            self.client = OpenAI(base_url=config["backend_url"])
        else:
            self.embedding = "text-embedding-3-small"
            openai_key = os.getenv('OPENAI_API_KEY')
            if openai_key:
                self.client = OpenAI(
                    api_key=openai_key,
                    base_url=config["backend_url"]
                )
            else:
                self.client = "DISABLED"
                logger.warning(f"⚠️ 未找到OPENAI_API_KEY,记忆功能已禁用")

        # 使用单例ChromaDB管理器
        self.chroma_manager = ChromaDBManager()
        self.situation_collection = self.chroma_manager.get_or_create_collection(name)

_smart_text_truncation(self, text, max_length)

    def _smart_text_truncation(self, text, max_length=8192):
        """智能文本截断,保持语义完整性和缓存兼容性"""
        if len(text) <= max_length:
            return text, False  # 返回原文本和是否截断的标志
        
        # 尝试在句子边界截断
        sentences = text.split('。')
        if len(sentences) > 1:
            truncated = ""
            for sentence in sentences:
                if len(truncated + sentence + '。') <= max_length - 50:  # 留50字符余量
                    truncated += sentence + '。'
                else:
                    break
            if len(truncated) > max_length // 2:  # 至少保留一半内容
                logger.info(f"📝 智能截断:在句子边界截断,保留{len(truncated)}/{len(text)}字符")
                return truncated, True
        
        # 尝试在段落边界截断
        paragraphs = text.split('\n')
        if len(paragraphs) > 1:
            truncated = ""
            for paragraph in paragraphs:
                if len(truncated + paragraph + '\n') <= max_length - 50:
                    truncated += paragraph + '\n'
                else:
                    break
            if len(truncated) > max_length // 2:
                logger.info(f"📝 智能截断:在段落边界截断,保留{len(truncated)}/{len(text)}字符")
                return truncated, True
        
        # 最后选择:保留前半部分和后半部分的关键信息
        front_part = text[:max_length//2]
        back_part = text[-(max_length//2-100):]  # 留100字符给连接符
        truncated = front_part + "\n...[内容截断]...\n" + back_part
        logger.warning(f"⚠️ 强制截断:保留首尾关键信息,{len(text)}字符截断为{len(truncated)}字符")
        return truncated, True
  • 作用
    保证输入文本不会超过 max_length,但尽量保持语义完整(比如按句子、段落截断)。

  • 主要逻辑

    1. 如果文本长度 ≤ max_length → 原样返回。
    2. 如果太长:
      • 尝试在句号(。.!?)之后截断,保留前面一段完整句子。
      • 否则尝试按段落 \n 截断。
      • 如果都不行,就取前 max_length//2 和后 max_length//2 拼接,中间插入 ...
    3. 把截断后的文本保存到 self.last_text_info,记录是否截断、采用哪种策略、原始/最终长度等。
  • 输入

    • text (str):原始文本
    • max_length (int):允许的最大长度
  • 输出

    • 截断后的文本 (str)

get_embedding(self, text, max_length=1000)

    def get_embedding(self, text):
        """Get embedding for a text using the configured provider"""

        # 检查记忆功能是否被禁用
        if self.client == "DISABLED":
            # 内存功能已禁用,返回空向量
            logger.debug(f"⚠️ 记忆功能已禁用,返回空向量")
            return [0.0] * 1024  # 返回1024维的零向量

        # 验证输入文本
        if not text or not isinstance(text, str):
            logger.warning(f"⚠️ 输入文本为空或无效,返回空向量")
            return [0.0] * 1024

        text_length = len(text)
        if text_length == 0:
            logger.warning(f"⚠️ 输入文本长度为0,返回空向量")
            return [0.0] * 1024
        
        # 检查是否启用长度限制
        if self.enable_embedding_length_check and text_length > self.max_embedding_length:
            logger.warning(f"⚠️ 文本过长({text_length:,}字符 > {self.max_embedding_length:,}字符),跳过向量化")
            # 存储跳过信息
            self._last_text_info = {
                'original_length': text_length,
                'processed_length': 0,
                'was_truncated': False,
                'was_skipped': True,
                'provider': self.llm_provider,
                'strategy': 'length_limit_skip',
                'max_length': self.max_embedding_length
            }
            return [0.0] * 1024
        
        # 记录文本信息(不进行任何截断)
        if text_length > 8192:
            logger.info(f"📝 处理长文本: {text_length}字符,提供商: {self.llm_provider}")
        
        # 存储文本处理信息
        self._last_text_info = {
            'original_length': text_length,
            'processed_length': text_length,  # 不截断,保持原长度
            'was_truncated': False,  # 永不截断
            'was_skipped': False,
            'provider': self.llm_provider,
            'strategy': 'no_truncation_with_fallback'  # 标记策略
        }

        if (self.llm_provider == "dashscope" or
            self.llm_provider == "alibaba" or
            (self.llm_provider == "google" and self.client is None) or
            (self.llm_provider == "deepseek" and self.client is None) or
            (self.llm_provider == "openrouter" and self.client is None)):
            # 使用阿里百炼的嵌入模型
            try:
                # 导入DashScope模块
                import dashscope
                from dashscope import TextEmbedding

                # 检查DashScope API密钥是否可用
                if not hasattr(dashscope, 'api_key') or not dashscope.api_key:
                    logger.warning(f"⚠️ DashScope API密钥未设置,记忆功能降级")
                    return [0.0] * 1024  # 返回空向量

                # 尝试调用DashScope API
                response = TextEmbedding.call(
                    model=self.embedding,
                    input=text
                )

                # 检查响应状态
                if response.status_code == 200:
                    # 成功获取embedding
                    embedding = response.output['embeddings'][0]['embedding']
                    logger.debug(f"✅ DashScope embedding成功,维度: {len(embedding)}")
                    return embedding
                else:
                    # API返回错误状态码
                    error_msg = f"{response.code} - {response.message}"
                    
                    # 检查是否为长度限制错误
                    if any(keyword in error_msg.lower() for keyword in ['length', 'token', 'limit', 'exceed']):
                        logger.warning(f"⚠️ DashScope长度限制: {error_msg}")
                        
                        # 检查是否有降级选项
                        if hasattr(self, 'fallback_available') and self.fallback_available:
                            logger.info(f"💡 尝试使用OpenAI降级处理长文本")
                            try:
                                response = self.fallback_client.embeddings.create(
                                    model=self.fallback_embedding,
                                    input=text
                                )
                                embedding = response.data[0].embedding
                                logger.info(f"✅ OpenAI降级成功,维度: {len(embedding)}")
                                return embedding
                            except Exception as fallback_error:
                                logger.error(f"❌ OpenAI降级失败: {str(fallback_error)}")
                                logger.info(f"💡 所有降级选项失败,记忆功能降级")
                                return [0.0] * 1024
                        else:
                            logger.info(f"💡 无可用降级选项,记忆功能降级")
                            return [0.0] * 1024
                    else:
                        logger.error(f"❌ DashScope API错误: {error_msg}")
                        return [0.0] * 1024  # 返回空向量而不是抛出异常

            except Exception as e:
                error_str = str(e).lower()
                
                # 检查是否为长度限制错误
                if any(keyword in error_str for keyword in ['length', 'token', 'limit', 'exceed', 'too long']):
                    logger.warning(f"⚠️ DashScope长度限制异常: {str(e)}")
                    
                    # 检查是否有降级选项
                    if hasattr(self, 'fallback_available') and self.fallback_available:
                        logger.info(f"💡 尝试使用OpenAI降级处理长文本")
                        try:
                            response = self.fallback_client.embeddings.create(
                                model=self.fallback_embedding,
                                input=text
                            )
                            embedding = response.data[0].embedding
                            logger.info(f"✅ OpenAI降级成功,维度: {len(embedding)}")
                            return embedding
                        except Exception as fallback_error:
                            logger.error(f"❌ OpenAI降级失败: {str(fallback_error)}")
                            logger.info(f"💡 所有降级选项失败,记忆功能降级")
                            return [0.0] * 1024
                    else:
                        logger.info(f"💡 无可用降级选项,记忆功能降级")
                        return [0.0] * 1024
                elif 'import' in error_str:
                    logger.error(f"❌ DashScope包未安装: {str(e)}")
                elif 'connection' in error_str:
                    logger.error(f"❌ DashScope网络连接错误: {str(e)}")
                elif 'timeout' in error_str:
                    logger.error(f"❌ DashScope请求超时: {str(e)}")
                else:
                    logger.error(f"❌ DashScope embedding异常: {str(e)}")
                
                logger.warning(f"⚠️ 记忆功能降级,返回空向量")
                return [0.0] * 1024
        else:
            # 使用OpenAI兼容的嵌入模型
            if self.client is None:
                logger.warning(f"⚠️ 嵌入客户端未初始化,返回空向量")
                return [0.0] * 1024  # 返回空向量
            elif self.client == "DISABLED":
                # 内存功能已禁用,返回空向量
                logger.debug(f"⚠️ 内存功能已禁用,返回空向量")
                return [0.0] * 1024  # 返回1024维的零向量

            # 尝试调用OpenAI兼容的embedding API
            try:
                response = self.client.embeddings.create(
                    model=self.embedding,
                    input=text
                )
                embedding = response.data[0].embedding
                logger.debug(f"✅ {self.llm_provider} embedding成功,维度: {len(embedding)}")
                return embedding

            except Exception as e:
                error_str = str(e).lower()
                
                # 检查是否为长度限制错误
                length_error_keywords = [
                    'token', 'length', 'too long', 'exceed', 'maximum', 'limit',
                    'context', 'input too large', 'request too large'
                ]
                
                is_length_error = any(keyword in error_str for keyword in length_error_keywords)
                
                if is_length_error:
                    # 长度限制错误:直接降级,不截断重试
                    logger.warning(f"⚠️ {self.llm_provider}长度限制: {str(e)}")
                    logger.info(f"💡 为保证分析准确性,不截断文本,记忆功能降级")
                else:
                    # 其他类型的错误
                    if 'attributeerror' in error_str:
                        logger.error(f"❌ {self.llm_provider} API调用错误: {str(e)}")
                    elif 'connectionerror' in error_str or 'connection' in error_str:
                        logger.error(f"❌ {self.llm_provider}网络连接错误: {str(e)}")
                    elif 'timeout' in error_str:
                        logger.error(f"❌ {self.llm_provider}请求超时: {str(e)}")
                    elif 'keyerror' in error_str:
                        logger.error(f"❌ {self.llm_provider}响应格式错误: {str(e)}")
                    else:
                        logger.error(f"❌ {self.llm_provider} embedding异常: {str(e)}")
                
                logger.warning(f"⚠️ 记忆功能降级,返回空向量")
                return [0.0] * 1024
  • 作用 : 把文本转成向量 embedding。

  • 主要逻辑

    1. 调用 _smart_text_truncation 保证长度安全。
    2. 根据 provider 选择对应的 API:
      • DashScope → 调用 dashscope.TextEmbedding.call,模型是 text-embedding-v2
      • OpenAI → 调用 client.embeddings.create(model="text-embedding-3-small")
      • DeepSeek → 可能走 DeepSeek 自己的 embedding API
      • Ollama → 调用 client.embeddings.create(model="nomic-embed-text")
      • 其它情况 → 返回 [0.0] * 1024(代表禁用)
    3. 如果调用失败,捕获异常并返回零向量。
  • 输入

    • text (str):输入文本
    • max_length (int):最大允许长度(默认 1000)
  • 输出

    • 向量 embedding (list[float])

get_embedding_config_status(self)

    def get_embedding_config_status(self):
        """获取向量缓存配置状态"""
        return {
            'enabled': self.enable_embedding_length_check,
            'max_embedding_length': self.max_embedding_length,
            'max_embedding_length_formatted': f"{self.max_embedding_length:,}字符",
            'provider': self.llm_provider,
            'client_status': 'DISABLED' if self.client == "DISABLED" else 'ENABLED'
        }
  • 作用 返回当前 embedding 的配置信息,主要用于调试/检查状态。

get_last_text_info(self)

    def get_last_text_info(self):
        """获取最后处理的文本信息"""
        return getattr(self, '_last_text_info', None)
  • 作用:返回最近一次 _smart_text_truncation 的信息(调试用)。

add_situations(self, situations)

    def add_situations(self, situations_and_advice):
        """Add financial situations and their corresponding advice. Parameter is a list of tuples (situation, rec)"""

        situations = []
        advice = []
        ids = []
        embeddings = []

        offset = self.situation_collection.count()

        for i, (situation, recommendation) in enumerate(situations_and_advice):
            situations.append(situation)
            advice.append(recommendation)
            ids.append(str(offset + i))
            embeddings.append(self.get_embedding(situation))

        self.situation_collection.add(
            documents=situations,
            metadatas=[{"recommendation": rec} for rec in advice],
            embeddings=embeddings,
            ids=ids,
        )
  • 作用:把一组新的“财务情况 + 建议”存入数据库。

主要逻辑

  1. 遍历 situations,每个元素是 (situation, recommendation)

  2. situation 文本生成 embedding。

  3. 构造数据项:

    {
      "id": str(uuid.uuid4()),
      "embedding": 向量,
      "metadata": {"recommendation": 建议}
    }
    
  4. 批量写入 self.collection(ChromaDB)。

get_cache_info(self)

    def get_cache_info(self):
        """获取缓存相关信息,用于调试和监控"""
        info = {
            'collection_count': self.situation_collection.count(),
            'client_status': 'enabled' if self.client != "DISABLED" else 'disabled',
            'embedding_model': self.embedding,
            'provider': self.llm_provider
        }
        
        # 添加最后一次文本处理信息
        if hasattr(self, '_last_text_info'):
            info['last_text_processing'] = self._last_text_info
            
        return info
  • 作用
    返回当前 ChromaDB 集合的一些元信息(比如名称、模型、provider 等),帮助确认系统运行状态。

get_memories(self, situation, n_results=5)

    def get_memories(self, current_situation, n_matches=1):
        """Find matching recommendations using embeddings with smart truncation handling"""
        
        # 获取当前情况的embedding
        query_embedding = self.get_embedding(current_situation)
        
        # 检查是否为空向量(记忆功能被禁用或出错)
        if all(x == 0.0 for x in query_embedding):
            logger.debug(f"⚠️ 查询embedding为空向量,返回空结果")
            return []
        
        # 检查是否有足够的数据进行查询
        collection_count = self.situation_collection.count()
        if collection_count == 0:
            logger.debug(f"📭 记忆库为空,返回空结果")
            return []
        
        # 调整查询数量,不能超过集合中的文档数量
        actual_n_matches = min(n_matches, collection_count)
        
        try:
            # 执行相似度查询
            results = self.situation_collection.query(
                query_embeddings=[query_embedding],
                n_results=actual_n_matches
            )
            
            # 处理查询结果
            memories = []
            if results and 'documents' in results and results['documents']:
                documents = results['documents'][0]
                metadatas = results.get('metadatas', [[]])[0]
                distances = results.get('distances', [[]])[0]
                
                for i, doc in enumerate(documents):
                    metadata = metadatas[i] if i < len(metadatas) else {}
                    distance = distances[i] if i < len(distances) else 1.0
                    
                    memory_item = {
                        'situation': doc,
                        'recommendation': metadata.get('recommendation', ''),
                        'similarity': 1.0 - distance,  # 转换为相似度分数
                        'distance': distance
                    }
                    memories.append(memory_item)
                
                # 记录查询信息
                if hasattr(self, '_last_text_info') and self._last_text_info.get('was_truncated'):
                    logger.info(f"🔍 截断文本查询完成,找到{len(memories)}个相关记忆")
                    logger.debug(f"📊 原文长度: {self._last_text_info['original_length']}, "
                               f"处理后长度: {self._last_text_info['processed_length']}")
                else:
                    logger.debug(f"🔍 记忆查询完成,找到{len(memories)}个相关记忆")
            
            return memories
            
        except Exception as e:
            logger.error(f"❌ 记忆查询失败: {str(e)}")
            return []
  • 作用:检索最相似的历史情况,返回对应的建议。

  • 主要逻辑

    1. 对输入 situation 生成 embedding。
    2. 调用 self.collection.query,检索最相似的 n_results 条记录。
    3. 解析返回结果:提取文本、embedding 相似度/距离、推荐建议。
    4. 返回一个结果列表。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值