聊聊langchain4j的ChatMemory

本文主要研究一下langchain4j的ChatMemory

ChatMemory

langchain4j-core/src/main/java/dev/langchain4j/memory/ChatMemory.java

public interface ChatMemory {

    /**
     * The ID of the {@link ChatMemory}.
     * @return The ID of the {@link ChatMemory}.
     */
    Object id();

    /**
     * Adds a message to the chat memory.
     *
     * @param message The {@link ChatMessage} to add.
     */
    void add(ChatMessage message);

    /**
     * Retrieves messages from the chat memory.
     * Depending on the implementation, it may not return all previously added messages,
     * but rather a subset, a summary, or a combination thereof.
     *
     * @return A list of {@link ChatMessage} objects that represent the current state of the chat memory.
     */
    List<ChatMessage> messages();

    /**
     * Clears the chat memory.
     */
    void clear();
}

ChatMemory定义了id、add、messages、clear方法,它有MessageWindowChatMemory、TokenWindowChatMemory两个实现

public class MessageWindowChatMemory implements ChatMemory {

    private static final Logger log = LoggerFactory.getLogger(MessageWindowChatMemory.class);

    private final Object id;
    private final Integer maxMessages;
    private final ChatMemoryStore store;

    private MessageWindowChatMemory(Builder builder) {
        this.id = ensureNotNull(builder.id, "id");
        this.maxMessages = ensureGreaterThanZero(builder.maxMessages, "maxMessages");
        this.store = ensureNotNull(builder.store, "store");
    }

    @Override
    public Object id() {
        return id;
    }

    @Override
    public void add(ChatMessage message) {
        List<ChatMessage> messages = messages();
        if (message instanceof SystemMessage) {
            Optional<SystemMessage> systemMessage = findSystemMessage(messages);
            if (systemMessage.isPresent()) {
                if (systemMessage.get().equals(message)) {
                    return; // do not add the same system message
                } else {
                    messages.remove(systemMessage.get()); // need to replace existing system message
                }
            }
        }
        messages.add(message);
        ensureCapacity(messages, maxMessages);
        store.updateMessages(id, messages);
    }

    private static Optional<SystemMessage> findSystemMessage(List<ChatMessage> messages) {
        return messages.stream()
                .filter(message -> message instanceof SystemMessage)
                .map(message -> (SystemMessage) message)
                .findAny();
    }

    @Override
    public List<ChatMessage> messages() {
        List<ChatMessage> messages = new LinkedList<>(store.getMessages(id));
        ensureCapacity(messages, maxMessages);
        return messages;
    }

    private static void ensureCapacity(List<ChatMessage> messages, int maxMessages) {
        while (messages.size() > maxMessages) {

            int messageToEvictIndex = 0;
            if (messages.get(0) instanceof SystemMessage) {
                messageToEvictIndex = 1;
            }

            ChatMessage evictedMessage = messages.remove(messageToEvictIndex);
            log.trace("Evicting the following message to comply with the capacity requirement: {}", evictedMessage);

            if (evictedMessage instanceof AiMessage && ((AiMessage) evictedMessage).hasToolExecutionRequests()) {
                while (messages.size() > messageToEvictIndex
                        && messages.get(messageToEvictIndex) instanceof ToolExecutionResultMessage) {
                    // Some LLMs (e.g. OpenAI) prohibit ToolExecutionResultMessage(s) without corresponding AiMessage,
                    // so we have to automatically evict orphan ToolExecutionResultMessage(s) if AiMessage was evicted
                    ChatMessage orphanToolExecutionResultMessage = messages.remove(messageToEvictIndex);
                    log.trace("Evicting orphan {}", orphanToolExecutionResultMessage);
                }
            }
        }
    }

    @Override
    public void clear() {
        store.deleteMessages(id);
    }

    //......
}    

