如何创建自定义 LLM 类

老铁们,今天我们来聊一下如何创建一个自定义的 LLM(Large Language Model)包装器,这样你就可以在 LangChain 中使用你自己的 LLM,或者是使用一个不同于 LangChain 默认支持的包装器。通过将你的 LLM 封装为标准 LLM 接口,你可以在现有的 LangChain 程序中使用自己的 LLM,而只需进行最小的代码修改。再一个好处就是,你的 LLM 会自动成为一个 LangChain Runnable,并因此受益于一些现成的优化功能,如异步支持、astream_events API 等。

实现

要实现一个自定义的 LLM,我们需要实现两个核心方法:

  • _call: 接收一个字符串和一些可选的终止词,返回一个字符串。这个方法是通过 invoke 调用的。
  • _llm_type: 返回一个字符串,仅用于日志记录。

此外还有一些可选的方法可供实现:

  • _identifying_params: 返回一个用于识别模型和打印 LLM 的字典。是一个 @property
  • _acall: 提供 _call 的异步实现,通过 ainvoke 调用。
  • _stream: 用于逐个令牌地流式输出。
  • _astream: 提供 _stream 的异步实现;在新版 LangChain 中默认实现为 _stream

示例代码

下面是一个简单的自定义 LLM 示例,它只返回输入的前 n 个字符:

from typing import Any, Dict, Iterator, List, Mapping, Optional

from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk

class CustomLLM(LLM):
    n: int
    """从提示的最后一条消息中回显的字符数。"""

    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        if stop is not None:
            raise ValueError("stop kwargs are not permitted.")
        return prompt[: self.n]

    def _stream(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> Iterator[GenerationChunk]:
        for char in prompt[: self.n]:
            chunk = GenerationChunk(text=char)
            if run_manager:
                run_manager.on_llm_new_token(chunk.text, chunk=chunk)
            yield chunk

    @property
    def _identifying_params(self) -> Dict[str, Any]:
        return {"model_name": "CustomChatModel"}

    @property
    def _llm_type(self) -> str:
        return "custom"

测试自定义 LLM

让我们来测试一下这个 LLM,它实现了 LangChain 的标准 Runnable 接口,很多 LangChain 抽象层都能支持它!

llm = CustomLLM(n=5)
print(llm)

llm.invoke("This is a foobar thing")

await llm.ainvoke("world")

llm.batch(["woof woof woof", "meow meow meow"])

await llm.abatch(["woof woof woof", "meow meow meow"])

async for token in llm.astream("hello"):
    print(token, end="|", flush=True)

确认集成

我们确保其能与其他 LangChain API 友好集成:

from langchain_core.prompts import ChatPromptTemplate

prompt = ChatPromptTemplate.from_messages(
    [("system", "you are a bot"), ("human", "{input}")]
)

llm = CustomLLM(n=7)
chain = prompt | llm

idx = 0
async for event in chain.astream_events({"input": "hello there!"}, version="v1"):
    print(event)
    idx += 1
    if idx > 7:
        break

贡献提示

我们欢迎对聊天模型集成的贡献。以下是一些确保你的贡献被纳入 LangChain 的检查清单:

  • 为所有初始化参数撰写文档字符串。
  • 如果模型由服务提供支持,类的文档字符串中应包含 API 链接。
  • 为重写的方法添加单元或集成测试。

今天的技术分享就到这里,希望对大家有帮助。开发过程中遇到问题也可以在评论区交流~

—END—

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值