Spring AI 自定义Redis持久化ChatMemory(可直接复用)

若要实现自定义MySQL持久化ChatMemory,可以参考这篇文章
Spring AI 自定义数据库持久化的ChatMemory

这里我使用的是spring ai 1.0.0-M6.1 现在spring ai已经发布了正式版,其实现与现在有了一定的区别.

自定义实现ChatMemory

Spring AI的对话记忆实现非常巧妙,解耦了"储存"和"记忆算法",

  • 存储:ChatMemory:我们可以单独修改ChatMemory存储来改变对话记忆的保存位置,而无需修改保存对话记忆的流程。

  • 记忆算法:ChatMemory Advisor,advisor可以理解为拦截器,在调用大模型时的前或后执行一些操作

    • MessageChatMemoryAdvisor: 从记忆中(ChatMemory)检索历史对话,并将其作为消息集合添加到提示词中。常用。能更好的保持上下文连贯性。
    • PromptChatMemoryAdvisor: 从记忆中检索历史对话,并将其添加到提示词的系统文本中。可以理解为没有结构性的纯文本。
    • VectorStoreChatMemoryAdvisor: 可以用向量数据库来存储检索历史对话。

我们可以单独修改ChatMemory储存来改变对话记忆的保存位置,而无需修改保存对话记忆的流程.

虽然官方文档没有给我们自定义ChatMemory实现的示例,但是我们可以直接去阅读默认实现类 InMemoryChatMemory 的源码

基于内存持久化的ChatMemory

其本质是实现了ChatMemory的增删查接口

ChatMemory

所以我们想实现自己的持久化,修改对应的储存实现就行了.

参考 InMemoryChatMemory 的源码,其实就是通过 ConcurrentHashMap 来维护对话信息,key 是对话 id(相当于房间号),value 是该对话 id 对应的消息列表。

自定义Redis持久化ChatMemory

由于List<Message>中Message是一个接口,虽然需要实现的接口不多,但是实现起来还是有一定复杂度的,一个最主要的问题是 消息和文本的转换。我们在保存消息时,要将消息从 Message 对象转为文件内的文本;读取消息时,要将文件内的文本转换为 Message 对象。也就是对象的序列化和反序列化。

我们本能地会想到通过 JSON 进行序列化,但实际操作中,我们发现这并不容易。原因是:

  1. 要持久化的 Message 是一个接口,有很多种不同的子类实现(比如 UserMessage、SystemMessage 等)
  2. 每种子类所拥有的字段都不一样,结构不统一
  3. 子类没有无参构造函数,而且没有实现 Serializable 序列化接口

在这里我们使用Kryo的序列化库来实现序列化

1)引入redis依赖

这里使用的是Spring 3.4.4 ,Java 21

        <!-- Redis -->
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-data-redis</artifactId>
        </dependency>

2)修改配置


# redis 配置
data:
  redis:
    port: 6379
    host: localhost
    database: 0

3)配置redis的bean注入

@Configuration
public class RedisTemplateConfig {

    @Bean
    public RedisTemplate<String, Object> redisTemplate(RedisConnectionFactory connectionFactory) {
        RedisTemplate<String, Object> template = new RedisTemplate<>();
        template.setConnectionFactory(connectionFactory);
        template.setKeySerializer(RedisSerializer.string());
        return template;
    }

}

4)实现序列化

  1. 引入依赖
		<!-- 自定义持久化的序列化库-->
		<dependency>
    		<groupId>com.esotericsoftware</groupId>
    		<artifactId>kryo</artifactId>
    		<version>5.6.2</version>
		</dependency>
  1. 创建序列化实现工具类
@Component
public class MessageSerializer {

    // ⚠️ 静态 Kryo 实例(线程不安全,建议改用局部实例)
    private static final Kryo kryo = new Kryo();

    static {
        kryo.setRegistrationRequired(false);
        // 设置实例化策略(需确保兼容所有 Message 实现类)
        kryo.setInstantiatorStrategy(new StdInstantiatorStrategy());
    }

    /**
     * 使用 Kryo 将 Message 序列化为 Base64 字符串
     */
    public static String serialize(Message message) {
        try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
             Output output = new Output(baos)) {
            kryo.writeClassAndObject(output, message);  // ⚠️ 依赖动态注册和实例化策略
            output.flush();
            return Base64.getEncoder().encodeToString(baos.toByteArray());
        } catch (IOException e) {
            throw new RuntimeException("序列化失败", e);
        }
    }

    /**
     * 使用 Kryo 将 Base64 字符串反序列化为 Message 对象
     */
    public static Message deserialize(String base64) {
        try (ByteArrayInputStream bais = new ByteArrayInputStream(Base64.getDecoder().decode(base64));
             Input input = new Input(bais)) {
            return (Message) kryo.readClassAndObject(input);  // ⚠️ 依赖动态注册和实例化策略
        } catch (IOException e) {
            throw new RuntimeException("反序列化失败", e);
        }
    }
}

5)实现自定义ChatMemory


/**
 * 基于 Redis 的对话记忆存储实现
 */
@Service
@Slf4j
public class RedisChatMemory implements ChatMemory {
    // Redis 键前缀,避免键冲突
    private static final String KEY_PREFIX = "chat:memory:";
    private final RedisTemplate<String, Object> redisTemplate;
    
    @Resource
    private MySQLChatMemoryStore mySQLChatMemoryStore;

    public RedisChatMemory(RedisTemplate<String, Object> redisTemplate) {
        this.redisTemplate = redisTemplate;
    }

