聊聊Spring AI的RAG

本文主要研究一下Spring AI的RAG

Sequential RAG Flows

Naive RAG

Advisor retrievalAugmentationAdvisor = RetrievalAugmentationAdvisor.builder()
        .documentRetriever(VectorStoreDocumentRetriever.builder()
                .similarityThreshold(0.50)
                .vectorStore(vectorStore)
                .build())
        .queryAugmenter(ContextualQueryAugmenter.builder()
                .allowEmptyContext(true)
                .build())
        .build();

String answer = chatClient.prompt()
        .advisors(retrievalAugmentationAdvisor)
        .user(question)
        .call()
        .content();

allowEmptyContext为true告诉大模型不回答context为empty的问题

Advanced RAG

Advisor retrievalAugmentationAdvisor = RetrievalAugmentationAdvisor.builder()
        .queryTransformers(RewriteQueryTransformer.builder()
                .chatClientBuilder(chatClientBuilder.build().mutate())
                .build())
        .documentRetriever(VectorStoreDocumentRetriever.builder()
                .similarityThreshold(0.50)
                .vectorStore(vectorStore)
                .build())
        .build();

String answer = chatClient.prompt()
        .advisors(retrievalAugmentationAdvisor)
        .user(question)
        .call()
        .content();

Advanced RAG可以设置queryTransformers来进行查询转换

Modular RAG

Spring AI受Modular RAG: Transforming RAG Systems into LEGO-like Reconfigurable Frameworks启发实现了Modular RAG,主要分为如下几个阶段:Pre-Retrieval、Retrieval、Post-Retrieval、Generation

Pre-Retrieval

增强和转换用户输入,使其更有效地执行检索任务,解决格式不正确的查询、query 语义不清晰、或不受支持的语言等。

1. QueryAugmenter 查询增强

使用附加的上下文数据信息增强用户query,提供大模型回答问题时的必要上下文信息;

  • ContextualQueryAugmenter使用上下文来增强query
QueryAugmenter augmenter = ContextualQueryAugmenter. builder()    
		.allowEmptyContext(false)    
		.build(); 
Query augmentedQuery = augmenter.augment(query, documents);

2. QueryTransformer 查询改写

因为用户的输入通常是片面的,关键信息较少,不便于大模型理解和回答问题。因此需要使用prompt调优手段或者大模型改写用户query;
当使用QueryTransformer时建议设置比较低的temperature(比如0.0)来确保结果的准确性
它有CompressionQueryTransformer、RewriteQueryTransformer、TranslationQueryTransformer三种实现

  • CompressionQueryTransformer使用大模型来压缩会话历史
Query query = Query.builder()
        .text("And what is its second largest city?")
        .history(new UserMessage("What is the capital of Denmark?"),
                new AssistantMessage("Copenhagen is the capital of Denmark."))
        .build();

QueryTransformer queryTransformer = CompressionQueryTransformer.builder()
        .chatClientBuilder(chatClientBuilder)
        .build();

Query transformedQuery = queryTransformer.transform(query);
  • RewriteQueryTransformer使用大模型来重写query
Query query = new Query("I'm studying machine learning. What is an LLM?");

QueryTransformer queryTransformer = RewriteQueryTransformer.builder()
        .chatClientBuilder(chatClientBuilder)
        .build();

Query transformedQuery = queryTransformer.transform(query);
  • TranslationQueryTransformer使用大模型来翻译query
Query query = new Query("Hvad er Danmarks hovedstad?");

QueryTransformer queryTransformer = TranslationQueryTransformer.builder()
        .chatClientBuilder(chatClientBuilder)
        .targetLanguage("english")
        .build();

Query transformedQuery = queryTransformer.transform(query);

3. QueryExpander 查询扩展

将用户 query 扩展为多个语义不同的变体以获得不同视角,有助于检索额外的上下文信息并增加找到相关结果的机会。

  • MultiQueryExpander使用大模型扩展query
MultiQueryExpander queryExpander = MultiQueryExpander.builder()
    .chatClientBuilder(chatClientBuilder)
    .numberOfQueries(3)
    .includeOriginal(false) // 默认会包含原始query,设置为false表示不包含
    .build();
List<Query> queries = expander.expand(new Query("How to run a Spring Boot app?"));

Retrieval

负责查询向量存储等数据系统并检索和用户query相关性最高的Document。

1. DocumentRetriever 检索器

