LangGraph 源码分析 | 结构化输出

with_structured_output 方法

主要功能

允许用户将模型(LLM)的输出转换为特定的数据格式:

  • 字典
  • JSON Schema
  • TypedDict
  • Pydantic 类 :Pydantic 是一个用于数据验证和解析的 Python 库。BaseModel 是 Pydantic 库中的一个基类,用于定义数据模型【BaseModel 提供了一种声明性的方式来定义数据模型,包括字段类型、验证规则和默认值】

核心参数

  • schema: 指定模型输出的格式,可以是:
    • OpenAI函数/工具的schema
    • JSON Schema格式
    • TypedDict类 (支持自0.2.26版本)
    • Pydantic类:如果提供的是Pydantic类,则会自动将模型输出转换为该Pydantic对象,并对字段进行验证。
  • include_raw
    • 是否返回原始输出。如果为True,返回一个包含"raw"、"parsed"和"parsing_error"三个键的字典。
    • 如果为False,只返回解析后的结构化输出(或者抛出解析错误)。

返回值

返回的是一个 Runnable 对象。该对象会根据输入生成模型响应,并将响应解析成符合指定schema格式的数据:

  • include_raw=False时:
    • 如果schema是Pydantic类,则返回一个Pydantic对象。
    • 否则,返回一个字典格式的结构化数据。
  • include_raw=True时:
    • 返回一个包含"raw""parsed""parsing_error"三个键的字典。

核心逻辑

首先,模型需要支持Tool use,因为后续需要通过bind_tools来格式化输出

绑定工具与选择解析器
llm = self.bind_tools([schema], tool_choice="any")
if isinstance(schema, type) and is_basemodel_subclass(schema):
    output_parser = PydanticToolsParser(
        tools=[cast(TypeBaseModel, schema)], first_tool_only=True
    )
else:
    key_name = convert_to_openai_tool(schema)["function"]["name"]
    output_parser = JsonOutputKeyToolsParser(
        key_name=key_name, first_tool_only=True
    )

  • 将模型输出的格式schema以工具的方式,绑定到 LLM 上
  • 如果schema是 Pydantic 类,则使用PydanticToolsParser解析
  • 如果是其他类型,则使用JsonOutputKeyToolsParser解析
解析并结构化输出
if include_raw:
    parser_assign = RunnablePassthrough.assign(
        parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
    )
    parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
    parser_with_fallback = parser_assign.with_fallbacks(
        [parser_none], exception_key="parsing_error"
    )
    return RunnableMap(raw=llm) | parser_with_fallback
else:
    return llm | output_parser
  • 如果include_rawTrue,则使用一个更复杂的解析器链,包括原始数据解析、回退机制和异常处理
  • 如果include_rawFalse,则使用一个更简单的解析器链