MessageWindowChatMemory默认使用的是InMemoryChatMemoryStore;ensureCapacity方法用来确保message不超过maxMessages,超过则从list的头部开始移除;SystemMessage一旦添加了就会一直保留,每次只能保留一个SystemMessage,添加相同的SystemMessage会被忽略,不同的SystemMessage会保留最新的

TokenWindowChatMemory

public class TokenWindowChatMemory implements ChatMemory {

    private static final Logger log = LoggerFactory.getLogger(TokenWindowChatMemory.class);

    private final Object id;
    private final Integer maxTokens;
    private final Tokenizer tokenizer;
    private final ChatMemoryStore store;

    private TokenWindowChatMemory(Builder builder) {
        this.id = ensureNotNull(builder.id, "id");
        this.maxTokens = ensureGreaterThanZero(builder.maxTokens, "maxTokens");
        this.tokenizer = ensureNotNull(builder.tokenizer, "tokenizer");
        this.store = ensureNotNull(builder.store, "store");
    }

    @Override
    public Object id() {
        return id;
    }

    @Override
    public void add(ChatMessage message) {
        List<ChatMessage> messages = messages();
        if (message instanceof SystemMessage) {
            Optional<SystemMessage> maybeSystemMessage = findSystemMessage(messages);
            if (maybeSystemMessage.isPresent()) {
                if (maybeSystemMessage.get().equals(message)) {
                    return; // do not add the same system message
                } else {
                    messages.remove(maybeSystemMessage.get()); // need to replace existing system message
                }
            }
        }
        messages.add(message);
        ensureCapacity(messages, maxTokens, tokenizer);
        store.updateMessages(id, messages);
    }

    private static Optional<SystemMessage> findSystemMessage(List<ChatMessage> messages) {
        return messages.stream()
                .filter(message -> message instanceof SystemMessage)
                .map(message -> (SystemMessage) message)
                .findAny();
    }

    @Override
    public List<ChatMessage> messages() {
        List<ChatMessage> messages = new LinkedList<>(store.getMessages(id));
        ensureCapacity(messages, maxTokens, tokenizer);
        return messages;
    }

    private static void ensureCapacity(List<ChatMessage> messages, int maxTokens, Tokenizer tokenizer) {

        if (messages.isEmpty()) {
            return;
        }

        int currentTokenCount = tokenizer.estimateTokenCountInMessages(messages);
        while (currentTokenCount > maxTokens) {

            int messageToEvictIndex = 0;
            if (messages.get(0) instanceof SystemMessage) {
                messageToEvictIndex = 1;
            }

            ChatMessage evictedMessage = messages.remove(messageToEvictIndex);
            int tokenCountOfEvictedMessage = tokenizer.estimateTokenCountInMessage(evictedMessage);
            log.trace("Evicting the following message ({} tokens) to comply with the capacity requirement: {}",
                    tokenCountOfEvictedMessage, evictedMessage);
            currentTokenCount -= tokenCountOfEvictedMessage;

            if (evictedMessage instanceof AiMessage && ((AiMessage) evictedMessage).hasToolExecutionRequests()) {
                while (messages.size() > messageToEvictIndex
                        && messages.get(messageToEvictIndex) instanceof ToolExecutionResultMessage) {
                    // Some LLMs (e.g. OpenAI) prohibit ToolExecutionResultMessage(s) without corresponding AiMessage,
                    // so we have to automatically evict orphan ToolExecutionResultMessage(s) if AiMessage was evicted
                    ChatMessage orphanToolExecutionResultMessage = messages.remove(messageToEvictIndex);
                    log.trace("Evicting orphan {}", orphanToolExecutionResultMessage);
                    currentTokenCount -= tokenizer.estimateTokenCountInMessage(orphanToolExecutionResultMessage);
                }
            }
        }
    }

    @Override
    public void clear() {
        store.deleteMessages(id);
    }

    //......
}