根据 QueryExpander 使用不同的数据源进行检索,例如 搜索引擎、向量存储、数据库或知识图等;它主要有VectorStoreDocumentRetriever、WebSearchRetriever两个实现

  • VectorStoreDocumentRetriever
DocumentRetriever retriever = VectorStoreDocumentRetriever.builder()
    .vectorStore(vectorStore)
    .similarityThreshold(0.73)
    .topK(5)
    .filterExpression(new FilterExpressionBuilder()
        .eq("genre", "fairytale")
        .build())
    .build();
List<Document> documents = retriever.retrieve(new Query("What is the main character of the story?"));

2. DocumentJoiner

将从多个query和从多个数据源检索到的Document合并为一个Document集合;它有ConcatenationDocumentJoiner实现

  • ConcatenationDocumentJoiner
Map<Query, List<List<Document>>> documentsForQuery = ...
DocumentJoiner documentJoiner = new ConcatenationDocumentJoiner();
List<Document> documents = documentJoiner.join(documentsForQuery);

Post-Retrieval

负责处理检索到的 Document 以获得最佳的输出结果,解决模型中的中间丢失和上下文长度限制等。

  1. DocumentRanker:根据Document和用户query的相关性对Document进行排序和排名;
  2. DocumentSelector:用于从检索到的Document列表中删除不相关或冗余文档;
  3. DocumentCompressor:用于压缩每个Document,减少检索到的信息中的噪音和冗余。

Generation

生成用户 Query 对应的大模型输出。

源码

org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java

	public static final class Builder {

		private List<QueryTransformer> queryTransformers;

		private QueryExpander queryExpander;

		private DocumentRetriever documentRetriever;

		private DocumentJoiner documentJoiner;

		private QueryAugmenter queryAugmenter;

		private TaskExecutor taskExecutor;

		private Scheduler scheduler;

		private Integer order;

		private Builder() {
		}

		//......
	}	

RetrievalAugmentationAdvisor的Builder提供了Pre-Retrieval(queryAugmenterqueryTransformersqueryExpander)、Retrieval(documentRetrieverdocumentJoiner)这几个组件的配置。

示例

ModuleRAGBasicController

@RestController
@RequestMapping("/module-rag")
public class ModuleRAGBasicController {

	private final ChatClient chatClient;
	private final RetrievalAugmentationAdvisor retrievalAugmentationAdvisor;

	public ModuleRAGBasicController(ChatClient.Builder chatClientBuilder, VectorStore vectorStore) {

		this.chatClient = chatClientBuilder.build();
		this.retrievalAugmentationAdvisor = RetrievalAugmentationAdvisor.builder()
				.documentRetriever(VectorStoreDocumentRetriever.builder()
						.similarityThreshold(0.50)
						.vectorStore(vectorStore)
						.build()
				).build();
	}

	@GetMapping("/rag/basic")
	public String chatWithDocument(@RequestParam("prompt") String prompt) {

		return chatClient.prompt()
				.advisors(retrievalAugmentationAdvisor)
				.user(prompt)
				.call()
				.content();
	}

}

ModuleRAGCompressionController

@RestController
@RequestMapping("/module-rag")
public class ModuleRAGCompressionController {

	private final ChatClient chatClient;

	private final MessageChatMemoryAdvisor chatMemoryAdvisor;

	private final RetrievalAugmentationAdvisor retrievalAugmentationAdvisor;

	public ModuleRAGCompressionController(
			ChatClient.Builder chatClientBuilder,
			ChatMemory chatMemory,
			VectorStore vectorStore) {

		this.chatClient = chatClientBuilder.build();

		this.chatMemoryAdvisor = MessageChatMemoryAdvisor.builder(chatMemory)
				.build();

		var documentRetriever = VectorStoreDocumentRetriever.builder()
				.vectorStore(vectorStore)
				.similarityThreshold(0.50)
				.build();

		var queryTransformer = CompressionQueryTransformer.builder()
				.chatClientBuilder(chatClientBuilder.build().mutate())
				.build();

		this.retrievalAugmentationAdvisor = RetrievalAugmentationAdvisor.builder()
				.documentRetriever(documentRetriever)
				.queryTransformers(queryTransformer)
				.build();
	}

