若要实现自定义MySQL持久化ChatMemory,可以参考这篇文章
Spring AI 自定义数据库持久化的ChatMemory
这里我使用的是spring ai 1.0.0-M6.1 现在spring ai已经发布了正式版,其实现与现在有了一定的区别.
自定义实现ChatMemory
Spring AI的对话记忆实现非常巧妙,解耦了"储存"和"记忆算法",
-
存储:ChatMemory:我们可以单独修改ChatMemory存储来改变对话记忆的保存位置,而无需修改保存对话记忆的流程。
-
记忆算法:ChatMemory Advisor,advisor可以理解为拦截器,在调用大模型时的前或后执行一些操作
- MessageChatMemoryAdvisor: 从记忆中(ChatMemory)检索历史对话,并将其作为消息集合添加到提示词中。常用。能更好的保持上下文连贯性。
- PromptChatMemoryAdvisor: 从记忆中检索历史对话,并将其添加到提示词的系统文本中。可以理解为没有结构性的纯文本。
- VectorStoreChatMemoryAdvisor: 可以用向量数据库来存储检索历史对话。
我们可以单独修改ChatMemory储存来改变对话记忆的保存位置,而无需修改保存对话记忆的流程.
虽然官方文档没有给我们自定义ChatMemory实现的示例,但是我们可以直接去阅读默认实现类 InMemoryChatMemory 的源码

其本质是实现了ChatMemory的增删查接口

