SpringAi基于PgSQL数据库存储扩展ChatMemory

一、环境准备

      SpringAI入门学习

        <!-- SpringAI-->
        <dependency>
            <groupId>com.alibaba.cloud.ai</groupId>
            <artifactId>spring-ai-alibaba-starter</artifactId>
            <version>1.0.0-M6.1</version>
        </dependency>
        <dependency>
            <groupId>org.postgresql</groupId>
            <artifactId>postgresql</artifactId>
            <version>42.3.1</version>
        </dependency>
        <dependency>
            <groupId>org.springframework.ai</groupId>
            <artifactId>spring-ai-pgvector-store-spring-boot-starter</artifactId>
            <version>1.0.0-M6</version>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-jdbc</artifactId>
        </dependency>

     上文中实现了基于内存的会话保持功能,现在要基于数据库扩展实现

     JdbcTemplate+PgSQL数据库实现扩展ChatMemory实现。

二、扩展实现

package org.springframework.ai.chat.memory;

import java.util.List;
import org.springframework.ai.chat.messages.Message;

public interface ChatMemory {
    default void add(String conversationId, Message message) {
        this.add(conversationId, List.of(message));
    }

    void add(String conversationId, List<Message> messages);

    List<Message> get(String conversationId, int lastN);

    void clear(String conversationId);
}
  • 表结构准备
CREATE TABLE chat_messages (
                               id BIGSERIAL PRIMARY KEY,
                               conversation_id VARCHAR(255) NOT NULL,
                               message_type VARCHAR(50) NOT NULL,
                               content TEXT NOT NULL,
                               created_at TIMESTAMP NOT NULL
);
-- 创建索引
CREATE INDEX idx_conversation_id ON chat_messages (conversation_id);
  • ChatDao接口定义
package org.spring.springaiprojet.dao;

import org.spring.springaiprojet.entity.ChatMessageEntity;

import java.util.List;

public interface ChatDao {
    /**
     * 保存表
     * @param messages
     */
    void insertMessages(List<ChatMessageEntity> messages);

    /**
     * 查询最近的N条消息
     * @param conversationId
     * @param lastN
     * @return
     */
    List<ChatMessageEntity> findLastNMessages(String conversationId, int lastN);


    /**
     * 删除会话
     * @param conversationId
     */
    void deleteByConversationId(String conversationId);
}
  • ChatImpl接口实现
import org.spring.springaiprojet.dao.ChatDao;
import org.spring.springaiprojet.entity.ChatMessageEntity;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.stereotype.Repository;

import java.time.LocalDateTime;
import java.util.List;

@Repository
public class ChatDaoImpl implements ChatDao {

    @Autowired
    private JdbcTemplate jdbcTemplate;

    @Override
    public void insertMessages(List<ChatMessageEntity> messages) {
         jdbcTemplate.batchUpdate("insert into chat_messages (conversation_id, message_type, content, created_at) values (?, ?, ?, ?)", messages, messages.size(), (ps, message) -> {
            ps.setString(1, message.getConversationId());
            ps.setString(2, message.getMessageType());
            ps.setString(3, message.getContent());
            ps.setObject(4, message.getCreatedAt());
        });
    }

    @Override
    public List<ChatMessageEntity> findLastNMessages(String conversationId, int lastN) {
        return jdbcTemplate.query(
                "select * from chat_messages where conversation_id = ? order by created_at desc limit ?",
                (rs, rowNum) -> {
                    ChatMessageEntity message = new ChatMessageEntity();
                    message.setConversationId(rs.getString("conversation_id"));
                    message.setMessageType(rs.getString("message_type"));
                    message.setContent(rs.getString("content"));
                    message.setCreatedAt(rs.getObject("created_at", LocalDateTime.class));
                    return message;
                },
                conversationId,
                lastN
        );
    }

    @Override
    public void deleteByConversationId(String conversationId) {
        jdbcTemplate.update("delete from chat_messages where conversation_id = ?", conversationId);
    }
}
  • ChatMessageEntity
package org.spring.springaiprojet.entity;

import java.time.LocalDateTime;

public class ChatMessageEntity {
    private Long id;
    private String conversationId;
    private String messageType;
    private String content;
    private LocalDateTime createdAt;

    public Long getId() {
        return id;
    }

    public void setId(Long id) {
        this.id = id;
    }

    public String getConversationId() {
        return conversationId;
    }

    public void setConversationId(String conversationId) {
        this.conversationId = conversationId;
    }

    public String getMessageType() {
        return messageType;
    }

    public void setMessageType(String messageType) {
        this.messageType = messageType;
    }

    public String getContent() {
        return content;
    }

    public void setContent(String content) {
        this.content = content;
    }

    public LocalDateTime getCreatedAt() {
        return createdAt;
    }

    public void setCreatedAt(LocalDateTime createdAt) {
        this.createdAt = createdAt;
    }
}
  • PgSQLChatMemory
package org.spring.springaiprojet.config.chat;

import org.spring.springaiprojet.dao.ChatDao;
import org.spring.springaiprojet.entity.ChatMessageEntity;
import org.spring.springaiprojet.entity.MessageEnum;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import java.time.LocalDateTime;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;

@Component
public class PgSQLChatMemory implements ChatMemory {

