聊聊Spring AI的StructuredOutputConverter

本文主要研究一下Spring AI的StructuredOutputConverter

StructuredOutputConverter

org/springframework/ai/converter/StructuredOutputConverter.java

public interface StructuredOutputConverter<T> extends Converter<String, T>, FormatProvider {

}

StructuredOutputConverter接口继承了Converter、FormatProvider接口,它有两个抽象类,分别是AbstractMessageOutputConverter、AbstractConversionServiceOutputConverter

Converter

org/springframework/core/convert/converter/Converter.java

@FunctionalInterface
public interface Converter<S, T> {

	/**
	 * Convert the source object of type {@code S} to target type {@code T}.
	 * @param source the source object to convert, which must be an instance of {@code S} (never {@code null})
	 * @return the converted object, which must be an instance of {@code T} (potentially {@code null})
	 * @throws IllegalArgumentException if the source cannot be converted to the desired target type
	 */
	@Nullable
	T convert(S source);

	/**
	 * Construct a composed {@link Converter} that first applies this {@link Converter}
	 * to its input, and then applies the {@code after} {@link Converter} to the
	 * result.
	 * @param after the {@link Converter} to apply after this {@link Converter}
	 * is applied
	 * @param <U> the type of output of both the {@code after} {@link Converter}
	 * and the composed {@link Converter}
	 * @return a composed {@link Converter} that first applies this {@link Converter}
	 * and then applies the {@code after} {@link Converter}
	 * @since 5.3
	 */
	default <U> Converter<S, U> andThen(Converter<? super T, ? extends U> after) {
		Assert.notNull(after, "'after' Converter must not be null");
		return (S s) -> {
			T initialResult = convert(s);
			return (initialResult != null ? after.convert(initialResult) : null);
		};
	}

}

Converter接口定义了convert方法,并提供了andThen的default实现

FormatProvider

org/springframework/ai/converter/FormatProvider.java

public interface FormatProvider {

	/**
	 * Get the format of the output of a language generative.
	 * @return Returns a string containing instructions for how the output of a language
	 * generative should be formatted.
	 */
	String getFormat();

}

FormatProvider定义了getFormat接口

AbstractMessageOutputConverter

org/springframework/ai/converter/AbstractMessageOutputConverter.java

public abstract class AbstractMessageOutputConverter<T> implements StructuredOutputConverter<T> {

	private MessageConverter messageConverter;

	/**
	 * Create a new AbstractMessageOutputConverter.
	 * @param messageConverter the message converter to use
	 */
	public AbstractMessageOutputConverter(MessageConverter messageConverter) {
		this.messageConverter = messageConverter;
	}

	/**
	 * Return the message converter used by this output converter.
	 * @return the message converter
	 */
	public MessageConverter getMessageConverter() {
		return this.messageConverter;
	}

}

AbstractMessageOutputConverter定义了MessageConverter属性,它的实现类为MapOutputConverter

MapOutputConverter

org/springframework/ai/converter/MapOutputConverter.java

public class MapOutputConverter extends AbstractMessageOutputConverter<Map<String, Object>> {

	public MapOutputConverter() {
		super(new MappingJackson2MessageConverter());
	}

	@Override
	public Map<String, Object> convert(@NonNull String text) {
		if (text.startsWith("```json") && text.endsWith("```")) {
			text = text.substring(7, text.length() - 3);
		}

		Message<?> message = MessageBuilder.withPayload(text.getBytes(StandardCharsets.UTF_8)).build();
		return (Map) this.getMessageConverter().fromMessage(message, HashMap.class);
	}

	@Override
	public String getFormat() {
		String raw = """
				Your response should be in JSON format.
				The data structure for the JSON should match this Java class: %s
				Do not include any explanations, only provide a RFC8259 compliant JSON response following this format without deviation.
				Remove the ```json markdown surrounding the output including the trailing "```".
				""";
		return String.format(raw, HashMap.class.getName());
	}

}

MapOutputConverter继承了AbstractMessageOutputConverter,其MessageConverter为MappingJackson2MessageConverter

AbstractConversionServiceOutputConverter

org/springframework/ai/converter/AbstractConversionServiceOutputConverter.java

public abstract class AbstractConversionServiceOutputConverter<T> implements StructuredOutputConverter<T> {

	private final DefaultConversionService conversionService;

	/**
	 * Create a new {@link AbstractConversionServiceOutputConverter} instance.
	 * @param conversionService the {@link DefaultConversionService} to use for converting
	 * the output.
	 */
	public AbstractConversionServiceOutputConverter(DefaultConversionService conversionService) {
		this.conversionService = conversionService;
	}

	/**
	 * Return the ConversionService used by this converter.
	 * @return the ConversionService used by this converter.
	 */
	public DefaultConversionService getConversionService() {
		return this.conversionService;
	}

}