源代码

    def with_structured_output(
        self,
        schema: Union[typing.Dict, type],  # noqa: UP006
        *,
        include_raw: bool = False,
        **kwargs: Any,
    ) -> Runnable[LanguageModelInput, Union[typing.Dict, BaseModel]]:  # noqa: UP006
        """Model wrapper that returns outputs formatted to match the given schema.

        Args:
            schema:
                The output schema. Can be passed in as:
                    - an OpenAI function/tool schema,
                    - a JSON Schema,
                    - a TypedDict class (support added in 0.2.26),
                    - or a Pydantic class.
                If ``schema`` is a Pydantic class then the model output will be a
                Pydantic instance of that class, and the model-generated fields will be
                validated by the Pydantic class. Otherwise the model output will be a
                dict and will not be validated. See :meth:`langchain_core.utils.function_calling.convert_to_openai_tool`
                for more on how to properly specify types and descriptions of
                schema fields when specifying a Pydantic or TypedDict class.

                .. versionchanged:: 0.2.26

                        Added support for TypedDict class.

            include_raw:
                If False then only the parsed structured output is returned. If
                an error occurs during model output parsing it will be raised. If True
                then both the raw model response (a BaseMessage) and the parsed model
                response will be returned. If an error occurs during output parsing it
                will be caught and returned as well. The final output is always a dict
                with keys "raw", "parsed", and "parsing_error".

        Returns:
            A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`.

            If ``include_raw`` is False and ``schema`` is a Pydantic class, Runnable outputs
            an instance of ``schema`` (i.e., a Pydantic object).

            Otherwise, if ``include_raw`` is False then Runnable outputs a dict.

            If ``include_raw`` is True, then Runnable outputs a dict with keys:
                - ``"raw"``: BaseMessage
                - ``"parsed"``: None if there was a parsing error, otherwise the type depends on the ``schema`` as described above.
                - ``"parsing_error"``: Optional[BaseException]

        Example: Pydantic schema (include_raw=False):
            .. code-block:: python

                from pydantic import BaseModel

                class AnswerWithJustification(BaseModel):
                    '''An answer to the user question along with justification for the answer.'''
                    answer: str
                    justification: str

                llm = ChatModel(model="model-name", temperature=0)
                structured_llm = llm.with_structured_output(AnswerWithJustification)

                structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")

                # -> AnswerWithJustification(
                #     answer='They weigh the same',
                #     justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'
                # )

        Example: Pydantic schema (include_raw=True):
            .. code-block:: python

                from pydantic import BaseModel

                class AnswerWithJustification(BaseModel):
                    '''An answer to the user question along with justification for the answer.'''
                    answer: str
                    justification: str

                llm = ChatModel(model="model-name", temperature=0)
                structured_llm = llm.with_structured_output(AnswerWithJustification, include_raw=True)

                structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
                # -> {
                #     'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_Ao02pnFYXD6GN1yzc0uXPsvF', 'function': {'arguments': '{"answer":"They weigh the same.","justification":"Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ."}', 'name': 'AnswerWithJustification'}, 'type': 'function'}]}),
                #     'parsed': AnswerWithJustification(answer='They weigh the same.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'),
                #     'parsing_error': None
                # }

        Example: Dict schema (include_raw=False):
            .. code-block:: python

                from pydantic import BaseModel
                from langchain_core.utils.function_calling import convert_to_openai_tool

                class AnswerWithJustification(BaseModel):
                    '''An answer to the user question along with justification for the answer.'''
                    answer: str
                    justification: str

                dict_schema = convert_to_openai_tool(AnswerWithJustification)
                llm = ChatModel(model="model-name", temperature=0)
                structured_llm = llm.with_structured_output(dict_schema)

                structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
                # -> {
                #     'answer': 'They weigh the same',
                #     'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.'
                # }
        """  # noqa: E501
        # 是否包含无效参数
        if kwargs:
            msg = f"Received unsupported arguments {kwargs}"
            raise ValueError(msg)

        # 导入解析器模块
        from langchain_core.output_parsers.openai_tools import (
            JsonOutputKeyToolsParser,
            PydanticToolsParser,
        )

        # 判断模型是否支持 with_structured_output
        if self.bind_tools is BaseChatModel.bind_tools:
            msg = "with_structured_output is not implemented for this model."
            raise NotImplementedError(msg)
        # 绑定工具与选择解析器
        llm = self.bind_tools([schema], tool_choice="any")
        if isinstance(schema, type) and is_basemodel_subclass(schema):
            output_parser: OutputParserLike = PydanticToolsParser(
                tools=[cast(TypeBaseModel, schema)], first_tool_only=True
            )
        else:
            key_name = convert_to_openai_tool(schema)["function"]["name"]
            output_parser = JsonOutputKeyToolsParser(
                key_name=key_name, first_tool_only=True
            )
        if include_raw:
            parser_assign = RunnablePassthrough.assign(
                parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
            )
            parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
            parser_with_fallback = parser_assign.with_fallbacks(
                [parser_none], exception_key="parsing_error"
            )
            return RunnableMap(raw=llm) | parser_with_fallback
        else:
            return llm | output_parser

解析器

JsonOutputKeyToolsParser

调用父类的 parse_result

作用:解析 LLM 调用的结果,并将其转化为工具调用的列表

