老铁们,今天我们来聊一下如何创建一个自定义的 LLM(Large Language Model)包装器,这样你就可以在 LangChain 中使用你自己的 LLM,或者是使用一个不同于 LangChain 默认支持的包装器。通过将你的 LLM 封装为标准 LLM 接口,你可以在现有的 LangChain 程序中使用自己的 LLM,而只需进行最小的代码修改。再一个好处就是,你的 LLM 会自动成为一个 LangChain Runnable,并因此受益于一些现成的优化功能,如异步支持、astream_events API 等。
实现
要实现一个自定义的 LLM,我们需要实现两个核心方法:
_call
: 接收一个字符串和一些可选的终止词,返回一个字符串。这个方法是通过invoke
调用的。_llm_type
: 返回一个字符串,仅用于日志记录。
此外还有一些可选的方法可供实现:
_identifying_params
: 返回一个用于识别模型和打印 LLM 的字典。是一个@property
。_acall
: 提供_call
的异步实现,通过ainvoke
调用。_stream
: 用于逐个令牌地流式输出。_astream
: 提供_stream
的异步实现;在新版 LangChain 中默认实现为_stream
。
示例代码
下面是一个简单的自定义 LLM 示例,它只返回输入的前 n 个字符:
from typing import Any, Dict, Iterator, List, Mapping, Optional
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk
class CustomLLM(LLM):
n: int
"""从提示的最后一条消息中回显的字符数。"""
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
if stop is not None:
raise ValueError("stop kwargs are not permitted.")
return prompt[: self.n]
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
for char in prompt[: self.n]:
chunk = GenerationChunk(text=char)
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
yield chunk
@property
def _identifying_params(self) -> Dict[str, Any]:
return {"model_name": "CustomChatModel"}
@property
def _llm_type(self) -> str:
return "custom"
测试自定义 LLM
让我们来测试一下这个 LLM,它实现了 LangChain 的标准 Runnable 接口,很多 LangChain 抽象层都能支持它!
llm = CustomLLM(n=5)
print(llm)
llm.invoke("This is a foobar thing")
await llm.ainvoke("world")
llm.batch(["woof woof woof", "meow meow meow"])
await llm.abatch(["woof woof woof", "meow meow meow"])
async for token in llm.astream("hello"):
print(token, end="|", flush=True)
确认集成
我们确保其能与其他 LangChain API 友好集成:
from langchain_core.prompts import ChatPromptTemplate
prompt = ChatPromptTemplate.from_messages(
[("system", "you are a bot"), ("human", "{input}")]
)
llm = CustomLLM(n=7)
chain = prompt | llm
idx = 0
async for event in chain.astream_events({"input": "hello there!"}, version="v1"):
print(event)
idx += 1
if idx > 7:
break
贡献提示
我们欢迎对聊天模型集成的贡献。以下是一些确保你的贡献被纳入 LangChain 的检查清单:
- 为所有初始化参数撰写文档字符串。
- 如果模型由服务提供支持,类的文档字符串中应包含 API 链接。
- 为重写的方法添加单元或集成测试。
今天的技术分享就到这里,希望对大家有帮助。开发过程中遇到问题也可以在评论区交流~
—END—