    @Autowired
    private ChatDao chatDao;
    @Override
    public void add(String conversationId, List<Message> messages) {
        if (messages == null || messages.isEmpty()) {
            return;
        }
        List<ChatMessageEntity> entities = messages.stream()
                .map(msg -> {
                    ChatMessageEntity entity = new ChatMessageEntity();
                    entity.setConversationId(conversationId);
                    entity.setContent(msg.getText());

                    if (msg instanceof UserMessage) {
                        entity.setMessageType(MessageEnum.USER.getValue());
                    } else if (msg instanceof AssistantMessage) {
                        entity.setMessageType(MessageEnum.ASSISTANT.getValue());
                    }
                    entity.setCreatedAt(LocalDateTime.now());
                    return entity;
                })
                .collect(Collectors.toList());
       chatDao.insertMessages(entities);
    }

    @Override
    public List<Message> get(String conversationId, int lastN) {
        List<ChatMessageEntity> entities = chatDao.findLastNMessages(conversationId, lastN);
        if (entities == null || entities.isEmpty()) {
            return Collections.emptyList();
        }
        // 倒序
        Collections.reverse(entities);
        return entities.stream()
                .map(entity -> {
                    switch (entity.getMessageType()) {
                        case "user":
                            return new UserMessage(entity.getContent());
                        case "assistant":
                            return new AssistantMessage(entity.getContent());
                        default:
                            throw new IllegalArgumentException("未知的消息类型!");
                    }
                })
                .collect(Collectors.toList());
    }

    @Override
    public void clear(String conversationId) {
        if (conversationId == null){
            return;
        }
        chatDao.deleteByConversationId(conversationId);
    }

}
  • AiConfig实现
    /**
     * 基于PgSQL实现会话记忆
     */
    @Autowired
    private PgSQLChatMemory pgSQLChatMemory;

    @Bean
    public ChatClient chatClient(ChatClient.Builder builder) {
        // defaultSystem,默认系统角色,带有对话身份
        return builder
//                .defaultSystem("请以通俗开发者角度介绍")
                // 增加会话记忆,基于
                .defaultAdvisors(new PromptChatMemoryAdvisor(pgSQLChatMemory)).build();
//                .defaultAdvisors(new SimpleLoggerAdvisor()).build();
    }
  • 控制器实现
    /**
     * 普通对话模式
     * @param question
     * @return
     */
    @RequestMapping("/qwen/chat/api")
    public String chat(String question) {
        return chatClient.prompt().user(question).call().content();
    }

三、测试调用

  • 第一次调用如下接口
GET http://localhost:8088/boot/ai/qwen/chat/api?question=SpringBoot框架各个模块作用

已经写入数据表中了

  • 第二次基于上文会话内容,写入表中
GET http://localhost:8088/boot/ai/qwen/chat/api?question=基于上述回答内容,SpringBoot哪个模块学习难度较大,评估一下

再次查看数据库表,发现已经基于会话保存回答问题了

  • 第三次提问,会话记忆
GET http://localhost:8088/boot/ai/qwen/chat/api?question=基于上述回答内容,SpringBoot哪个模块学习难度相对较小,比较容易上手,评估一下

再次查看数据库,以及基本保存进来,基于数据库内容实现上下文会话回答了。

Pgsql中,存储过程有其特定的使用方法和示例。 ### 判断存储过程是否存在 可以使用以下SQL语句判断存储过程是否存在: ```sql SELECT EXISTS( SELECT * FROM pg_catalog.pg_proc JOIN pg_namespace ON pg_catalog.pg_proc.pronamespace = pg_namespace.oid WHERE proname = '存储过程名称' AND pg_namespace.nspname = '存储过程所在的层 即schema' ); ``` 这里通过查询`pg_catalog.pg_proc`和`pg_namespace`表,根据存储过程名称和所在的`schema`来判断存储过程是否存在[^1]。 ### 用PL/pgSQL存储过程 PL/pgSQL是PostgreSQL的一种过程语言,用于编写存储过程。其存储过程结构包含变量类型、连接字符、控制结构(如if条件、多种循环结构、异常捕获等)。例如常见的控制结构有if条件(五种形式)、循环(如LOOP、EXIT、CONTINUE、WHILE、FOR(整数变种))等 [^2]。 ### 存储过程示例 以下是一个简单的转账存储过程示例: ```sql -- 创建accounts表 CREATE TABLE accounts ( id INT GENERATED BY DEFAULT AS IDENTITY, name VARCHAR(100) NOT NULL, balance DEC(15,2) NOT NULL, PRIMARY KEY(id) ); -- 插入示例数据 INSERT INTO accounts(name,balance) VALUES('Bob',10000); INSERT INTO accounts(name,balance) VALUES('Alice',10000); -- 查看accounts表数据 select * from accounts; -- 创建或替换存储过程transfer CREATE OR REPLACE PROCEDURE transfer(INT, INT, DEC) LANGUAGE plpgsql AS $$ BEGIN -- 从发送者账户减去相应金额 UPDATE accounts SET balance = balance - $3 WHERE id = $1; -- 向接收者账户添加相应金额 UPDATE accounts SET balance = balance + $3 WHERE id = $2; COMMIT; END; $$; -- 调用存储过程 call transfer(1,2,1000); -- 再次查看accounts表数据 select * from accounts; ``` 这个示例中,首先创建了`accounts`表并插入了示例数据,然后创建了一个名为`transfer`的存储过程,用于实现账户之间的转账功能,最后调用该存储过程并查看转账后的账户数据 [^3]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

大道之简

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值