	@PostMapping("/rag/compression/{chatId}")
	public String rag(
			@RequestBody String prompt,
			@PathVariable("chatId") String conversationId
	) {

		return chatClient.prompt()
				.advisors(chatMemoryAdvisor, retrievalAugmentationAdvisor)
				.advisors(advisors -> advisors.param(
						AbstractChatMemoryAdvisor.CHAT_MEMORY_CONVERSATION_ID_KEY, conversationId))
				.user(prompt)
				.call()
				.content();
	}

}

ModuleRAGMemoryController

@RestController
@RequestMapping("/module-rag")
public class ModuleRAGMemoryController {

	private final ChatClient chatClient;

	private final MessageChatMemoryAdvisor chatMemoryAdvisor;

	private final RetrievalAugmentationAdvisor retrievalAugmentationAdvisor;

	public ModuleRAGMemoryController(
			ChatClient.Builder chatClientBuilder,
			ChatMemory chatMemory,
			VectorStore vectorStore
	) {

		this.chatClient = chatClientBuilder.build();
		this.chatMemoryAdvisor = MessageChatMemoryAdvisor.builder(chatMemory)
				.build();

		this.retrievalAugmentationAdvisor = RetrievalAugmentationAdvisor.builder()
				.documentRetriever(VectorStoreDocumentRetriever.builder()
						.similarityThreshold(0.50)
						.vectorStore(vectorStore)
						.build())
				.build();
	}

	@PostMapping("/rag/memory/{chatId}")
	public String chatWithDocument(
			@RequestBody String prompt,
			@PathVariable("chatId") String conversationId
	) {

		return chatClient.prompt()
				.advisors(chatMemoryAdvisor, retrievalAugmentationAdvisor)
				.advisors(advisors -> advisors.param(
						AbstractChatMemoryAdvisor.CHAT_MEMORY_CONVERSATION_ID_KEY, conversationId))
				.user(prompt)
				.call()
				.content();
	}

}

ModuleRAGRewriteController

@RestController
@RequestMapping("/module-rag")
public class ModuleRAGRewriteController {

	private final ChatClient chatClient;

	private final RetrievalAugmentationAdvisor retrievalAugmentationAdvisor;

	public ModuleRAGRewriteController(
			ChatClient.Builder chatClientBuilder,
			VectorStore vectorStore
	) {

		this.chatClient = chatClientBuilder.build();

		var documentRetriever = VectorStoreDocumentRetriever.builder()
				.vectorStore(vectorStore)
				.similarityThreshold(0.50)
				.build();

		var queryTransformer = RewriteQueryTransformer.builder()
				.chatClientBuilder(chatClientBuilder.build().mutate())
				.targetSearchSystem("vector store")
				.build();

		this.retrievalAugmentationAdvisor = RetrievalAugmentationAdvisor.builder()
				.documentRetriever(documentRetriever)
				.queryTransformers(queryTransformer)
				.build();
	}

	@PostMapping("/rag/rewrite")
	public String rag(@RequestBody String prompt) {

		return chatClient.prompt()
				.advisors(retrievalAugmentationAdvisor)
				.user(prompt)
				.call()
				.content();
	}
}

ModuleRAGTranslationController

@RestController
@RequestMapping("/module-rag")
public class ModuleRAGTranslationController {

	private final ChatClient chatClient;

	private final RetrievalAugmentationAdvisor retrievalAugmentationAdvisor;

	public ModuleRAGTranslationController(ChatClient.Builder chatClientBuilder, VectorStore vectorStore) {
		this.chatClient = chatClientBuilder.build();

		var documentRetriever = VectorStoreDocumentRetriever.builder()
				.vectorStore(vectorStore)
				.similarityThreshold(0.50)
				.build();

		var queryTransformer = TranslationQueryTransformer.builder()
				.chatClientBuilder(chatClientBuilder.build().mutate())
				.targetLanguage("english")
				.build();

		this.retrievalAugmentationAdvisor = RetrievalAugmentationAdvisor.builder()
				.documentRetriever(documentRetriever)
				.queryTransformers(queryTransformer)
				.build();
	}

	@PostMapping("/rag/translation")
	public String rag(@RequestBody String prompt) {

		return chatClient.prompt()
				.advisors(retrievalAugmentationAdvisor)
				.user(prompt)
				.call()
				.content();
	}

}

小结