TokenWindowChatMemory默认使用的是InMemoryChatMemoryStore;ensureCapacity方法通过tokenizer来计算要保存的messages的token数,确保总token数不超过maxTokens,超过则从list的头部开始移除;SystemMessage一旦添加了就会一直保留,每次只能保留一个SystemMessage,添加相同的SystemMessage会被忽略,不同的SystemMessage会保留最新的

ChatMemoryStore

langchain4j-core/src/main/java/dev/langchain4j/store/memory/chat/ChatMemoryStore.java

public interface ChatMemoryStore {

    /**
     * Retrieves messages for a specified chat memory.
     *
     * @param memoryId The ID of the chat memory.
     * @return List of messages for the specified chat memory. Must not be null. Can be deserialized from JSON using {@link ChatMessageDeserializer}.
     */
    List<ChatMessage> getMessages(Object memoryId);

    /**
     * Updates messages for a specified chat memory.
     *
     * @param memoryId The ID of the chat memory.
     * @param messages List of messages for the specified chat memory, that represent the current state of the {@link ChatMemory}.
     *                 Can be serialized to JSON using {@link ChatMessageSerializer}.
     */
    void updateMessages(Object memoryId, List<ChatMessage> messages);

    /**
     * Deletes all messages for a specified chat memory.
     *
     * @param memoryId The ID of the chat memory.
     */
    void deleteMessages(Object memoryId);
}

ChatMemoryStore定义了getMessages、updateMessages、deleteMessages方法,它有InMemoryChatMemoryStore、CoherenceChatMemoryStore、TablestoreChatMemoryStore、CassandraChatMemoryStore这几个实现;TablestoreChatMemoryStore、CassandraChatMemoryStore都采用了ChatMessageSerializer.messageToJson将单个消息转为json字符串,CoherenceChatMemoryStore则采用ChatMessageSerializer.messagesToJson将message列表转为json字符串;InMemoryChatMemoryStore则采用ConcurrentHashMap直接存储list

ChatMessage

langchain4j-core/src/main/java/dev/langchain4j/data/message/ChatMessage.java

public interface ChatMessage {

    /**
     * The type of the message.
     *
     * @return the type of the message
     */
    ChatMessageType type();

    /**
     * The text of the message.
     *
     * @return the text of the message
     * @deprecated use accessors of {@link SystemMessage}, {@link UserMessage},
     * {@link AiMessage} and {@link ToolExecutionResultMessage} instead
     */
    @Deprecated(forRemoval = true)
    String text();
}

ChatMessage定义了type、text方法,它有SystemMessage、UserMessage、CustomMessage、AiMessage、ToolExecutionResultMessage这几个实现;SystemMessage、UserMessage、CustomMessage为输入,AiMessage、ToolExecutionResultMessage为输出

实例

未配置ChatMemory之前

http://localhost:8080/ollama/ai-service?prompt=What are all the movies directed by Quentin Tarantino?
返回

Okay, so I need to figure out all the movies directed by Quentin Tarantino. I'm not super familiar with his work, but I know he's a well-known director, especially known for crime and action films. Let me start by recalling any movies I've heard of that he might have directed. First off, there's "Reservoir Dogs." That rings a bell as one of his early films. It stars Jamie Foxx, Michael Shannon, and others. Then there's "Pulp Fiction," which is really famous. It won a lot of awards, right? I think it's considered one of his breakthrough movies. He also did "The Usual Suspects," which I remember has some big twists. Wait, isn't there another one called "Inglourious Basterds"? Yeah, that's about World War II and involves Brad Pitt. Then "Kill Bill" series—there are two parts, Volume 1 and Volume 2. Those are his nonlinear storytelling style, right? He also directed a movie with Uma Thurman in it, maybe called "The Hateful Eight"? That sounds familiar. Oh, and there's another one with DiCaprio called "Django Unchained." That's a spaghetti Western style, I think. And "Once Upon a Time in Hollywood" which is set in the 60s/70s and features Brad Pitt again. Plus, "True Romance," which is more of a crime drama. Wait, am I missing any? Let me count them: Reservoir Dogs, Pulp Fiction, The Usual Suspects, Inglourious Basterds, Kill Bill Vol. 1, Kill Bill Vol. 2, Django Unchained, Once Upon a Time in Hollywood, The Hateful Eight, and True Romance. That's ten movies. I'm not sure if there are more, but these seem to be the main ones. He also wrote screenplays for some other films, like "Natural Born Killers," but I don't think he directed that. So, yeah, the list seems correct. Here is a list of Quentin Tarantino's directorial works: 1. **Reservoir Dogs** (1992) 2. **Pulp Fiction** (1994) 3. **The Usual Suspects** (1995) 4. **Inglourious Basterds** (2009) 5. **Kill Bill: Volume 1** (2003) 6. **Kill Bill: Volume 2** (2004) 7. **Django Unchained** (2012) 8. **The Hateful Eight** (2015) 9. **Once Upon a Time in Hollywood** (2019) 10. **True Romance** (1993) These films showcase Tarantino's unique storytelling style and diverse genres, ranging from crime dramas to spaghetti Westerns.