AbstractConversionServiceOutputConverter定义了DefaultConversionService属性,它的实现类为ListOutputConverter

ListOutputConverter

org/springframework/ai/converter/ListOutputConverter.java

public class ListOutputConverter extends AbstractConversionServiceOutputConverter<List<String>> {

	public ListOutputConverter(DefaultConversionService defaultConversionService) {
		super(defaultConversionService);
	}

	@Override
	public String getFormat() {
		return """
				Respond with only a list of comma-separated values, without any leading or trailing text.
				Example format: foo, bar, baz
				""";
	}

	@Override
	public List<String> convert(@NonNull String text) {
		return this.getConversionService().convert(text, List.class);
	}

}

ListOutputConverter继承了AbstractConversionServiceOutputConverter,其convert将text转换为List<String>

BeanOutputConverter

org/springframework/ai/converter/BeanOutputConverter.java

public class BeanOutputConverter<T> implements StructuredOutputConverter<T> {

	private final Logger logger = LoggerFactory.getLogger(BeanOutputConverter.class);

	/**
	 * The target class type reference to which the output will be converted.
	 */
	private final Type type;

	/** The object mapper used for deserialization and other JSON operations. */
	private final ObjectMapper objectMapper;

	/** Holds the generated JSON schema for the target type. */
	private String jsonSchema;

	/**
	 * Constructor to initialize with the target type's class.
	 * @param clazz The target type's class.
	 */
	public BeanOutputConverter(Class<T> clazz) {
		this(ParameterizedTypeReference.forType(clazz));
	}

	/**
	 * Constructor to initialize with the target type's class, a custom object mapper, and
	 * a line endings normalizer to ensure consistent line endings on any platform.
	 * @param clazz The target type's class.
	 * @param objectMapper Custom object mapper for JSON operations. endings.
	 */
	public BeanOutputConverter(Class<T> clazz, ObjectMapper objectMapper) {
		this(ParameterizedTypeReference.forType(clazz), objectMapper);
	}

	/**
	 * Constructor to initialize with the target class type reference.
	 * @param typeRef The target class type reference.
	 */
	public BeanOutputConverter(ParameterizedTypeReference<T> typeRef) {
		this(typeRef.getType(), null);
	}

	/**
	 * Constructor to initialize with the target class type reference, a custom object
	 * mapper, and a line endings normalizer to ensure consistent line endings on any
	 * platform.
	 * @param typeRef The target class type reference.
	 * @param objectMapper Custom object mapper for JSON operations. endings.
	 */
	public BeanOutputConverter(ParameterizedTypeReference<T> typeRef, ObjectMapper objectMapper) {
		this(typeRef.getType(), objectMapper);
	}

	/**
	 * Constructor to initialize with the target class type reference, a custom object
	 * mapper, and a line endings normalizer to ensure consistent line endings on any
	 * platform.
	 * @param type The target class type.
	 * @param objectMapper Custom object mapper for JSON operations. endings.
	 */
	private BeanOutputConverter(Type type, ObjectMapper objectMapper) {
		Objects.requireNonNull(type, "Type cannot be null;");
		this.type = type;
		this.objectMapper = objectMapper != null ? objectMapper : getObjectMapper();
		generateSchema();
	}

	/**
	 * Generates the JSON schema for the target type.
	 */
	private void generateSchema() {
		JacksonModule jacksonModule = new JacksonModule(JacksonOption.RESPECT_JSONPROPERTY_REQUIRED,
				JacksonOption.RESPECT_JSONPROPERTY_ORDER);
		SchemaGeneratorConfigBuilder configBuilder = new SchemaGeneratorConfigBuilder(
				com.github.victools.jsonschema.generator.SchemaVersion.DRAFT_2020_12,
				com.github.victools.jsonschema.generator.OptionPreset.PLAIN_JSON)
			.with(jacksonModule)
			.with(Option.FORBIDDEN_ADDITIONAL_PROPERTIES_BY_DEFAULT);
		SchemaGeneratorConfig config = configBuilder.build();
		SchemaGenerator generator = new SchemaGenerator(config);
		JsonNode jsonNode = generator.generateSchema(this.type);
		ObjectWriter objectWriter = this.objectMapper.writer(new DefaultPrettyPrinter()
			.withObjectIndenter(new DefaultIndenter().withLinefeed(System.lineSeparator())));
		try {
			this.jsonSchema = objectWriter.writeValueAsString(jsonNode);
		}
		catch (JsonProcessingException e) {
			logger.error("Could not pretty print json schema for jsonNode: {}", jsonNode);
			throw new RuntimeException("Could not pretty print json schema for " + this.type, e);
		}
	}