Spring AI通过RetrievalAugmentationAdvisor提供了开箱即用的RAG flows,主要有两大类,一是Sequential RAG Flows(Naive RAGAdvanced RAG),另外Spring AI受Modular RAG: Transforming RAG Systems into LEGO-like Reconfigurable Frameworks启发实现了Modular RAG,主要分为如下几个阶段:Pre-Retrieval、Retrieval、Post-Retrieval、Generation这几个阶段。RetrievalAugmentationAdvisor的Builder提供了Pre-Retrieval(queryAugmenterqueryTransformersqueryExpander)、Retrieval(documentRetrieverdocumentJoiner)这几个组件的配置。

doc

### Spring AI RAG 的概述 Spring AI RAG(Retrieval-Augmented Generation)是一种结合了检索增强生成的技术,旨在提升自然语言处理任务中的上下文理解和生成质量。该技术通过将外部知识库与生成模型相结合,在保持高效的同时提高了生成内容的准确性[^1]。 具体来说,Spring AI RAG 利用了嵌入模型(Embedding Model)和向量数据库来实现高效的语义匹配。这种设计允许开发者快速检索相关文档并将其作为输入传递给生成模型,从而显著提高生成结果的质量[^2]。 --- ### Spring AI RAG 的核心组件 以下是 Spring AI RAG 中的核心组成部分及其作用: #### 1. **Embedding 模型** 嵌入模型用于将文本转换为高维向量表示。这些向量可以捕捉文本的语义特征,并被存储在向量数据库中以便后续检索。例如,可以通过引入 OpenAI 提供的 `embedding` API 来完成这一过程。 #### 2. **向量数据库** 向量数据库负责存储和管理由嵌入模型生成的向量数据。常见的向量数据库包括 Milvus 和 Pinecone 等工具。它们提供了高效的相似度查询功能,能够在毫秒级时间内返回最接近目标向量的结果。 #### 3. **生成模型** 在获取到相关文档片段之后,生成模型会利用这些信息生成最终输出。这一步骤通常涉及调用大型预训练语言模型(LLM),比如通义千问 Qwen 或其他第三方服务商提供的 LLM 接口。 --- ### Spring AI RAG 的使用方法 为了更好地理解如何使用 Spring AI RAG 技术,下面给出了一种典型的配置流程以及代码示例。 #### 配置 Maven 依赖项 首先需要确保项目的 `pom.xml` 文件中包含了必要的依赖项。对于希望集成 OpenAI Embedding 功能的情况,可按照如下方式添加依赖: ```xml <dependency> <groupId>org.springframework.ai</groupId> <artifactId>spring-ai-openai-spring-boot-starter</artifactId> <version>1.0.0-M1</version> </dependency> ``` 上述代码片段展示了如何引入 OpenAI 支持所需的 Starter 组件。 #### 初始化服务实例 接下来定义一个简单的 Java 类用来初始化所需的服务对象。这里假设已经完成了基础环境搭建工作并且拥有有效的 API 密钥访问权限。 ```java import org.springframework.ai.openai.OpenAIClient; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @Configuration public class AppConfig { @Bean public OpenAIClient openAiClient() { return new OpenAIClient("your-api-key-here"); } } ``` 此部分实现了对 OpenAI 客户端实例化操作的支持。 #### 实现基本逻辑 最后编写实际业务场景下的代码实现。以下是一个完整的例子演示了整个流程从创建请求直到接收响应为止的关键环节。 ```java @Service public class RagService { private final OpenAIClient client; public RagService(OpenAIClient client) { this.client = client; } public String generateResponse(String query) throws Exception { // Step 1: Convert the input into an embedding vector. var embeddingsRequest = new EmbeddingsRequest(query); var embeddingsResponse = client.createEmbeddings(embeddingsRequest); double[] queryVector = embeddingsResponse.getData().get(0).getEmbedding(); // Step 2: Query a vector database to find similar documents (not shown here). // Assume we have retrieved some relevant document snippets as context. List<String> contexts = Arrays.asList( "Context snippet A", "Context snippet B" ); StringBuilder promptBuilder = new StringBuilder(); for (String ctx : contexts) { promptBuilder.append(ctx).append("\n\n"); } promptBuilder.append("Question: ").append(query).append("\nAnswer:"); // Step 3: Use the combined information to call a generation model like Qwen or GPT. CompletionRequest completionRequest = new CompletionRequest(promptBuilder.toString()); CompletionResponse response = client.createCompletion(completionRequest); return response.getChoices().get(0).getText(); } } ``` 以上代码说明了如何逐步执行检索增强生成的任务链路,涵盖了从原始询问转化为嵌入形式直至最终获得答案的所有主要阶段。 --- ###
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值