参数:

  • result: 这是一个包含 LLM 调用结果的列表,通常是 Generation 类型的对象
  • partial: 一个可选的布尔参数,默认值为 False。指示是否解析部分 JSON。如果为 True,则返回的将是包含已返回键的 JSON 对象;如果为 False,则返回完整的 JSON 对象。
    def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
        # 从 result 列表中获取第一个生成结果
        generation = result[0]
        # 检查 generation 是否是 ChatGeneration 类型
        if not isinstance(generation, ChatGeneration):
            msg = "This output parser can only be used with a chat generation."
            raise OutputParserException(msg)
        # 从 generation 中提取消息
        message = generation.message
        if isinstance(message, AIMessage) and message.tool_calls:
            tool_calls = [dict(tc) for tc in message.tool_calls]
            for tool_call in tool_calls:
                if not self.return_id:
                    _ = tool_call.pop("id")
        else:
            try:
                raw_tool_calls = copy.deepcopy(message.additional_kwargs["tool_calls"])
            except KeyError:
                return []
            tool_calls = parse_tool_calls(
                raw_tool_calls,
                partial=partial,
                strict=self.strict,
                return_id=self.return_id,
            )
        # for backwards compatibility
        for tc in tool_calls:
            tc["type"] = tc.pop("name")

        if self.first_tool_only:
            return tool_calls[0] if tool_calls else None
        return tool_calls
源代码
class JsonOutputKeyToolsParser(JsonOutputToolsParser):
    """Parse tools from OpenAI response."""

    key_name: str
    """The type of tools to return."""

    def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
        """Parse the result of an LLM call to a list of tool calls.

        Args:
            result: The result of the LLM call.
            partial: Whether to parse partial JSON.
                If True, the output will be a JSON object containing
                all the keys that have been returned so far.
                If False, the output will be the full JSON object.
                Default is False.

        Returns:
            The parsed tool calls.
        """
        parsed_result = super().parse_result(result, partial=partial)

        # 只需要返回第一个匹配的工具结果
        if self.first_tool_only:
            single_result = (
                parsed_result
                if parsed_result and parsed_result["type"] == self.key_name
                else None
            )
            if self.return_id:
                return single_result
            elif single_result:
                return single_result["args"]
            else:
                return None
        # 处理多个工具结果
        parsed_result = [res for res in parsed_result if res["type"] == self.key_name]
        if not self.return_id:
            parsed_result = [res["args"] for res in parsed_result]
        return parsed_result

PydanticToolsParser

概览
  1. 输入:LLM 返回的 JSON 数据(嵌套在 Generation 对象内)。
  2. 处理
  • 验证并将 JSON 数据映射为 Pydantic 模型。
  • 如果出错,会抛出自定义的异常。
  1. 输出:返回一个 Pydantic 模型对象,便于进一步处理或使用。
_parse_obj:解析 JSON 为 Pydantic 模型
def _parse_obj(self, obj: dict) -> TBaseModel:
    if PYDANTIC_MAJOR_VERSION == 2:
        try:
            if issubclass(self.pydantic_object, pydantic.BaseModel):
                return self.pydantic_object.model_validate(obj)
            elif issubclass(self.pydantic_object, pydantic.v1.BaseModel):
                return self.pydantic_object.parse_obj(obj)
            else:
                msg = f"Unsupported model version for PydanticOutputParser: \
                        {self.pydantic_object.__class__}"
                raise OutputParserException(msg)
        except (pydantic.ValidationError, pydantic.v1.ValidationError) as e:
            raise self._parser_exception(e, obj) from e
    else:
        try:
            return self.pydantic_object.parse_obj(obj)
        except pydantic.ValidationError as e:
            raise self._parser_exception(e, obj) from e

parse_result:根据 LLM 生成的数据解析结果
def parse_result(
    self, result: list[Generation], *, partial: bool = False
) -> Optional[TBaseModel]:
    """Parse the result of an LLM call to a pydantic object."""
    try:
        json_object = super().parse_result(result)
        return self._parse_obj(json_object)
    except OutputParserException as e:
        if partial:
            return None
        raise e

  1. 使用父类的 parse_result() 方法将 Generation 对象转换为 JSON 数据。
  2. 使用 _parse_obj() 方法进一步解析为 Pydantic 模型
  3. 如果解析失败且 partial=True,则返回 None;否则抛出异常。
