一.前言
本文是 LangChain 框架源码分析系列文章之一,以工程设计的能力视角切入分析当下最火的大模型应用开发框架。与大模型原理学习对比,学习 LangChain 框架需要做思维模式的转变,关注的重点从 Why 转为 What&How。
什么是回调?回调是一个很常见的设计,每个研发同事对其都有自身的理解。作者本身理解回调设计是从 IO 模型的同步/异步调用方式入手。
•同步调用 - 调用端发出请求之后,线程(或协程等抽象)必须一直等待结果返回;
•异步调用 - 调用端发出请求之后,线程(或协程等抽象)立即返回。
针对异步调用,如果调用端不关注返回结果就很完美;但是如果需要处理返回结果该怎么办?回调就是针对这种场景的补充设计,通过事件响应机制结合回调完成异步调用的结果返回处理。
阅读源码是一个相对枯燥的过程,并且穿插源码在其中一定程度上会影响文章的可读性;为此,结论先行 - LangChain 回调能力由 Python asyncio
标准库 和 围绕大模型核心抽象形成的自定义事件体系 结合而成。
为了避免文章与代码的来回切换,文中携带了大量的模块源码,若仅关注设计而对代码不感兴趣,可以选择跳过代码阅读。
二.正文
1.代码地图
代码视角切入
├──libs
│ ├── core
│ │ ├── langchain_core
│ │ │ ├── callbacks
│ │ │ │ ├── __init__.py
│ │ │ │ ├── base.py
│ │ │ │ ├── manager.py
│ │ │ │ ├── stdout.py
│ │ │ │ ├── streaming_stdout.py
熟悉 Python 语言编程的情况下,剖析 LangChain 框架可能会感觉到赏心悦目,因为工程质量很高,作者作为 Python 语言的外行人员也能快速入手阅读。
进入callbacks模块的分析,以语言特性 __init__.py 入手,明确需要剖析的类清单。
__all__ =[
"RetrieverManagerMixin",
"LLMManagerMixin",
"ChainManagerMixin",
"ToolManagerMixin",
"Callbacks",
"CallbackManagerMixin",
"RunManagerMixin",
"BaseCallbackHandler",
"AsyncCallbackHandler",
"BaseCallbackManager",
"BaseRunManager",
"RunManager",
"ParentRunManager",
"AsyncRunManager",
"AsyncParentRunManager",
"CallbackManagerForLLMRun",
"AsyncCallbackManagerForLLMRun",
"CallbackManagerForChainRun",
"AsyncCallbackManagerForChainRun",
"CallbackManagerForToolRun",
"AsyncCallbackManagerForToolRun",
"CallbackManagerForRetrieverRun",
"AsyncCallbackManagerForRetrieverRun",
"CallbackManager",
"CallbackManagerForChainGroup",
"AsyncCallbackManager",
"AsyncCallbackManagerForChainGroup",
"StdOutCallbackHandler",
"StreamingStdOutCallbackHandler",]
软件从业人员是有自己的知识体系的,代码语言的书写遵循通用规范 - 因此我们可以“望文生义”,通过名字进行简单分类,形成一个模块组织体系(“代码地图”)。
如上通过简单的关键字切分,28个类呈现出两大核心:
•大模型抽象(LLM、Retriever、Chain、Tool)
•回调技术抽象(Callback、CallbackManager、RunManager、CallbackHandler)。
与预期相符 - 大模型的回调能力在模块设计上划分为大模型和回调两个核心抽象。
Tip: 这种形式仅仅是作者用于简化认知的一种尝试,与UML相比是否是一种浪费时间的冗余行为有待商榷。
前置准备就绪后,可以开始进行源码分析。作者将源码分析视作编码风格和设计技巧的学习过程。编码风格很直观;设计技巧可以通过逆向思维来思考 - 工程落地【设计 -> 编码】、源码分析【编码 -> 设计】。
基于“代码地图”展开源码分析还有一个问题 - 选择那个点开始?这个问题在复杂“代码地图”场景下会表现的很显眼,毕竟千头万绪总会让人陷入选择困难。作者的建议是看个人喜好,而且在选择一个点进行分析过程中发现遇到很多理解上的阻力时不要纠结,果断换。从最终设计的视角看,点与点之间必然是相互关联的,最终会形成一个闭环的设计体系。
Tip: 点与点之间的关联关系参照UML类图关系。
为简化分析的复杂度,可以将UML关系中最常见的六大类(依赖、关联、组合、聚合、实现、泛化)简化为两类:
1.(依赖、关联、组合、聚合)=> 关联【横向延伸】
2.(实现、泛化)=> 继承【纵向延伸】
简化后我们使用横向和纵向延伸的方式来探索“代码地图”。
2.地图探索
step1. Callbacks
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
等价转换
Callbacks = Union[List[BaseCallbackHandler], BaseCallbackManager, None]
Callbacks的源码剖析至此就已经结束了,并且提供给我们下一步的分析对象 - BaseCallbackHandler 和 BaseCallbackManager;作为 Union 的并列元素,我们可以提出假设是否两者具有相同的能力?猜想1
step2. BaseCallbackHandler
class BaseCallbackHandler(
LLMManagerMixin,
ChainManagerMixin,
ToolManagerMixin,
RetrieverManagerMixin,
CallbackManagerMixin,
RunManagerMixin,):
"""Base callback handler that handles callbacks from LangChain."""
raise_error:bool=False
run_inline:bool=False
@property
def ignore_llm(self)->bool:
"""Whether to ignore LLM callbacks."""
return False
@property
def ignore_retry(self)->bool:
"""Whether to ignore retry callbacks."""
return False
@property
def ignore_chain(self)->bool:
"""Whether to ignore chain callbacks."""
return False
@property
def ignore_agent(self)->bool:
"""Whether to ignore agent callbacks."""
return False
@property
def ignore_retriever(self)->bool:
"""Whether to ignore retriever callbacks."""
return False
@property
def ignore_chat_model(self)->bool:
"""Whether to ignore chat model callbacks."""
return False
代码不少但是结构很清晰。
•首先是多重继承了很多 mixin 类,熟悉 Mixin 设计模式可能有一点别扭,这种模式主要用来弥补单继承约束下的类复用;但是设计本身就是一种偏艺术的工作 - 用代码语言写文章,每个人都有自己喜好的风格。
•其次是一系列过滤能力的方法定义,作为模版方法用以规范“忽略回调”的行为规范。
--- Mixin设计模式 ---
mixin 本身是混合的意思。在软件设计上下文来看,mixin 特指一种设计模式 - 不妨就称其为 Mixin 设计模式。
参照维基百科的定义,mixin 是一个包含可被其他类使用而无需继承的方法的类。mixin 提供了实现特定行为的方法,但是从用户视角看,这些方法需要结合其他类才方便使用。
官方定义一般不会太白话,我们用白话来解释一下。一个纯方法定义的工具类,作为“混儿”哪儿需要掺乎到哪儿。
*Mixin
定义事件响应机制处理的行为规范。
class RetrieverManagerMixin:
"""Mixin for Retriever callbacks."""
def on_retriever_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when Retriever errors."""
def on_retriever_end(
self,
documents: Sequence[Document],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when Retriever ends running."""
class LLMManagerMixin:
"""Mixin for LLM callbacks."""
def on_llm_new_token(
self,
token: str,
*,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run on new LLM token. Only available when streaming is enabled.
Args:
token (str): The new token.
chunk (GenerationChunk | ChatGenerationChunk): The new generated chunk,
containing content and other information.
"""
def on_llm_end(
self,
response: LLMResult,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when LLM ends running."""
def on_llm_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when LLM errors.
Args:
error (BaseException): The error that occurred.
kwargs (Any): Additional keyword arguments.
- response (LLMResult): The response which was generated before
the error occurred.
"""
class ChainManagerMixin:
"""Mixin for chain callbacks."""
def on_chain_end(
self,
outputs: Dict[str, Any],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when chain ends running."""
def on_chain_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when chain errors."""
def on_agent_action(
self,
action: AgentAction,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run on agent action."""
def on_agent_finish(
self,
finish: AgentFinish,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run on agent end."""
class ToolManagerMixin:
"""Mixin for tool callbacks."""
def on_tool_end(
self,
output: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when tool ends running."""
def on_tool_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when tool errors."""
class CallbackManagerMixin:
"""Mixin for callback manager."""
def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Run when LLM starts running.
**ATTENTION**: This method is called for non-chat models (regular LLMs). If
you're implementing a handler for a chat model,
you should use on_chat_model_start instead.
"""
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Run when a chat model starts running.
**ATTENTION**: This method is called for chat models. If you're implementing
a handler for a non-chat model, you should use on_llm_start instead.
"""
# NotImplementedError is thrown intentionally
# Callback handler will fall back to on_llm_start if this is exception is thrown
raise NotImplementedError(
f"{self.__class__.__name__} does not implement `on_chat_model_start`"
)
def on_retriever_start(
self,
serialized: Dict[str, Any],
query: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Run when Retriever starts running."""
def on_chain_start(
self,
serialized: Dict[str, Any],
inputs: Dict[str, Any],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Run when chain starts running."""
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
inputs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Run when tool starts running."""
class RunManagerMixin:
"""Mixin for run manager."""
def on_text(
self,
text: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run on arbitrary text."""
def on_retry(
self,
retry_state: RetryCallState,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run on a retry event."""
step3. BaseCallbackManager
class BaseCallbackManager(CallbackManagerMixin):
"""Base callback manager that handles callbacks from LangChain."""
def __init__(
self,
handlers: List[BaseCallbackHandler],
inheritable_handlers: Optional[List[BaseCallbackHandler]] = None,
parent_run_id: Optional[UUID] = None,
*,
tags: Optional[List[str]] = None,
inheritable_tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
inheritable_metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""Initialize callback manager."""
self.handlers: List[BaseCallbackHandler] = handlers
self.inheritable_handlers: List[BaseCallbackHandler] = (
inheritable_handlers or []
)
self.parent_run_id: Optional[UUID] = parent_run_id
self.tags = tags or []
self.inheritable_tags = inheritable_tags or []
self.metadata = metadata or {}
self.inheritable_metadata = inheritable_metadata or {}
def copy(self: T) -> T:
"""Copy the callback manager."""
return self.__class__(
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
metadata=self.metadata,
inheritable_metadata=self.inheritable_metadata,
)
@property
def is_async(self) -> bool:
"""Whether the callback manager is async."""
return False
def add_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None:
"""Add a handler to the callback manager."""
if handler not in self.handlers:
self.handlers.append(handler)
if inherit and handler not in self.inheritable_handlers:
self.inheritable_handlers.append(handler)
def remove_handler(self, handler: BaseCallbackHandler) -> None:
"""Remove a handler from the callback manager."""
self.handlers.remove(handler)
self.inheritable_handlers.remove(handler)
def set_handlers(
self, handlers: List[BaseCallbackHandler], inherit: bool = True
) -> None:
"""Set handlers as the only handlers on the callback manager."""
self.handlers = []
self.inheritable_handlers = []
for handler in handlers:
self.add_handler(handler, inherit=inherit)
def set_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None:
"""Set handler as the only handler on the callback manager."""
self.set_handlers([handler], inherit=inherit)
def add_tags(self, tags: List[str], inherit: bool = True) -> None:
for tag in tags:
if tag in self.tags:
self.remove_tags([tag])
self.tags.extend(tags)
if inherit:
self.inheritable_tags.extend(tags)
def remove_tags(self, tags: List[str]) -> None:
for tag in tags:
self.tags.remove(tag)
self.inheritable_tags.remove(tag)
def add_metadata(self, metadata: Dict[str, Any], inherit: bool = True) -> None:
self.metadata.update(metadata)
if inherit:
self.inheritable_metadata.update(metadata)
def remove_metadata(self, keys: List[str]) -> None:
for key in keys:
self.metadata.pop(key)
self.inheritable_metadata.pop(key)
除了 is_async、tag、metadata 关键字外,剩下的内容应该可以一图以蔽之。
进一步佐证了 猜想1 。
以 Callbacks 为起点向外延伸,覆盖了地图上的红色区域;此时从类的直观关联关系【直观 - 代码可见;关联关系 - UML类图的关联和继承(父类)关系】已经无法继续。从 IDE 上看可以发现一个很有意思的现象。所有类定义收敛于 base.py 文件(体现出极佳的工程质量)。
同时我们也发现了 base.py 中还存在的一个类 AsyncCallbackHandler。此类不满足 直观关联关系 ,则可以大胆假设,其一定满足被继承关系(子类)。
step4. AsyncCallbackHandler
添加了一整套 异步 事件响应机制处理的方法。
class AsyncCallbackHandler(BaseCallbackHandler):
"""Async callback handler that handles callbacks from LangChain."""
async def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""Run when LLM starts running.
**ATTENTION**: This method is called for non-chat models (regular LLMs). If
you're implementing a handler for a chat model,
you should use on_chat_model_start instead.
"""
async def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Run when a chat model starts running.
**ATTENTION**: This method is called for chat models. If you're implementing
a handler for a non-chat model, you should use on_llm_start instead.
"""
# NotImplementedError is thrown intentionally
# Callback handler will fall back to on_llm_start if this is exception is thrown
raise NotImplementedError(
f"{self.__class__.__name__} does not implement `on_chat_model_start`"
)
async def on_llm_new_token(
self,
token: str,
*,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
async def on_llm_end(
self,
response: LLMResult,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run when LLM ends running."""
async def on_llm_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run when LLM errors.
Args:
error: The error that occurred.
kwargs (Any): Additional keyword arguments.
- response (LLMResult): The response which was generated before
the error occurred.
"""
async def on_chain_start(
self,
serialized: Dict[str, Any],
inputs: Dict[str, Any],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""Run when chain starts running."""
async def on_chain_end(
self,
outputs: Dict[str, Any],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run when chain ends running."""
async def on_chain_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run when chain errors."""
async def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
inputs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""Run when tool starts running."""
async def on_tool_end(
self,
output: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run when tool ends running."""
async def on_tool_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run when tool errors."""
async def on_text(
self,
text: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run on arbitrary text."""
async def on_retry(
self,
retry_state: RetryCallState,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run on a retry event."""
async def on_agent_action(
self,
action: AgentAction,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run on agent action."""
async def on_agent_finish(
self,
finish: AgentFinish,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run on agent end."""
async def on_retriever_start(
self,
serialized: Dict[str, Any],
query: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""Run on retriever start."""
async def on_retriever_end(
self,
documents: Sequence[Document],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run on retriever end."""
async def on_retriever_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Run on retriever error."""
在地图上继续以 BaseCallbackHandler 为基础纵向延伸,可以找到 StdOutCallbackHandler 和 StreamingStdOutCallbackHandler 两个类。从命名上看,两者可能存在继承关系,实际上不然。
step5. StdOutCallbackHandler
stdout.py 文件仅持有 StdOutCallbackHandler 类定义。
提供了部分事件响应处理的相关实现 - 输出提示信息到标准输出(e.g.终端控制台)。
class StdOutCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out."""
def __init__(self, color: Optional[str] = None) -> None:
"""Initialize callback handler."""
self.color = color
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Print out that we are entering a chain."""
class_name = serialized.get("name", serialized.get("id", ["<unknown>"])[-1])
print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m") # noqa: T201
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain."""
print("\n\033[1m> Finished chain.\033[0m") # noqa: T201
def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any:
"""Run on agent action."""
print_text(action.log, color=color or self.color)
def on_tool_end(
self,
output: str,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
"""If not the final action, print out observation."""
if observation_prefix is not None:
print_text(f"\n{observation_prefix}")
print_text(output, color=color or self.color)
if llm_prefix is not None:
print_text(f"\n{llm_prefix}")
def on_text(
self,
text: str,
color: Optional[str] = None,
end: str = "",
**kwargs: Any,
) -> None:
"""Run when agent ends."""
print_text(text, color=color or self.color, end=end)
def on_agent_finish(
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
) -> None:
"""Run on agent end."""
print_text(finish.log, color=color or self.color, end="\n")
StreamingStdOutCallbackHandler
streaming_stdout.py 文件仅持有 StreamingStdOutCallbackHandler 类定义。提供了部分事件响应处理的相关定义,但仅实现了 on_llm_new_token 方法,方法很简单将 token 内容直接写入到标准输出,并通过 flush( ) 方法确保这些内容立即被显示出来。结合流式能力来看,相当于将大模型的流式输出(token)实时展示在标准输出(e.g.终端控制台)。
class StreamingStdOutCallbackHandler(BaseCallbackHandler):
"""Callback handler for streaming. Only works with LLMs that support streaming."""
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts running."""
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
**kwargs: Any,
) -> None:
"""Run when LLM starts running."""
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
sys.stdout.write(token)
sys.stdout.flush()
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when LLM errors."""
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running."""
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when chain errors."""
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Run when tool starts running."""
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
pass
def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running."""
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when tool errors."""
def on_text(self, text: str, **kwargs: Any) -> None:
"""Run on arbitrary text."""
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run on agent end."""
从工程视角看,只剩下 manager.py 一个文件;但是从地图上看 RunManager 和 CallbackManager 为核心的类仍占据地图的一半。这与我们的认知是相符的,因为截止当前,我们只看到了回调模块的大体框架(声明),具体如何生效的逻辑没有(实现);因此核心内容必然集中在最后一个文件中。
step6. BaseRunManager
学会看注解"""Base class for run manager (a bound callback manager).""" 拆解为两个重点:
•RunManager 的基类
•绑定回调的管理器
class BaseRunManager(RunManagerMixin):
"""Base class for run manager (a bound callback manager)."""
def __init__(
self,
*,
run_id: UUID,
handlers: List[BaseCallbackHandler],
inheritable_handlers: List[BaseCallbackHandler],
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
inheritable_tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
inheritable_metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""Initialize the run manager.
Args:
run_id (UUID): The ID of the run.
handlers (List[BaseCallbackHandler]): The list of handlers.
inheritable_handlers (List[BaseCallbackHandler]):
The list of inheritable handlers.
parent_run_id (UUID, optional): The ID of the parent run.
Defaults to None.
tags (Optional[List[str]]): The list of tags.
inheritable_tags (Optional[List[str]]): The list of inheritable tags.
metadata (Optional[Dict[str, Any]]): The metadata.
inheritable_metadata (Optional[Dict[str, Any]]): The inheritable metadata.
"""
self.run_id = run_id
self.handlers = handlers
self.inheritable_handlers = inheritable_handlers
self.parent_run_id = parent_run_id
self.tags = tags or []
self.inheritable_tags = inheritable_tags or []
self.metadata = metadata or {}
self.inheritable_metadata = inheritable_metadata or {}
@classmethod
def get_noop_manager(cls: Type[BRM]) -> BRM:
"""Return a manager that doesn't perform any operations.
Returns:
BaseRunManager: The noop manager.
"""
return cls(
run_id=uuid.uuid4(),
handlers=[],
inheritable_handlers=[],
tags=[],
inheritable_tags=[],
metadata={},
inheritable_metadata={},
)
RunManager
class RunManager(BaseRunManager):
"""Sync Run Manager."""
def on_text(
self,
text: str,
**kwargs: Any,
) -> Any:
"""Run when text is received.
Args:
text (str): The received text.
Returns:
Any: The result of the callback.
"""
handle_event(
self.handlers,
"on_text",
None,
text,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs,
)
def on_retry(
self,
retry_state: RetryCallState,
**kwargs: Any,
) -> None:
handle_event(
self.handlers,
"on_retry",
"ignore_retry",
retry_state,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs,
)
【**核心实现1**】handle_event(事件处理)
def handle_event(
handlers: List[BaseCallbackHandler],
event_name: str,
ignore_condition_name: Optional[str],
*args: Any,
**kwargs: Any,
) -> None:
"""Generic event handler for CallbackManager.【回调管理器的通用事件处理器】
Note: This function is used by langserve to handle events.
Args:
handlers: The list of handlers that will handle the event
event_name: The name of the event (e.g., "on_llm_start")
ignore_condition_name: Name of the attribute defined on handler
that if True will cause the handler to be skipped for the given event
*args: The arguments to pass to the event handler
**kwargs: The keyword arguments to pass to the event handler
"""
### 声明“协程”列表
coros: List[Coroutine[Any, Any, Any]] = []
try:
message_strings: Optional[List[str]] = None
for handler in handlers:
try:
if ignore_condition_name is None or not getattr(
handler, ignore_condition_name
):
event = getattr(handler, event_name)(*args, **kwargs)
if asyncio.iscoroutine(event):
coros.append(event)
except NotImplementedError as e:
if event_name == "on_chat_model_start":
if message_strings is None:
message_strings = [get_buffer_string(m) for m in args[1]]
handle_event(
[handler],
"on_llm_start",
"ignore_llm",
args[0],
message_strings,
*args[2:],
**kwargs,
)
else:
handler_name = handler.__class__.__name__
logger.warning(
f"NotImplementedError in {handler_name}.{event_name}"
f" callback: {repr(e)}"
)
except Exception as e:
logger.warning(
f"Error in {handler.__class__.__name__}.{event_name} callback:"
f" {repr(e)}"
)
if handler.raise_error:
raise e
finally:
### 如果 coros 非空(至少包含一个元素)
if coros:
try:
# Raises RuntimeError if there is no current event loop.
asyncio.get_running_loop()
loop_running = True
except RuntimeError:
loop_running = False
if loop_running:
# If we try to submit this coroutine to the running loop
# we end up in a deadlock, as we'd have gotten here from a
# running coroutine, which we cannot interrupt to run this one.
# The solution is to create a new loop in a new thread.
with ThreadPoolExecutor(1) as executor:
executor.submit(
cast(Callable, copy_context().run), _run_coros, coros
).result()
else:
_run_coros(coros)
整个方法只有两个 Python 特殊语法会对 Java 语言读者不太友好,其他的完全不影响阅读。
•getattr
(obj, attribute_name) 内置函数(可视为全局函数定义),用于获取对象的属性,注意因为 Python 语言将方法等同视为属性,因此 attribute_name 代表的既可以属性也可以方法,此处使用的就是方法。
不得不说"Unifying abstract"的设计理念被诸多优秀的产品所采用。Linux - 一切都是文件;Kubernetes - 一切都是资源;Python - 一切都是属性。
•asyncio
asyncio
是 Python 并发编程相关的标准库(3.4版本纳入),用于编写单线程的并发代码使用协程,多用于I/O密集型操作。提供了一种简单的方法来运行并发任务,同时不需要显式地处理线程或进程。核心组件包括【协程(Coroutine)、事件循环(get_event_loop()/get_running_loop())、任务、Future、同步原语、传输和协议、流、队列等】。
### 协程(Coroutines)
- **轻量级**:协程是非常轻量级的执行单元,它们在单个线程内部通过事件循环进行调度。协程的切换和调度是由程序员在代码中显式控制的,通常通过`await`关键字暂停当前协程的执行,并将控制权交还给事件循环,直到某个操作完成。
- **协作式多任务**:协程采用协作式的方式进行多任务处理,意味着一个协程必须主动让出控制权,其他协程才能运行。这种方式使得协程在执行过程中不需要锁(mutexes)和其他并发控制结构,从而减少了死锁的风险。
- **用于I/O密集型任务**:协程非常适合I/O密集型任务,如网络请求、文件读写等,因为这些操作的大部分时间都在等待外部事件完成,协程可以在这些等待时间中让出CPU,执行其他任务。
### 线程(Threads)
- **重量级**:线程的创建和切换比协程更重,因为它们需要操作系统的直接支持。线程的调度是由操作系统内核完成的,不需要程序员在代码中进行显式的管理。
- **抢占式多任务**:线程通常采用抢占式的多任务处理方式,即操作系统决定哪个线程获得CPU时间,并进行线程间的切换。这意味着线程在执行过程中可能会在任何时刻被中断,转而执行另一个线程。
- **适用于CPU密集型任务**:线程适合于CPU密集型任务,如计算和数据处理等,因为它们可以被操作系统调度到多个CPU核心上并行执行。
### 【总结】
在 Python 中协程和线程都是程序中的可调度单元,但它们在设计和用途上有所不同。协程更轻量级,适合I/O密集型任务,通过协作式的方式进行调度;而线程更适合CPU密集型任务,由操作系统通过抢占式方式进行调度。
### 【注意】
协程不是操作系统提供的能力,而是 Python 语言层面提供的可调度抽象,这一点在上文“它们在单个线程内部通过事件循环进行调度”中体现的非常明显;与此同时我们也能体会到协程与事件响应机制实现的高度契合。
回归到代码分析。按照分治的思想将代码拆解为几个部分。
1.1 try 块
【代码逻辑】如果不满足可忽略条件【未输入可忽略条件 或 Handler不允许忽略】,执行事件处理。
这里执行事件处理有两种处理方式【同步执行/异步执行】。
event = getattr(handler, event_name)(*args, **kwargs)
if asyncio.iscoroutine(event):
coros.append(event)
【对 Python 很熟悉的可以略过】上面三行代码实现了同/异步两种模式的兼容,此处稍作解释。
'''
### 协程函数 和 协程对象
协程函数:在Python中,使用async def来定义一个函数,这标志着该函数是一个协程。协程函数被调用时,它不会立即执行,而是返回一个协程对象。
协程对象:定义一个协程函数并调用它时,它并不是返回执行后的结果,而是返回一个协程对象。这个对象代表了将要执行的协程,但它本身还没有开始执行,需要被调度到事件循环中去执行。
一个简单的例子来说明这一点:
'''
async def my_coroutine():
# 模拟异步操作
await asyncio.sleep(1)
return "Hello, asyncio!"
# 调用协程函数并不会执行它,而是返回一个协程对象
coro = my_coroutine()
# 这个时候,coro 是一个协程对象,并不是最终的结果
print(coro) # 输出类似于: <coroutine object my_coroutine at 0x10e1e4c10>
# 为了得到执行结果,我们需要将协程对象提交给事件循环
result = asyncio.run(coro)
print(result) # 输出: Hello, asyncio!
'''
在这个例子中,my_coroutine() 是一个协程函数。当它被调用时(my_coroutine()),它返回一个协程对象,而不是执行函数内部的代码。为了执行协程并获取其结果,你必须将协程对象提交给事件循环,这可以通过 asyncio.run() 或其他类似的函数来完成。
这种行为与传统的同步函数不同,同步函数在调用时会立即执行并返回结果。协程的异步特性允许它们在等待(如I/O操作)期间挂起,释放事件循环去执行其他协程,这是实现并发的关键。
'''
1.2 except 块【忽略】
2.1 try 块
【代码逻辑】检查是否存在运行中的事件循环。
具体作用体现在后续逻辑中(2.3)。
2.2 except 块【忽略】
2.3 if 块
【代码逻辑】在一个新的线程中创建一个新的事件循环,并在那里运行协程。
这里与前置的运行中事件循环是否存在检查强关联。背后的原理是协程使用方式:
在一个运行中的事件循环中异步提交一个协程来执行是完全可以的,使用asyncio.create_task()或者loop.create_task(),这是异步编程的标准做法。
但是,在一个协程中同步地运行另一个协程(尝试重新启动事件循环或同步等待协程完成)会导致问题,因为这违反了事件循环的工作方式。
所以,当你在一个协程中,你应该总是异步地(使用await)等待其他协程或任务。如果你需要在另一个线程中运行协程,你应该创建一个新的事件循环。
如果上面的说明不是很容易理解,可以换个角度 - 线程中存在正在运行的协程,想立刻执行新提交的协程只有一个方法,重新开一个线程(事件循环)来执行。
with ThreadPoolExecutor(1) as executor:
executor.submit(
cast(Callable, copy_context().run), ### 复制线程上下文
_run_coros, ### 函数(真正的可执行体)
coros ### 函数的参数(协程列表)
).result()
### executer.submit()三个参数的关系: copy_context().run(_run_coros, coros)
2.4 else 块
【代码逻辑】真正的执行逻辑 - 执行输入的协程列表。(忽略 Python 版本的语法兼容,真正的执行逻辑只有一句话。)
def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None:
if hasattr(asyncio, "Runner"):
# Python 3.11+
# Run the coroutines in a new event loop, taking care to
# - install signal handlers
# - run pending tasks scheduled by `coros`
# - close asyncgens and executors
# - close the loop
with asyncio.Runner() as runner:
# Run the coroutine, get the result
for coro in coros:
try:
runner.run(coro)
except Exception as e:
logger.warning(f"Error in callback coroutine: {repr(e)}")
# Run pending tasks scheduled by coros until they are all done
while pending := asyncio.all_tasks(runner.get_loop()):
runner.run(asyncio.wait(pending))
else:
# Before Python 3.11 we need to run each coroutine in a new event loop
# as the Runner api is not available.
for coro in coros:
try:
asyncio.run(coro)
except Exception as e:
logger.warning(f"Error in callback coroutine: {repr(e)}")
ParentRunManager
class ParentRunManager(RunManager):
"""Sync Parent Run Manager."""
def get_child(self, tag: Optional[str] = None) -> CallbackManager:
"""Get a child callback manager.
Args:
tag (str, optional): The tag for the child callback manager.
Defaults to None.
Returns:
CallbackManager: The child callback manager.
"""
manager = CallbackManager(handlers=[], parent_run_id=self.run_id)
manager.set_handlers(self.inheritable_handlers)
manager.add_tags(self.inheritable_tags)
manager.add_metadata(self.inheritable_metadata)
if tag is not None:
manager.add_tags([tag], False)
return manager
定义获取 子回调管理器 的方法。
AsyncRunManager
class AsyncRunManager(BaseRunManager, ABC):
"""Async Run Manager."""
@abstractmethod
def get_sync(self) -> RunManager:
"""Get the equivalent sync RunManager.
Returns:
RunManager: The sync RunManager.
"""
async def on_text(
self,
text: str,
**kwargs: Any,
) -> Any:
"""Run when text is received.
Args:
text (str): The received text.
Returns:
Any: The result of the callback.
"""
await ahandle_event(
self.handlers,
"on_text",
None,
text,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs,
)
async def on_retry(
self,
retry_state: RetryCallState,
**kwargs: Any,
) -> None:
await ahandle_event(
self.handlers,
"on_retry",
"ignore_retry",
retry_state,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs,
)
【**核心实现2**】ahandle_event(异步事件处理)
async def ahandle_event(
handlers: List[BaseCallbackHandler],
event_name: str,
ignore_condition_name: Optional[str],
*args: Any,
**kwargs: Any,
) -> None:
"""Generic event handler for AsyncCallbackManager.
Note: This function is used by langserve to handle events.
Args:
handlers: The list of handlers that will handle the event
event_name: The name of the event (e.g., "on_llm_start")
ignore_condition_name: Name of the attribute defined on handler
that if True will cause the handler to be skipped for the given event
*args: The arguments to pass to the event handler
**kwargs: The keyword arguments to pass to the event handler
"""
for handler in [h for h in handlers if h.run_inline]:
await _ahandle_event_for_handler(
handler, event_name, ignore_condition_name, *args, **kwargs
)
await asyncio.gather(
*(
_ahandle_event_for_handler(
handler,
event_name,
ignore_condition_name,
*args,
**kwargs,
)
for handler in handlers
if not handler.run_inline
)
)
这里的逻辑比较清晰 - 两步走:
•step1 在 Python 中 await 关键字在异步编程中用于暂停当前协程的执行等待其指定的协程执行结束;for 循环是当前协程,await 指定的协程是 _ahandle_event_for_handler 方法。因此此处类似于所有满足条件的 handler 串行同步执行 _ahandle_event_for_handler 方法。此处满足的条件命名也传达了相同的意思(run_inline - “内联执行”/“同步执行”)。
•step2 所有不强制同步执行的 handler,异步并行执行 _ahandle_event_for_handler 方法。
综上,忽略同步异步的约束,上面的逻辑是遍历 handler 列表执行 _ahandle_event_for_handler 方法。
async def _ahandle_event_for_handler(
handler: BaseCallbackHandler,
event_name: str,
ignore_condition_name: Optional[str],
*args: Any,
**kwargs: Any,
) -> None:
try:
if ignore_condition_name is None or not getattr(handler, ignore_condition_name):
event = getattr(handler, event_name)
if asyncio.iscoroutinefunction(event):
await event(*args, **kwargs)
else:
if handler.run_inline:
event(*args, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None,
cast(
Callable,
functools.partial(
copy_context().run, event, *args, **kwargs
),
),
)
except NotImplementedError as e:
if event_name == "on_chat_model_start":
message_strings = [get_buffer_string(m) for m in args[1]]
await _ahandle_event_for_handler(
handler,
"on_llm_start",
"ignore_llm",
args[0],
message_strings,
*args[2:],
**kwargs,
)
else:
logger.warning(
f"NotImplementedError in {handler.__class__.__name__}.{event_name}"
f" callback: {repr(e)}"
)
except Exception as e:
logger.warning(
f"Error in {handler.__class__.__name__}.{event_name} callback:"
f" {repr(e)}"
)
if handler.raise_error:
raise e
与前文一致,我们忽略 except 的处理,关注正向处理。可以一句话概括方法内容 - 执行 event 方法。
首先检查event是否是一个异步函数,如果是,直接 await 异步执行;如果不是,根据handler.run_inline 的值来决定是直接同步执行 event 还是在异步执行(借助 executor )。
AsyncParentRunManager
class AsyncParentRunManager(AsyncRunManager):
"""Async Parent Run Manager."""
def get_child(self, tag: Optional[str] = None) -> AsyncCallbackManager:
"""Get a child callback manager.
Args:
tag (str, optional): The tag for the child callback manager.
Defaults to None.
Returns:
AsyncCallbackManager: The child callback manager.
"""
manager = AsyncCallbackManager(handlers=[], parent_run_id=self.run_id)
manager.set_handlers(self.inheritable_handlers)
manager.add_tags(self.inheritable_tags)
manager.add_metadata(self.inheritable_metadata)
if tag is not None:
manager.add_tags([tag], False)
return manager
定义获取 异步子回调管理器 的方法。
Last Step. CallbackManager
经过在代码地图上的6步探索,剩下的内容围绕 CallbackManager 展开可归纳为一步;而且剩下的部分虽然代码量比较多,但是已经没有新的内容引入了,通读之后可以看到这部分仅针对大模型抽象提供了事件定义的入口。
Chain Group 是大模型应用落地的其他设计点,与本文中心无关,为了不影响本文的阅读我们不妨大胆假设: Chain group 是多个 chain 的集合,用于管理并行或相互独立的 chain。Chain group 可以用来处理多个任务流程,这些流程可能有共同的起点或终点,或者需要在某些点上同步。Chain group 的设计允许开发者将复杂的工作流组织成可管理的单元,从而可以同时执行多个任务流,提高效率并允许更灵活的资源管理。
三.写在最后
回调能力是软件框架中很基础的支撑能力,其服务对象是框架而非框架的用户。换言之,基于 LangChain 框架进行开发的工程师无法直观的感知到回调。因此,选择这样一个能力作为 LangChain 框架源码剖析的开篇并不恰当。从另一视角来看,作为基础能力,其在后续的能力分析中必然有所依赖,因此当作一个番外篇独立阅读却未尝不可。
本人另一篇文章《流式(Stream)能力分析》会引用本篇,而流式能力分析也将算作本系列的开篇。