所以我们想实现自己的持久化,修改对应的储存实现就行了.
参考 InMemoryChatMemory 的源码,其实就是通过 ConcurrentHashMap 来维护对话信息,key 是对话 id(相当于房间号),value 是该对话 id 对应的消息列表。
自定义Redis持久化ChatMemory
由于List<Message>中Message是一个接口,虽然需要实现的接口不多,但是实现起来还是有一定复杂度的,一个最主要的问题是 消息和文本的转换。我们在保存消息时,要将消息从 Message 对象转为文件内的文本;读取消息时,要将文件内的文本转换为 Message 对象。也就是对象的序列化和反序列化。
我们本能地会想到通过 JSON 进行序列化,但实际操作中,我们发现这并不容易。原因是:
- 要持久化的 Message 是一个接口,有很多种不同的子类实现(比如 UserMessage、SystemMessage 等)
- 每种子类所拥有的字段都不一样,结构不统一
- 子类没有无参构造函数,而且没有实现 Serializable 序列化接口
在这里我们使用Kryo的序列化库来实现序列化
1)引入redis依赖
这里使用的是Spring 3.4.4 ,Java 21
<!-- Redis -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>
2)修改配置
# redis 配置
data:
redis:
port: 6379
host: localhost
database: 0
3)配置redis的bean注入
@Configuration
public class RedisTemplateConfig {
@Bean
public RedisTemplate<String, Object> redisTemplate(RedisConnectionFactory connectionFactory) {
RedisTemplate<String, Object> template = new RedisTemplate<>();
template.setConnectionFactory(connectionFactory);
template.setKeySerializer(RedisSerializer.string());
return template;
}
}
4)实现序列化
- 引入依赖
<!-- 自定义持久化的序列化库-->
<dependency>
<groupId>com.esotericsoftware</groupId>
<artifactId>kryo</artifactId>
<version>5.6.2</version>
</dependency>
- 创建序列化实现工具类
@Component
public class MessageSerializer {
// ⚠️ 静态 Kryo 实例(线程不安全,建议改用局部实例)
private static final Kryo kryo = new Kryo();
static {
kryo.setRegistrationRequired(false);
// 设置实例化策略(需确保兼容所有 Message 实现类)
kryo.setInstantiatorStrategy(new StdInstantiatorStrategy());
}
/**
* 使用 Kryo 将 Message 序列化为 Base64 字符串
*/
public static String serialize(Message message) {
try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
Output output = new Output(baos)) {
kryo.writeClassAndObject(output, message); // ⚠️ 依赖动态注册和实例化策略
output.flush();
return Base64.getEncoder().encodeToString(baos.toByteArray());
} catch (IOException e) {
throw new RuntimeException("序列化失败", e);
}
}
/**
* 使用 Kryo 将 Base64 字符串反序列化为 Message 对象
*/
public static Message deserialize(String base64) {
try (ByteArrayInputStream bais = new ByteArrayInputStream(Base64.getDecoder().decode(base64));
Input input = new Input(bais)) {
return (Message) kryo.readClassAndObject(input); // ⚠️ 依赖动态注册和实例化策略
} catch (IOException e) {
throw new RuntimeException("反序列化失败", e);
}
}
}
5)实现自定义ChatMemory
/**
* 基于 Redis 的对话记忆存储实现
*/
@Service
@Slf4j
public class RedisChatMemory implements ChatMemory {
// Redis 键前缀,避免键冲突
private static final String KEY_PREFIX = "chat:memory:";
private final RedisTemplate<String, Object> redisTemplate;
@Resource
private MySQLChatMemoryStore mySQLChatMemoryStore;
public RedisChatMemory(RedisTemplate<String, Object> redisTemplate) {
this.redisTemplate = redisTemplate;
}
/**
* 添加单条消息到对话历史
*
* @param conversationId 对话 ID
* @param message 消息对象
*/
@Override
public void add(String conversationId, Message message) {
add(conversationId, List.of(message));
}
/**
* 添加多条消息到对话历史
*
* @param conversationId 对话 ID
* @param messages 消息列表
*/
@Override
public void add(String conversationId, List<Message> messages) {
// 获取现有消息列表
List<Message> existingMessages = getFromRedis(conversationId);
// 合并消息
existingMessages.addAll(messages);
// 保存更新后的消息列表
setToRedis(conversationId, existingMessages);
// 异步存储到MySQL数据库(若自己实现了可以使用)
// mySQLChatMemoryStore.storeMessages(conversationId, messages);
// 检查消息数量,如果超过20条则删除多余部分,只保留最新的20条
if (existingMessages.size() > 20) {
trimConversation(conversationId);
}
log.debug("已向对话 [{}] 添加 {} 条消息,当前总消息数: {}",
conversationId, messages.size(), Math.min(existingMessages.size(), 20));
}
/**
* 获取对话的最近 N 条消息
*
* @param conversationId 对话 ID
* @param lastN 最近的消息数量
* @return 消息列表
*/
@Override
public List<Message> get(String conversationId, int lastN) {
List<Message> allMessages = getFromRedis(conversationId);
int skip = Math.max(0, allMessages.size() - lastN);
return allMessages.stream()
.skip(skip)
.toList();
}
/**
* 清空对话历史
*
* @param conversationId 对话 ID
*/
@Override
public void clear(String conversationId) {
String key = getRedisKey(conversationId);
redisTemplate.delete(key);
log.debug("已清空对话 [{}] 的历史消息", conversationId);
}
/**
* 清理对话历史,只保留最新的20条消息
*
* @param conversationId 对话 ID
*/
public void trimConversation(String conversationId) {
List<Message> allMessages = getFromRedis(conversationId);
if (allMessages.size() > 20) {
// 只保留最新的20条消息
List<Message> recentMessages = allMessages.subList(allMessages.size() - 20, allMessages.size());
setToRedis(conversationId, recentMessages);
log.debug("已清理对话 [{}] 的历史消息,从 {} 条减少到 20 条", conversationId, allMessages.size());
}
}
/**
* 获取所有对话ID
*
* @return 对话ID集合
*/
public List<String> getAllConversationIds() {
Set<String> keys = redisTemplate.keys(KEY_PREFIX + "*");
if (keys != null) {
// 移除前缀,只返回conversationId
List<String> conversationIds = new ArrayList<>();
for (String key : keys) {
conversationIds.add(key.substring(KEY_PREFIX.length()));
}
return conversationIds;
}
return new ArrayList<>();
}
/**
* 从 Redis 获取消息列表
*/
@SuppressWarnings("unchecked")
private List<Message> getFromRedis(String conversationId) {
String key = getRedisKey(conversationId);
Object value = redisTemplate.opsForValue().get(key);
// 处理空值或类型不匹配的情况
if (value == null) {
return new ArrayList<>();
}
if (!(value instanceof List)) {
log.error("对话 [{}] 的消息存储格式不正确,预期为 List,实际为: {}",
conversationId, value.getClass().getName());
return new ArrayList<>();
}
List<String> serializedMessages = new ArrayList<>();
for (Object item : (List<?>) value) {
if (item instanceof String) {
serializedMessages.add((String) item);
} else {
log.warn("对话 [{}] 中发现非字符串类型的消息,跳过: {}",
conversationId, item.getClass().getName());
}
}
List<Message> messages = new ArrayList<>(serializedMessages.size());
for (String serialized : serializedMessages) {
try {
Message message = MessageSerializer.deserialize(serialized);
messages.add(message);
} catch (Exception e) {
log.error("反序列化消息失败,跳过该消息: {}", serialized, e);
}
}
return messages;
}
/**
* 将消息列表存入 Redis
*/
private void setToRedis(String conversationId, List<Message> messages) {
String key = getRedisKey(conversationId);
List<String> serializedMessages = new ArrayList<>(messages.size());
for (Message message : messages) {
try {
String serialized = MessageSerializer.serialize(message);
serializedMessages.add(serialized);
} catch (Exception e) {
log.error("序列化消息失败,跳过该消息: {}", message, e);
}
}
redisTemplate.opsForValue().set(key, serializedMessages, 1L, TimeUnit.DAYS);
}
/**
* 生成带前缀的 Redis 键
*/
private String getRedisKey(String conversationId) {
return KEY_PREFIX + conversationId;
}
}
6)配置到自己的APP里面
this.chatClient = ChatClient.builder(dashscopeChatModel)
.defaultSystem(SYSTEM_PROMPT)
.defaultAdvisors(
new MessageChatMemoryAdvisor(redisChatMemory),
//自定义日志拦截器,可按需开启
new MyLoggerAdvisor(),
//权限校验
new AuthAdvisor(),
//违禁词校验
new BannedWordsAdvisor()
)
.build();
代码测试
- 先与AI对话

- 重启项目重新对话,询问记录的消息

成功!!!
4727

被折叠的 条评论
为什么被折叠?



