SpringAI--基于MySQL的持久化对话记忆实现

SpringAI–基于MySQL的持久化对话记忆实现

项目源码

对话记忆官方介绍

SpringAI目前提供了一些将对话保存到不同数据源中的实现,比如:

  • InMemoryChatMemory 基于内存存储
  • CassandraChatMemory 在Cassandra中带有过期时间的持久化存储。
  • Neo4jChatMemory 在Neo4j中没有过期时间限制的持久化存储。
  • JdbcChatMemory 在JDBC中没有过期时间限制的持久化存储。

如果要将对话持久化到数据库中,就可以使用JdbcChatMemory。但是spring-ai-starter-model-jdbc依赖模板版本很少,而且缺乏相关介绍,Maven官方仓库还搜不到依赖,所以不推荐使用。在Spring仓库能搜到,但是用的人太少了。

SpringAI源码中只有InMemoryChatMemory实现了ChatMemory。

所以可以自己自定义一个数据库持久化对话记忆。

自定义实现

Spring AI的对话记忆实现非常巧妙,解耦了“存储”和“记忆算法”。

  • 存储:ChatMemory:我们可以单独修改ChatMemory存储来改变对话记忆的保存位置,而无需修改保存对话记忆的流程。
  • 记忆算法:ChatMemory Advisor,advisor可以理解为拦截器,在调用大模型时的前或后执行一些操作
    • MessageChatMemoryAdvisor: 从记忆中(ChatMemory)检索历史对话,并将其作为消息集合添加到提示词中。常用。能更好的保持上下文连贯性。
    • PromptChatMemoryAdvisor: 从记忆中检索历史对话,并将其添加到提示词的系统文本中。可以理解为没有结构性的纯文本。
    • VectorStoreChatMemoryAdvisor: 可以用向量数据库来存储检索历史对话。

ChatMemory接口的方法并不多,需要实现对话消息的增、删、查就可以了。

源码中的conversationId就相当于会话id,每个用户可以有自己的会话id,这个值可以自己来生成,在调用的时候传过去就可以了,就是根据这个值实现了多轮对话(多轮对话的本质实际上就是把历史消息拼接上新的消息再一起发送给大模型)。

自定义持久化ChatMemory

版本

  • JDK21
  • Springboot 3.4.5
  • Spring AI Alibaba 1.0.0-M6.1
  • mysql驱动 8.0.32
  • mybatis plus 3.5.12
依赖
<!--Spring AI Alibaba-->
<!--Spring AI 还不支持国产大模型,所以使用Alibaba-->
<dependency>
    <groupId>com.alibaba.cloud.ai</groupId>
    <artifactId>spring-ai-alibaba-starter</artifactId>
    <version>1.0.0-M6.1</version>
</dependency>

<!-- MySQL 驱动 -->
<dependency>
    <groupId>mysql</groupId>
    <artifactId>mysql-connector-java</artifactId>
    <version>8.0.32</version>
</dependency>

<!-- https://mvnrepository.com/artifact/com.baomidou/mybatis-plus-boot-starter -->
<dependency>
    <groupId>com.baomidou</groupId>
    <artifactId>mybatis-plus-spring-boot3-starter</artifactId>
    <version>3.5.12</version>
</dependency>

<!-- 3.5.9及以上版本想使用mybatis plus分页配置需要单独引入-->
<dependency>
    <groupId>com.baomidou</groupId>
    <artifactId>mybatis-plus-jsqlparser</artifactId>
    <version>3.5.12</version> <!-- 确保版本和 MyBatis Plus 主包一致 -->
</dependency>

SQL