    /**
     * 添加单条消息到对话历史
     *
     * @param conversationId 对话 ID
     * @param message        消息对象
     */
    @Override
    public void add(String conversationId, Message message) {
        add(conversationId, List.of(message));
    }

    /**
     * 添加多条消息到对话历史
     *
     * @param conversationId 对话 ID
     * @param messages       消息列表
     */
    @Override
    public void add(String conversationId, List<Message> messages) {
        // 获取现有消息列表
        List<Message> existingMessages = getFromRedis(conversationId);
        // 合并消息
        existingMessages.addAll(messages);
        // 保存更新后的消息列表
        setToRedis(conversationId, existingMessages);
        
        // 异步存储到MySQL数据库(若自己实现了可以使用)
//        mySQLChatMemoryStore.storeMessages(conversationId, messages);
        
        // 检查消息数量,如果超过20条则删除多余部分,只保留最新的20条
        if (existingMessages.size() > 20) {
            trimConversation(conversationId);
        }
        
        log.debug("已向对话 [{}] 添加 {} 条消息,当前总消息数: {}",
                conversationId, messages.size(), Math.min(existingMessages.size(), 20));
    }

    /**
     * 获取对话的最近 N 条消息
     *
     * @param conversationId 对话 ID
     * @param lastN          最近的消息数量
     * @return 消息列表
     */
    @Override
    public List<Message> get(String conversationId, int lastN) {
        List<Message> allMessages = getFromRedis(conversationId);
        int skip = Math.max(0, allMessages.size() - lastN);
        return allMessages.stream()
                .skip(skip)
                .toList();
    }

    /**
     * 清空对话历史
     *
     * @param conversationId 对话 ID
     */
    @Override
    public void clear(String conversationId) {
        String key = getRedisKey(conversationId);
        redisTemplate.delete(key);
        log.debug("已清空对话 [{}] 的历史消息", conversationId);
    }
    
    /**
     * 清理对话历史,只保留最新的20条消息
     *
     * @param conversationId 对话 ID
     */
    public void trimConversation(String conversationId) {
        List<Message> allMessages = getFromRedis(conversationId);
        if (allMessages.size() > 20) {
            // 只保留最新的20条消息
            List<Message> recentMessages = allMessages.subList(allMessages.size() - 20, allMessages.size());
            setToRedis(conversationId, recentMessages);
            log.debug("已清理对话 [{}] 的历史消息,从 {} 条减少到 20 条", conversationId, allMessages.size());
        }
    }
    
    /**
     * 获取所有对话ID
     *
     * @return 对话ID集合
     */
    public List<String> getAllConversationIds() {
        Set<String> keys = redisTemplate.keys(KEY_PREFIX + "*");
        if (keys != null) {
            // 移除前缀,只返回conversationId
            List<String> conversationIds = new ArrayList<>();
            for (String key : keys) {
                conversationIds.add(key.substring(KEY_PREFIX.length()));
            }
            return conversationIds;
        }
        return new ArrayList<>();
    }

    /**
     * 从 Redis 获取消息列表
     */
    @SuppressWarnings("unchecked")
    private List<Message> getFromRedis(String conversationId) {
        String key = getRedisKey(conversationId);
        Object value = redisTemplate.opsForValue().get(key);
        // 处理空值或类型不匹配的情况
        if (value == null) {
            return new ArrayList<>();
        }
        if (!(value instanceof List)) {
            log.error("对话 [{}] 的消息存储格式不正确,预期为 List,实际为: {}",
                    conversationId, value.getClass().getName());
            return new ArrayList<>();
        }
        List<String> serializedMessages = new ArrayList<>();
        for (Object item : (List<?>) value) {
            if (item instanceof String) {
                serializedMessages.add((String) item);
            } else {
                log.warn("对话 [{}] 中发现非字符串类型的消息,跳过: {}",
                        conversationId, item.getClass().getName());
            }
        }
        List<Message> messages = new ArrayList<>(serializedMessages.size());
        for (String serialized : serializedMessages) {
            try {
                Message message = MessageSerializer.deserialize(serialized);
                messages.add(message);
            } catch (Exception e) {
                log.error("反序列化消息失败,跳过该消息: {}", serialized, e);
            }
        }
        return messages;
    }

    /**
     * 将消息列表存入 Redis
     */
    private void setToRedis(String conversationId, List<Message> messages) {
        String key = getRedisKey(conversationId);
        List<String> serializedMessages = new ArrayList<>(messages.size());
        for (Message message : messages) {
            try {
                String serialized = MessageSerializer.serialize(message);
                serializedMessages.add(serialized);
            } catch (Exception e) {
                log.error("序列化消息失败,跳过该消息: {}", message, e);
            }
        }
        redisTemplate.opsForValue().set(key, serializedMessages, 1L, TimeUnit.DAYS);
    }

    /**
     * 生成带前缀的 Redis 键
     */
    private String getRedisKey(String conversationId) {
        return KEY_PREFIX + conversationId;
    }
}



6)配置到自己的APP里面

this.chatClient = ChatClient.builder(dashscopeChatModel)
        .defaultSystem(SYSTEM_PROMPT)
        .defaultAdvisors(
                new MessageChatMemoryAdvisor(redisChatMemory),
                //自定义日志拦截器,可按需开启
                new MyLoggerAdvisor(),
                //权限校验
                new AuthAdvisor(),
                //违禁词校验
                new BannedWordsAdvisor()
        )
        .build();
代码测试
  1. 先与AI对话

在这里插入图片描述

  1. 重启项目重新对话,询问记录的消息

在这里插入图片描述

成功!!!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值