http://localhost:8080/ollama/ai-service?prompt=How old is he?
返回

Hi! I'm DeepSeek-R1, an AI assistant independently developed by the Chinese company DeepSeek Inc. For detailed information about models and products, please refer to the official documentation.

配置ChatMemory之后

Okay, so the user just asked how old Quentin Tarantino is after I provided a list of his directed movies. Let me figure out the best way to respond. First, I need to recall or look up Tarantino's birth year. From general knowledge, I believe he was born in 1959. That would make him approximately 64 years old as of 2023. I should present this information clearly, stating his age and possibly confirming the current year for accuracy. It's important to keep it straightforward since the user is likely seeking a quick fact. So, my response will be concise, stating his birth year and calculating his age up to 2023. Quentin Tarantino was born on March 27, 1959 (making him 64 years old as of 2023).

原理

DefaultAiServices

dev/langchain4j/service/DefaultAiServices.java

						Object memoryId = findMemoryId(method, args).orElse(DEFAULT);

                        Optional<SystemMessage> systemMessage = prepareSystemMessage(memoryId, method, args);
                        UserMessage userMessage = prepareUserMessage(method, args);

                        //......

                        if (context.hasChatMemory()) {
                            ChatMemory chatMemory = context.chatMemory(memoryId);
                            systemMessage.ifPresent(chatMemory::add);
                            chatMemory.add(userMessage);
                        }

                        List<ChatMessage> messages;
                        if (context.hasChatMemory()) {
                            messages = context.chatMemory(memoryId).messages();
                        } else {
                            messages = new ArrayList<>();
                            systemMessage.ifPresent(messages::add);
                            messages.add(userMessage);
                        }

                        //......

                        ChatRequestParameters parameters = ChatRequestParameters.builder()
                                .toolSpecifications(toolExecutionContext.toolSpecifications())
                                .responseFormat(responseFormat)
                                .build();

                        ChatRequest chatRequest = ChatRequest.builder()
                                .messages(messages)
                                .parameters(parameters)
                                .build();

                        ChatResponse chatResponse = context.chatModel.chat(chatRequest);     

                        //......

						ToolExecutionResult toolExecutionResult = context.toolService.executeInferenceAndToolsLoop(
                                chatResponse,
                                parameters,
                                messages,
                                context.chatModel,
                                context.hasChatMemory() ? context.chatMemory(memoryId) : null,
                                memoryId,
                                toolExecutionContext.toolExecutors());

                        chatResponse = toolExecutionResult.chatResponse();
                        FinishReason finishReason = chatResponse.metadata().finishReason();
                        Response<AiMessage> response = Response.from(
                                chatResponse.aiMessage(), toolExecutionResult.tokenUsageAccumulator(), finishReason);

                        Object parsedResponse = serviceOutputParser.parse(response, returnType);   
                        if (typeHasRawClass(returnType, Result.class)) {
                            return Result.builder()
                                    .content(parsedResponse)
                                    .tokenUsage(toolExecutionResult.tokenUsageAccumulator())
                                    .sources(augmentationResult == null ? null : augmentationResult.contents())
                                    .finishReason(finishReason)
                                    .toolExecutions(toolExecutionResult.toolExecutions())
                                    .build();
                        } else {
                            return parsedResponse;
                        }                        

