一、环境准备
<!-- SpringAI-->
<dependency>
<groupId>com.alibaba.cloud.ai</groupId>
<artifactId>spring-ai-alibaba-starter</artifactId>
<version>1.0.0-M6.1</version>
</dependency>
<dependency>
<groupId>org.postgresql</groupId>
<artifactId>postgresql</artifactId>
<version>42.3.1</version>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-pgvector-store-spring-boot-starter</artifactId>
<version>1.0.0-M6</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-jdbc</artifactId>
</dependency>
上文中实现了基于内存的会话保持功能,现在要基于数据库扩展实现
JdbcTemplate+PgSQL数据库实现扩展ChatMemory实现。
二、扩展实现
package org.springframework.ai.chat.memory;
import java.util.List;
import org.springframework.ai.chat.messages.Message;
public interface ChatMemory {
default void add(String conversationId, Message message) {
this.add(conversationId, List.of(message));
}
void add(String conversationId, List<Message> messages);
List<Message> get(String conversationId, int lastN);
void clear(String conversationId);
}
- 表结构准备
CREATE TABLE chat_messages (
id BIGSERIAL PRIMARY KEY,
conversation_id VARCHAR(255) NOT NULL,
message_type VARCHAR(50) NOT NULL,
content TEXT NOT NULL,
created_at TIMESTAMP NOT NULL
);
-- 创建索引
CREATE INDEX idx_conversation_id ON chat_messages (conversation_id);
- ChatDao接口定义
package org.spring.springaiprojet.dao;
import org.spring.springaiprojet.entity.ChatMessageEntity;
import java.util.List;
public interface ChatDao {
/**
* 保存表
* @param messages
*/
void insertMessages(List<ChatMessageEntity> messages);
/**
* 查询最近的N条消息
* @param conversationId
* @param lastN
* @return
*/
List<ChatMessageEntity> findLastNMessages(String conversationId, int lastN);
/**
* 删除会话
* @param conversationId
*/
void deleteByConversationId(String conversationId);
}
- ChatImpl接口实现
import org.spring.springaiprojet.dao.ChatDao;
import org.spring.springaiprojet.entity.ChatMessageEntity;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.stereotype.Repository;
import java.time.LocalDateTime;
import java.util.List;
@Repository
public class ChatDaoImpl implements ChatDao {
@Autowired
private JdbcTemplate jdbcTemplate;
@Override
public void insertMessages(List<ChatMessageEntity> messages) {
jdbcTemplate.batchUpdate("insert into chat_messages (conversation_id, message_type, content, created_at) values (?, ?, ?, ?)", messages, messages.size(), (ps, message) -> {
ps.setString(1, message.getConversationId());
ps.setString(2, message.getMessageType());
ps.setString(3, message.getContent());
ps.setObject(4, message.getCreatedAt());
});
}
@Override
public List<ChatMessageEntity> findLastNMessages(String conversationId, int lastN) {
return jdbcTemplate.query(
"select * from chat_messages where conversation_id = ? order by created_at desc limit ?",
(rs, rowNum) -> {
ChatMessageEntity message = new ChatMessageEntity();
message.setConversationId(rs.getString("conversation_id"));
message.setMessageType(rs.getString("message_type"));
message.setContent(rs.getString("content"));
message.setCreatedAt(rs.getObject("created_at", LocalDateTime.class));
return message;
},
conversationId,
lastN
);
}
@Override
public void deleteByConversationId(String conversationId) {
jdbcTemplate.update("delete from chat_messages where conversation_id = ?", conversationId);
}
}
-
ChatMessageEntity
package org.spring.springaiprojet.entity;
import java.time.LocalDateTime;
public class ChatMessageEntity {
private Long id;
private String conversationId;
private String messageType;
private String content;
private LocalDateTime createdAt;
public Long getId() {
return id;
}
public void setId(Long id) {
this.id = id;
}
public String getConversationId() {
return conversationId;
}
public void setConversationId(String conversationId) {
this.conversationId = conversationId;
}
public String getMessageType() {
return messageType;
}
public void setMessageType(String messageType) {
this.messageType = messageType;
}
public String getContent() {
return content;
}
public void setContent(String content) {
this.content = content;
}
public LocalDateTime getCreatedAt() {
return createdAt;
}
public void setCreatedAt(LocalDateTime createdAt) {
this.createdAt = createdAt;
}
}
-
PgSQLChatMemory
package org.spring.springaiprojet.config.chat;
import org.spring.springaiprojet.dao.ChatDao;
import org.spring.springaiprojet.entity.ChatMessageEntity;
import org.spring.springaiprojet.entity.MessageEnum;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import java.time.LocalDateTime;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
@Component
public class PgSQLChatMemory implements ChatMemory {
@Autowired
private ChatDao chatDao;
@Override
public void add(String conversationId, List<Message> messages) {
if (messages == null || messages.isEmpty()) {
return;
}
List<ChatMessageEntity> entities = messages.stream()
.map(msg -> {
ChatMessageEntity entity = new ChatMessageEntity();
entity.setConversationId(conversationId);
entity.setContent(msg.getText());
if (msg instanceof UserMessage) {
entity.setMessageType(MessageEnum.USER.getValue());
} else if (msg instanceof AssistantMessage) {
entity.setMessageType(MessageEnum.ASSISTANT.getValue());
}
entity.setCreatedAt(LocalDateTime.now());
return entity;
})
.collect(Collectors.toList());
chatDao.insertMessages(entities);
}
@Override
public List<Message> get(String conversationId, int lastN) {
List<ChatMessageEntity> entities = chatDao.findLastNMessages(conversationId, lastN);
if (entities == null || entities.isEmpty()) {
return Collections.emptyList();
}
// 倒序
Collections.reverse(entities);
return entities.stream()
.map(entity -> {
switch (entity.getMessageType()) {
case "user":
return new UserMessage(entity.getContent());
case "assistant":
return new AssistantMessage(entity.getContent());
default:
throw new IllegalArgumentException("未知的消息类型!");
}
})
.collect(Collectors.toList());
}
@Override
public void clear(String conversationId) {
if (conversationId == null){
return;
}
chatDao.deleteByConversationId(conversationId);
}
}
- AiConfig实现
/**
* 基于PgSQL实现会话记忆
*/
@Autowired
private PgSQLChatMemory pgSQLChatMemory;
@Bean
public ChatClient chatClient(ChatClient.Builder builder) {
// defaultSystem,默认系统角色,带有对话身份
return builder
// .defaultSystem("请以通俗开发者角度介绍")
// 增加会话记忆,基于
.defaultAdvisors(new PromptChatMemoryAdvisor(pgSQLChatMemory)).build();
// .defaultAdvisors(new SimpleLoggerAdvisor()).build();
}
- 控制器实现
/**
* 普通对话模式
* @param question
* @return
*/
@RequestMapping("/qwen/chat/api")
public String chat(String question) {
return chatClient.prompt().user(question).call().content();
}
三、测试调用
- 第一次调用如下接口
GET http://localhost:8088/boot/ai/qwen/chat/api?question=SpringBoot框架各个模块作用
已经写入数据表中了

- 第二次基于上文会话内容,写入表中
GET http://localhost:8088/boot/ai/qwen/chat/api?question=基于上述回答内容,SpringBoot哪个模块学习难度较大,评估一下

再次查看数据库表,发现已经基于会话保存回答问题了

- 第三次提问,会话记忆
GET http://localhost:8088/boot/ai/qwen/chat/api?question=基于上述回答内容,SpringBoot哪个模块学习难度相对较小,比较容易上手,评估一下

再次查看数据库,以及基本保存进来,基于数据库内容实现上下文会话回答了。

4029

被折叠的 条评论
为什么被折叠?