	/**
	 * Parses the given text to transform it to the desired target type.
	 * @param text The LLM output in string format.
	 * @return The parsed output in the desired target type.
	 */
	@SuppressWarnings("unchecked")
	@Override
	public T convert(@NonNull String text) {
		try {
			// Remove leading and trailing whitespace
			text = text.trim();

			// Check for and remove triple backticks and "json" identifier
			if (text.startsWith("```") && text.endsWith("```")) {
				// Remove the first line if it contains "```json"
				String[] lines = text.split("\n", 2);
				if (lines[0].trim().equalsIgnoreCase("```json")) {
					text = lines.length > 1 ? lines[1] : "";
				}
				else {
					text = text.substring(3); // Remove leading ```
				}

				// Remove trailing ```
				text = text.substring(0, text.length() - 3);

				// Trim again to remove any potential whitespace
				text = text.trim();
			}
			return (T) this.objectMapper.readValue(text, this.objectMapper.constructType(this.type));
		}
		catch (JsonProcessingException e) {
			logger.error(SENSITIVE_DATA_MARKER,
					"Could not parse the given text to the desired target type: \"{}\" into {}", text, this.type);
			throw new RuntimeException(e);
		}
	}

	/**
	 * Configures and returns an object mapper for JSON operations.
	 * @return Configured object mapper.
	 */
	protected ObjectMapper getObjectMapper() {
		return JsonMapper.builder()
			.addModules(JacksonUtils.instantiateAvailableModules())
			.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
			.build();
	}

	/**
	 * Provides the expected format of the response, instructing that it should adhere to
	 * the generated JSON schema.
	 * @return The instruction format string.
	 */
	@Override
	public String getFormat() {
		String template = """
				Your response should be in JSON format.
				Do not include any explanations, only provide a RFC8259 compliant JSON response following this format without deviation.
				Do not include markdown code blocks in your response.
				Remove the ```json markdown from the output.
				Here is the JSON Schema instance your output must adhere to:
				```%s```
				""";
		return String.format(template, this.jsonSchema);
	}

	/**
	 * Provides the generated JSON schema for the target type.
	 * @return The generated JSON schema.
	 */
	public String getJsonSchema() {
		return this.jsonSchema;
	}

	public Map<String, Object> getJsonSchemaMap() {
		try {
			return this.objectMapper.readValue(this.jsonSchema, Map.class);
		}
		catch (JsonProcessingException ex) {
			logger.error("Could not parse the JSON Schema to a Map object", ex);
			throw new IllegalStateException(ex);
		}
	}

}

BeanOutputConverter通过objectMapper将json转换为bean

示例

chatModel + outputConverter

	@Test
	void mapOutputConvert() {
		MapOutputConverter outputConverter = new MapOutputConverter();

		String format = outputConverter.getFormat();
		String template = """
				For each letter in the RGB color scheme, tell me what it stands for.
				Example: R -> Red.
				{format}
				""";
		PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
		Prompt prompt = new Prompt(promptTemplate.createMessage());

		Generation generation = this.chatModel.call(prompt).getResult();

		Map<String, Object> result = outputConverter.convert(generation.getOutput().getText());
		assertThat(result).isNotNull();
		assertThat((String) result.get("R")).containsIgnoringCase("red");
		assertThat((String) result.get("G")).containsIgnoringCase("green");
		assertThat((String) result.get("B")).containsIgnoringCase("blue");
	}

chatClient + outputConverter

	@Test
	public void responseEntityTest() {

		ChatResponseMetadata metadata = ChatResponseMetadata.builder().keyValue("key1", "value1").build();

		var chatResponse = new ChatResponse(List.of(new Generation(new AssistantMessage("""
				{"name":"John", "age":30}
				"""))), metadata);

		given(this.chatModel.call(this.promptCaptor.capture())).willReturn(chatResponse);

		ResponseEntity<ChatResponse, MyBean> responseEntity = ChatClient.builder(this.chatModel)
			.build()
			.prompt()
			.user("Tell me about John")
			.call()
			.responseEntity(MyBean.class);

		assertThat(responseEntity.getResponse()).isEqualTo(chatResponse);
		assertThat(responseEntity.getResponse().getMetadata().get("key1").toString()).isEqualTo("value1");

		assertThat(responseEntity.getEntity()).isEqualTo(new MyBean("John", 30));

		Message userMessage = this.promptCaptor.getValue().getInstructions().get(0);
		assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER);
		assertThat(userMessage.getText()).contains("Tell me about John");
	}

ChatClient配置responseEntity类型,内部使用了BeanOutputConverter

小结

Spring AI提供了Structured Output Converters来将LLM的输出转换为结构化的格式。目前主要有MapOutputConverter、ListOutputConverter、BeanOutputConverter这几种实现。

doc

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值