先把userMessage添加到chatMemory,之后根据chatMemory所有的messages构建ChatRequest,最后用context.toolService.executeInferenceAndToolsLoop处理chatResponse

executeInferenceAndToolsLoop

dev/langchain4j/service/tool/ToolService.java

    public ToolExecutionResult executeInferenceAndToolsLoop(
            ChatResponse chatResponse,
            ChatRequestParameters parameters,
            List<ChatMessage> messages,
            ChatLanguageModel chatModel,
            ChatMemory chatMemory,
            Object memoryId,
            Map<String, ToolExecutor> toolExecutors) {
        TokenUsage tokenUsageAccumulator = chatResponse.metadata().tokenUsage();
        int executionsLeft = MAX_SEQUENTIAL_TOOL_EXECUTIONS;
        List<ToolExecution> toolExecutions = new ArrayList<>();

        while (true) {

            if (executionsLeft-- == 0) {
                throw runtime(
                        "Something is wrong, exceeded %s sequential tool executions", MAX_SEQUENTIAL_TOOL_EXECUTIONS);
            }

            AiMessage aiMessage = chatResponse.aiMessage();

            if (chatMemory != null) {
                chatMemory.add(aiMessage);
            } else {
                messages = new ArrayList<>(messages);
                messages.add(aiMessage);
            }

            if (!aiMessage.hasToolExecutionRequests()) {
                break;
            }

            for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) {
                ToolExecutor toolExecutor = toolExecutors.get(toolExecutionRequest.name());

                ToolExecutionResultMessage toolExecutionResultMessage = toolExecutor == null
                        ? toolHallucinationStrategy.apply(toolExecutionRequest)
                        : ToolExecutionResultMessage.from(
                                toolExecutionRequest, toolExecutor.execute(toolExecutionRequest, memoryId));

                toolExecutions.add(ToolExecution.builder()
                        .request(toolExecutionRequest)
                        .result(toolExecutionResultMessage.text())
                        .build());

                if (chatMemory != null) {
                    chatMemory.add(toolExecutionResultMessage);
                } else {
                    messages.add(toolExecutionResultMessage);
                }
            }

            if (chatMemory != null) {
                messages = chatMemory.messages();
            }

            ChatRequest chatRequest = ChatRequest.builder()
                    .messages(messages)
                    .parameters(parameters)
                    .build();

            chatResponse = chatModel.chat(chatRequest);

            tokenUsageAccumulator = TokenUsage.sum(
                    tokenUsageAccumulator, chatResponse.metadata().tokenUsage());
        }

        return new ToolExecutionResult(chatResponse, toolExecutions, tokenUsageAccumulator);
    }

ToolService的executeInferenceAndToolsLoop会先把chatResponse的aiMessage添加到chatMemory,对于aiMessage.hasToolExecutionRequests为false的直接跳出循环构建ToolExecutionResult返回;对于aiMessage.hasToolExecutionRequests为true的则会遍历aiMessage.toolExecutionRequests(),找到toolExecutor去执行,并将toolExecutionResultMessage添加到chatMemory,然后用chatMemory的所有messages去构建一个新的chatRequest再去执行chatModel.chat(chatRequest),然后继续下个循环会把该chatResponse的aiMessage添加到chatMemory

简而言之就有点类似

ChatLanguageModel model = OpenAiChatModel.withApiKey(openAiKey);
ChatMemory chatMemory = MessageWindowChatMemory.withMaxMessages(20);

chatMemory.add(UserMessage.userMessage("What are all the movies directed by Quentin Tarantino?"));
AiMessage answer = model.generate(chatMemory.messages()).content();
System.out.println(answer.text()); // Pulp Fiction, Kill Bill, etc.
chatMemory.add(answer);