CREATE TABLE ai_chat_memory (
    id              BIGINT AUTO_INCREMENT PRIMARY KEY,
    conversation_id VARCHAR(255) NOT NULL comment '会话id',
    type            VARCHAR(20)  NOT NULL comment '消息类型',
    content         TEXT         NOT NULL comment '消息内容',
    create_time      TIMESTAMP    NOT NULL DEFAULT CURRENT_TIMESTAMP comment '创建时间',
    update_time      TIMESTAMP default CURRENT_TIMESTAMP not null on update CURRENT_TIMESTAMP comment '更新时间',
    is_delete        tinyint  default 0                 not null comment '是否删除',
    INDEX idx_conv (conversation_id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;

配置

大模型使用的是阿里的百炼大模型

spring:
  application:
    name: yy-ai-agent
  profiles:
    active: local
  ai:
    dashscope:
      api-key: ${DASH_SCOPE_API_KEY}
      chat:
        options:
          model: qwen-max
  datasource:
    url: jdbc:mysql://localhost:3306/your_database?useUnicode=true&characterEncoding=UTF-8&connectionCollation=utf8mb4_unicode_ci&serverTimezone=Asia/Shanghai
    username: your_username
    password: your_password
    driver-class-name: com.mysql.cj.jdbc.Driver


mybatis-plus:
  configuration:
    map-underscore-to-camel-case: false
    log-impl: org.apache.ibatis.logging.stdout.StdOutImpl
  global-config:
    db-config:
      logic-delete-field: isDelete # 全局逻辑删除的实体字段名
      logic-delete-value: 1 # 逻辑已删除值(默认为 1)
      logic-not-delete-value: 0 # 逻辑未删除值(默认为 0)

model

import com.baomidou.mybatisplus.annotation.*;

import java.io.Serializable;
import java.util.Date;
import lombok.Data;

/**
 * 
 * @TableName ai_chat_memory
 */
@TableName(value ="ai_chat_memory")
@Data
public class AiChatMemory implements Serializable {
    /**
     * 
     */
    @TableId(type = IdType.AUTO)
    private Long id;

    /**
     * 会话id
     */
    @TableField("conversation_id")
    private String conversationId;

    /**
     * 消息类型
     */
    @TableField("type")
    private String type;

    /**
     * 消息内容
     */
    @TableField("content")
    private String content;

    /**
     * 创建时间
     */
    @TableField("create_time")
    private Date createTime;

    /**
     * 更新时间
     */
    @TableField("update_time")
    private Date updateTime;

    /**
     * 是否删除
     */
    @TableLogic
    @TableField("is_delete")
    private Integer isDelete;

}

mapper

注意在项目启动类上加上@MapperScan("自己mapper所在报名")

@Mapper
public interface AiChatMemoryMapper extends BaseMapper<AiChatMemory> {

}

mybatis plus分页配置

这块有个坑,mybatis plus 3.5.9及以上版本想使用mybatis plus分页配置需要再引入一个mybatis-plus-jsqlparser的包,单纯只引入mybatis-plus-spring-boot3-starter这个依赖会找不到PaginationInnerInterceptor这个类。

import com.baomidou.mybatisplus.extension.plugins.inner.PaginationInnerInterceptor;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

import com.baomidou.mybatisplus.annotation.DbType;
import com.baomidou.mybatisplus.extension.plugins.MybatisPlusInterceptor;

@Configuration
public class MyBatisPlusConfig {

    /**
     * 注册 MyBatis-Plus 拦截器并添加分页插件
     */
    @Bean
    public MybatisPlusInterceptor mybatisPlusInterceptor() {
        MybatisPlusInterceptor interceptor = new MybatisPlusInterceptor();
        // 指定数据库类型为 MySQL,构造分页内置拦截器
        interceptor.addInnerInterceptor(new PaginationInnerInterceptor(DbType.MYSQL));
        return interceptor;
    }

}

ChatMemory实现

import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.core.aiagent.mapper.AiChatMemoryMapper;
import com.core.aiagent.model.AiChatMemory;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

@Component
public class MyBatisPlusChatMemory implements ChatMemory {

    @Autowired
    private AiChatMemoryMapper mapper;

    @Override
    public void add(String conversationId, Message message) {
        AiChatMemory aiChatMemory = new AiChatMemory();
        aiChatMemory.setConversationId(conversationId);
        aiChatMemory.setType(message.getMessageType().getValue());
        aiChatMemory.setContent(message.getText());

        mapper.insert(aiChatMemory);
    }

    @Override
    public void add(String conversationId, List<Message> messages) {
        List<AiChatMemory> aiChatMemories = new ArrayList<>();
        for (Message message : messages) {
            AiChatMemory aiChatMemory = new AiChatMemory();
            aiChatMemory.setConversationId(conversationId);
            aiChatMemory.setType(message.getMessageType().getValue());
            aiChatMemory.setContent(message.getText());
            aiChatMemories.add(aiChatMemory);
        }

        mapper.insert(aiChatMemories);
    }

    @Override
    public List<Message> get(String conversationId, int lastN) {
        // 分页查询最近N条记录
        Page<AiChatMemory> page = new Page<>(1, lastN);
        QueryWrapper<AiChatMemory> wrapper = new QueryWrapper<>();
        wrapper.eq("conversation_id", conversationId)
                .orderByDesc("create_time");

        List<AiChatMemory> aiChatMemories = mapper.selectList(wrapper);
        // 反转列表,使得最新的消息在最后
        Collections.reverse(aiChatMemories);

        // 转换为Message对象
        List<Message> messages = new ArrayList<>();
        for (AiChatMemory aiChatMemory : aiChatMemories) {
            String type = aiChatMemory.getType();
            switch (type) {
                case "user" -> messages.add(new UserMessage(aiChatMemory.getContent()));
                case "assistant" -> messages.add(new AssistantMessage(aiChatMemory.getContent()));
                case "system" -> messages.add(new SystemMessage(aiChatMemory.getContent()));
                default -> throw new IllegalArgumentException("Unknown message type: " + type);
            }
        }
        return messages;
    }

    @Override
    public void clear(String conversationId) {
        // 删除指定会话的所有消息
        QueryWrapper<AiChatMemory> wrapper = new QueryWrapper<>();
        wrapper.eq("conversation_id", conversationId);
        mapper.delete(wrapper);
    }
}

使用自定义持久化的ChatMemory

import com.core.aiagent.advisor.MyLoggerAdvisor;
import com.core.aiagent.chatmemory.MyBatisPlusChatMemory;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.memory.InMemoryChatMemory;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import java.util.List;

import static org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor.CHAT_MEMORY_CONVERSATION_ID_KEY;
import static org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor.CHAT_MEMORY_RETRIEVE_SIZE_KEY;

@Component
@Slf4j
public class LoveApp {

    private ChatClient chatClient;

    // mysql对话记忆
    @Autowired
    private MyBatisPlusChatMemory chatMemory;

    private static final String SYSTEM_PROMPT = "自己随便写点什么";

    public LoveApp(ChatModel dashScopeChatModel) {

        // 对话记忆,创建一个内存对话记忆
        //ChatMemory chatMemory = new InMemoryChatMemory();

        this.chatClient = ChatClient.builder(dashScopeChatModel)
                .defaultSystem(SYSTEM_PROMPT)
                // 指定默认advisor(类似拦截器),MessageChatMemoryAdvisor实现对话记忆功能,chatMemory是用来保存对话的
                // .defaultAdvisors(...):注册「要用记忆」的能力。
                .defaultAdvisors(
                        //new MessageChatMemoryAdvisor(chatMemory),
                        new MyLoggerAdvisor()
                )
                .build();
    }

    /**
     *
     * @param message 用户消息
     * @param chatId 对话记忆的id
     * @return ai回复
     */
    public String doChat(String message, String chatId) {

        ChatResponse chatResponse = chatClient.prompt()
                .user(message)
                // 指定对话记忆的id和对话记忆的长度(10条)
                // .advisors(spec->...):告诉该能力「这次用哪个会话」「取多少历史」
                .advisors(spec -> spec.param(CHAT_MEMORY_CONVERSATION_ID_KEY, chatId)
                        .param(CHAT_MEMORY_RETRIEVE_SIZE_KEY, 5))
                // 自定义mysql对话记忆
                .advisors(new MessageChatMemoryAdvisor(chatMemory))
                .call()
                .chatResponse();

        String text = chatResponse.getResult().getOutput().getText();
        //log.info("用户消息: {}, 返回消息: {}", message, text);

        return text;
    }
}

单元测试

先执行testChat(),再执行testChatMemory(),会发现在执行testChatMemory()时,大模型返回的是testChat()方法中的执行的信息。

@Test
    void testChat() {
        String chatId = UUID.randomUUID().toString();
        System.out.println("chatId: " + chatId);
        // 第一轮对话
        String message = "我是一个程序员,我叫xx";
        String answer = loveApp.doChat(message, chatId);
        Assertions.assertNotNull(answer);

        // 第二轮对话
        message = "我的另一半是yy,我想让她更爱我";
        answer = loveApp.doChat(message, chatId);
        Assertions.assertNotNull(answer);

        // 第三轮对话
        message = "我的另一半是谁来着,我刚刚提到过";
        answer = loveApp.doChat(message, chatId);
        Assertions.assertNotNull(answer);
    }

@Test
    void testChatMemory() {
        // 取出数据库中的conversation_id
        String chatId = "804e52bf-aa75-4a07-bb2e-ec93f47f4e1e";
        System.out.println("chatId: " + chatId);
        // 第一轮对话
        String message = "我是谁,我的另一半叫什么";
        String answer = loveApp.doChat(message, chatId);
        Assertions.assertNotNull(answer);
    }
<think>我们正在讨论快速幂算法的记忆实现。快速幂算法(也称为幂的二进制分解法)是一种高效计算幂运算的方法,特别是对于大指数的情况。记忆化(Memoization)是一种优化技术,通过存储先前计算的结果来避免重复计算,从而减少计算时间。 快速幂算法的基本思想是利用指数的二进制表示和分治思想。例如,计算 $a^n$ 时,我们可以将 $n$ 表示为二进制,然后通过平方和乘法来组合结果。 记忆化快速幂的目标是避免在递归或迭代过程中重复计算相同的子问题。在递归实现快速幂时,我们可能会多次计算同一个指数幂(尤其是当指数有重叠部分时),因此使用记忆化可以显著提高效率。 根据引用[2]和引用[3],记忆化递归是一种将递归结果存储起来以便后续直接调用的技术。在快速幂的递归实现中,我们可以将已经计算过的指数对应的幂值存储在一个字典或数组中,当再次需要相同指数的幂时,直接返回存储的结果。 具体步骤: 1. 定义一个记忆化存储结构(如字典),用于存储已经计算过的幂的结果。 2. 在递归函数中,首先检查当前指数是否已经在记忆化存储中,如果是,则直接返回存储的结果。 3. 如果当前指数未计算,则按照快速幂的递归方法进行计算(将指数分为两半,递归计算),并将结果存储到记忆化结构中。 快速幂递归公式: - 如果指数n为0,返回1(因为 $a^0=1$)。 - 如果指数n为1,返回底数a。 - 如果n是偶数,则 $a^n = (a^{n/2})^2$。 - 如果n是奇数,则 $a^n = a \cdot (a^{(n-1)/2})^2$。 在递归计算时,我们首先检查记忆化存储中是否有 $a^{n/2}$ 或 $a^{(n-1)/2}$ 的值,如果有就直接使用,避免重复递归。 注意:由于指数可能非常大(如引用[4]中指数达到 $10^{18}$),递归深度可能会很大,但快速幂的递归深度是 $O(\log n)$,所以是可以接受的。但是,如果使用记忆化,我们可以避免重复计算,特别是当有多个相同指数需要计算时(如多次查询,引用[4]中的Q次查询),记忆化可以大大提高效率。 下面我们给出一个使用记忆化递归实现快速幂的Python代码示例: ```python def fast_power(a, n, memo=None): if memo is None: memo = {} # 检查是否已经计算过 if n in memo: return memo[n] # 递归基 if n == 0: return 1 if n == 1: return a # 递归计算 if n % 2 == 0: half = fast_power(a, n // 2, memo) result = half * half else: half = fast_power(a, (n - 1) // 2, memo) result = a * half * half # 存储结果 memo[n] = result return result ``` 然而,上述代码有一个问题:当指数n很大时,递归深度可能达到 $O(\log n)$,这在大多数情况下是可以接受的(例如,$n=10^{18}$ 时,递归深度大约为60)。但是,我们也可以使用迭代的快速幂,它更节省内存,而且不需要递归栈。 引用[4]中提到,当有多次查询(Q次)时,我们可以利用记忆化来避免重复计算相同的指数幂。但注意,每次查询的指数可能不同,所以记忆化存储可以跨查询使用。因此,我们可以将记忆化字典作为全局变量或类变量,这样在多次调用之间可以共享。 另外,引用[3]中提到,当输入是递增的,记忆化的效果会更好,因为之前计算的小指数幂会被后续的大指数幂利用。但即使不是递增,记忆化也能避免重复计算相同的指数幂。 迭代实现快速幂通常不需要记忆化,因为迭代过程本身没有重复计算。但是,如果我们需要多次计算不同指数的幂,并且这些指数有重复,那么我们可以用一个全局字典来存储所有计算过的幂,这样即使使用迭代方法,也可以利用记忆化避免重复计算。 迭代快速幂的记忆实现: 我们可以先写一个不带记忆化的迭代快速幂,然后在外层封装一个记忆化字典。每次计算幂时,先检查记忆化字典,如果已经计算过则直接返回,否则用迭代方法计算并存储。 迭代快速幂代码(无记忆化): ```python def iterative_fast_power(a, n): result = 1 base = a while n: if n & 1: result *= base base *= base n //= 2 # 或者 n >>= 1 return result ``` 加上记忆化(全局): ```python memo_dict = {} def iterative_fast_power_memo(a, n): # 如果指数n相同,但底数a不同,那么我们需要区分。所以记忆化的键应该是(a, n) key = (a, n) if key in memo_dict: return memo_dict[key] # 否则计算 result = 1 base = a # 注意:这里n在循环中会被修改,所以先保存原始指数 exp = n while exp: if exp & 1: result *= base base *= base exp //= 2 memo_dict[key] = result return result ``` 但是,上述全局记忆化在多次查询时,如果底数a和指数n的组合重复出现,就可以避免重复计算。但是,如果底数a是变化的,且指数n很大,那么记忆化字典可能会变得很大。因此,需要根据应用场景权衡。 另外,引用[1]和引用[3]提到矩阵快速幂,实际上快速幂算法可以推广到矩阵乘法,用于加速矩阵的幂运算(例如斐波那契数列的矩阵形式)。同样,矩阵快速幂也可以使用记忆化,但由于矩阵通常较大,存储矩阵幂可能占用较多内存,所以一般只用于指数相同的情况多次出现,且矩阵大小不是特别大的情况。 总结:快速幂的记忆实现主要适用于需要多次计算幂(特别是相同指数)的场景,通过存储已计算的结果来减少重复计算。递归实现中自然可以通过记忆化避免重复递归计算,而迭代实现则可以通过外部字典实现记忆化。 注意:在递归实现中,记忆化存储的是指数对应的结果,而迭代实现中,我们存储的是(a, n)对应的结果。由于底数a可能不同,所以用(a, n)作为键。 根据用户的问题,用户可能想了解如何实现快速幂的记忆化。我们提供了递归和迭代两种方式的记忆实现。 最后,我们生成相关问题: §§相关问题§§ 1. 快速幂算法的时间复杂度是多少? 2. 在什么情况下使用记忆化快速幂比普通快速幂更有优势? 3. 记忆化快速幂在矩阵幂运算中如何应用? 4. 如何处理大指数幂运算中的内存限制问题(例如记忆化字典过大)?</think>### 快速幂算法的记忆实现方法 快速幂算法(又称二进制取幂)通过指数分解降低幂运算的时间复杂度,从 $O(n)$ 优化至 $O(\log n)$。记忆化技术通过存储中间计算结果避免重复计算,进一步提升效率,尤其适合多次查询相同指数的场景(如引用[4]中 $Q$ 次查询)。以下是实现方法: --- #### **1. 递归实现 + 记忆化** **核心思想**: - 将 $a^n$ 分解为 $a^{n/2} \cdot a^{n/2}$(偶数)或 $a \cdot a^{(n-1)/2} \cdot a^{(n-1)/2}$(奇数) - 用字典存储已计算的 $a^k$ 结果,避免重复递归 **Python 实现**: ```python memo = {} # 全局记忆化字典 def fast_pow(a, n): if n == 0: return 1 if n == 1: return a if n in memo: return memo[n] # 直接返回已存储的结果 # 分解子问题 half = fast_pow(a, n // 2) result = half * half if n % 2 == 1: result *= a # 处理奇指数 memo[n] = result # 存储结果 return result ``` **优势**: - 时间复杂度 $O(\log n)$,空间复杂度 $O(\log n)$(递归栈 + 记忆化存储) - 对重复查询效率极高(如引用[4]的多次查询场景) --- #### **2. 迭代实现 + 记忆化** **核心思想**: - 通过二进制分解指数:$n = \sum b_i \cdot 2^i$ - 动态计算 $a^{2^i}$ 并组合结果 - 全局字典存储历史查询 **Python 实现**: ```python memo_global = {} # 全局记忆化字典 def iterative_fast_pow(a, n): if (a, n) in memo_global: # 检查是否已计算 return memo_global[(a, n)] res = 1 base = a while n: if n & 1: # 当前二进制位为1 res *= base base *= base # a^{2^i} -> a^{2^{i+1}} n //= 2 memo_global[(a, n)] = res # 存储结果 return res ``` **优势**: - 避免递归栈开销,空间复杂度 $O(1)$(不计记忆化存储) - 支持任意底数 $a$ 和指数 $n$(引用[4]中 $n \leq 10^{18}$) --- #### **3. 矩阵快速幂的记忆化(扩展)** 当 $a$ 为矩阵时(如斐波那契数列的矩阵形式),记忆化可加速矩阵幂运算: ```python matrix_memo = {} # 存储矩阵幂结果 def matrix_fast_pow(matrix, n): if n == 1: return matrix if n in matrix_memo: return matrix_memo[n] half = matrix_fast_pow(matrix, n // 2) result = multiply_matrix(half, half) # 矩阵乘法 if n % 2 == 1: result = multiply_matrix(result, matrix) matrix_memo[n] = result return result ``` **应用场景**: - 动态规划优化(如斐波那契数列 $O(\log n)$ 解法,引用[1][3]) - 大指数线性递推问题(如引用[4]中的 $N \leq 10^{18}$) --- ### **记忆化的关键优化点** 1. **键设计**: - 简单场景:用指数 $n$ 作为键($a$ 固定) - 通用场景:用元组 $(a, n)$ 作为键(引用[4]的多查询需求) 2. **内存管理**: - 限制字典大小(如 LRU 缓存),避免大 $n$ 导致内存溢出 - 分布式场景可用 Redis 等外部存储 3. **适用场景**: - 多次查询相同指数(如竞赛中的 $Q$ 次查询,引用[4]) - 指数递增时效果显著(引用[3]:"输入的 $n$ 递增则越算越快") > 通过记忆化,快速幂的多次查询时间复杂度可降至 $O(1)$(命中缓存)或 $O(\log n)$(未命中),空间换时间的策略在指数重复时收益显著[^2][^3]。 ---
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值