get_format_instructions:返回格式化说明
def get_format_instructions(self) -> str:
    """Return the format instructions for the JSON output."""
    schema = dict(self.pydantic_object.model_json_schema().items())
    if "title" in schema:
        del schema["title"]
    if "type" in schema:
        del schema["type"]
    schema_str = json.dumps(schema, ensure_ascii=False)
    return _PYDANTIC_FORMAT_INSTRUCTIONS.format(schema=schema_str)

  • 获取 Pydantic 模型的 JSON schema,并删除不必要的字段(如 titletype)。
  • 返回格式化后的 JSON schema 字符串,作为模型的格式说明。
源代码
class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
    """Parse an output using a pydantic model."""

    pydantic_object: Annotated[type[TBaseModel], SkipValidation()]  # type: ignore
    """The pydantic model to parse."""

    def _parse_obj(self, obj: dict) -> TBaseModel:
        if PYDANTIC_MAJOR_VERSION == 2:
            try:
                if issubclass(self.pydantic_object, pydantic.BaseModel):
                    return self.pydantic_object.model_validate(obj)
                elif issubclass(self.pydantic_object, pydantic.v1.BaseModel):
                    return self.pydantic_object.parse_obj(obj)
                else:
                    msg = f"Unsupported model version for PydanticOutputParser: \
                            {self.pydantic_object.__class__}"
                    raise OutputParserException(msg)
            except (pydantic.ValidationError, pydantic.v1.ValidationError) as e:
                raise self._parser_exception(e, obj) from e
        else:  # pydantic v1
            try:
                return self.pydantic_object.parse_obj(obj)
            except pydantic.ValidationError as e:
                raise self._parser_exception(e, obj) from e

    def _parser_exception(
        self, e: Exception, json_object: dict
    ) -> OutputParserException:
        json_string = json.dumps(json_object)
        name = self.pydantic_object.__name__
        msg = f"Failed to parse {name} from completion {json_string}. Got: {e}"
        return OutputParserException(msg, llm_output=json_string)

    def parse_result(
        self, result: list[Generation], *, partial: bool = False
    ) -> Optional[TBaseModel]:
        """Parse the result of an LLM call to a pydantic object.

        Args:
            result: The result of the LLM call.
            partial: Whether to parse partial JSON objects.
                If True, the output will be a JSON object containing
                all the keys that have been returned so far.
                Defaults to False.

        Returns:
            The parsed pydantic object.
        """
        try:
            json_object = super().parse_result(result)
            return self._parse_obj(json_object)
        except OutputParserException as e:
            if partial:
                return None
            raise e

    def parse(self, text: str) -> TBaseModel:
        """Parse the output of an LLM call to a pydantic object.

        Args:
            text: The output of the LLM call.

        Returns:
            The parsed pydantic object.
        """
        return super().parse(text)

    def get_format_instructions(self) -> str:
        """Return the format instructions for the JSON output.

        Returns:
            The format instructions for the JSON output.
        """
        # Copy schema to avoid altering original Pydantic schema.
        schema = dict(self.pydantic_object.model_json_schema().items())

        # Remove extraneous fields.
        reduced_schema = schema
        if "title" in reduced_schema:
            del reduced_schema["title"]
        if "type" in reduced_schema:
            del reduced_schema["type"]
        # Ensure json in context is well-formed with double quotes.
        schema_str = json.dumps(reduced_schema, ensure_ascii=False)

        return _PYDANTIC_FORMAT_INSTRUCTIONS.format(schema=schema_str)

    @property
    def _type(self) -> str:
        return "pydantic"

    @property
    @override
    def OutputType(self) -> type[TBaseModel]:
        """Return the pydantic model."""
        return self.pydantic_object

总结

  • 结构化输出建立在 LLM 已经经过 Tool use 微调的基础上
  • Tool use 功能允许在特定场景下,根据可用的工具列表,LLM 能够智能地选择最合适的工具。接着,模型会根据该工具的 API 调用指南,生成符合要求的 JSON 格式请求
  • 而 LangGraph 的结构化输出复用了 Tool use 的功能,将给定的模式 schema 看成一个工具,要求 LLM 根据 schema 的描述,输出符合要求的“请求格式”,并取出 Tool use 功能返回的冗余字段
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值