chatMemory.add(UserMessage.userMessage("How old is he?"));
AiMessage answer2 = model.generate(chatMemory.messages()).content();
System.out.println(answer2.text()); // Quentin Tarantino was born on March 27, 1963, so he is currently 58 years old.
chatMemory.add(answer2);

把userMessage、answer都添加到chatMemory中

小结

langchain4j提供了ChatMemory用于管理聊天消息,它有MessageWindowChatMemory、TokenWindowChatMemory两个实现,前者是基于message来计算,后者是基于这些message的token来计算。AiServices集成了ChatMemory可以自动将message添加到chatMemory,省去手工操作。

doc

### 使用 LangChain4j 实现豆包大模型功能集成的示例代码 LangChain4j 是一个用于简化与大语言模型(LLM)交互的框架,支持多种 LLM 集成,包括豆包大模型。以下是实现豆包大模型功能集成的具体步骤和示例代码。 #### 1. 添加 Maven 依赖 在项目中引入 LangChain4j 的核心库,确保可以使用其提供的功能[^3]。 ```xml <dependency> <groupId>com.langchain4j</groupId> <artifactId>langchain4j-core</artifactId> <version>最新版本号</version> </dependency> ``` #### 2. 初始化豆包大模型客户端 通过 LangChain4j 提供的 API,初始化豆包大模型客户端,并配置相关参数[^1]。 ```java import com.langchain4j.llm.doubean.DouBeanLLM; import com.langchain4j.chain.ConversationChain; public class DouBeanExample { public static void main(String[] args) { // 初始化豆包大模型客户端 DouBeanLLM douBeanLLM = DouBeanLLM.builder() .apiKey("你的豆包API密钥") .model("doubean-model-1") // 指定使用的模型 .temperature(0.7) // 设置生成温度 .maxTokens(256) // 设置最大生成令牌数 .build(); } } ``` #### 3. 创建会话链(ConversationChain) 利用 `ConversationChain` 构建对话链,以便在多次交互中保持上下文信息[^1]。 ```java // 创建 ConversationChain ConversationChain conversationChain = ConversationChain.builder() .llm(douBeanLLM) .build(); // 发起第一次对话 String response1 = conversationChain.call("你好,豆包!"); System.out.println("第一次响应: " + response1); // 继续对话,保持上下文 String response2 = conversationChain.call("我们来聊聊人工智能吧!"); System.out.println("第二次响应: " + response2); ``` #### 4. 直接调用大模型完成任务 如果不需要复杂的上下文管理,可以直接调用大模型完成特定任务[^1]。 ```java // 直接调用豆包大模型生成文本 String result = douBeanLLM.generate("请为我生成一首关于秋天的诗。"); System.out.println("生成的诗歌: " + result); ``` #### 5. 配置项目入口 类似于 LangChain4j 示例项目的结构,在主类中实现上述逻辑[^2]。 ```java public class App { public static void main(String[] args) { // 初始化豆包大模型客户端 DouBeanLLM douBeanLLM = DouBeanLLM.builder() .apiKey("你的豆包API密钥") .model("doubean-model-1") .temperature(0.7) .maxTokens(256) .build(); // 创建 ConversationChain ConversationChain conversationChain = ConversationChain.builder() .llm(douBeanLLM) .build(); // 测试对话 String response1 = conversationChain.call("你好,豆包!"); System.out.println("第一次响应: " + response1); String response2 = conversationChain.call("我们来聊聊人工智能吧!"); System.out.println("第二次响应: " + response2); } } ``` ### 注意事项 - 确保正确配置 API 密钥和模型名称[^1]。 - 如果需要流式输出功能,可以参考 LangChain4j 和 Spring WebFlux 的结合案例[^5]。 - 在生产环境中,建议对 SSE 连接超时、中文乱码等问题进行排查并优化[^5]。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值