Spring AI 设计模式综合应用与完整工程实现
🎯 一、项目概述
本项目将模拟实现 Spring AI 的核心原理,通过多种设计模式的组合使用,展示一个完整的 AI 服务集成框架。我们将实现以下核心功能:
- 统一 AI 客户端接口 - 支持多种 AI 模型(OpenAI、Azure、本地模型)
- 智能路由机制 - 根据模型类型自动选择适配器
- 可扩展插件架构 - 支持自定义模型扩展
- 完整配置管理 - 基于 Spring Boot 的自动配置
- 监控和指标收集 - 集成 Micrometer
- 完整的示例应用 - 演示如何使用
📁 二、完整项目结构
spring-ai-simulation/
├── pom.xml # Maven 依赖配置
├── src/
│ ├── main/
│ │ ├── java/
│ │ │ └── com/
│ │ │ └── example/
│ │ │ └── ai/
│ │ │ ├── SpringAiSimulationApplication.java
│ │ │ ├── annotation/
│ │ │ │ └── EnableAiClient.java
│ │ │ ├── config/
│ │ │ │ ├── AiAutoConfiguration.java
│ │ │ │ ├── AiProperties.java
│ │ │ │ ├── AiClientConfig.java
│ │ │ │ └── Condition/
│ │ │ │ ├── OnAiModelCondition.java
│ │ │ │ └── OnAiProviderCondition.java
│ │ │ ├── core/
│ │ │ │ ├── AiClient.java
│ │ │ │ ├── ChatClient.java
│ │ │ │ ├── EmbeddingClient.java
│ │ │ │ ├── model/
│ │ │ │ │ ├── ChatRequest.java
│ │ │ │ │ ├── ChatResponse.java
│ │ │ │ │ ├── Message.java
│ │ │ │ │ ├── Choice.java
│ │ │ │ │ ├── Usage.java
│ │ │ │ │ ├── EmbeddingRequest.java
│ │ │ │ │ └── EmbeddingResponse.java
│ │ │ │ ├── template/
│ │ │ │ │ ├── AiClientTemplate.java
│ │ │ │ │ ├── BaseAiClient.java
│ │ │ │ │ ├── AbstractAiClient.java
│ │ │ │ │ └── ChatClientTemplate.java
│ │ │ │ └── exception/
│ │ │ │ ├── AiClientException.java
│ │ │ │ ├── AiServiceException.java
│ │ │ │ ├── RateLimitException.java
│ │ │ │ └── UnauthorizedException.java
│ │ │ ├── pattern/
│ │ │ │ ├── factory/
│ │ │ │ │ ├── AiClientFactory.java
│ │ │ │ │ └── AiAdapterFactory.java
│ │ │ │ ├── strategy/
│ │ │ │ │ ├── AiStrategy.java
│ │ │ │ │ ├── AiProviderStrategy.java
│ │ │ │ │ ├── OpenAiStrategy.java
│ │ │ │ │ ├── AzureAiStrategy.java
│ │ │ │ │ └── LocalAiStrategy.java
│ │ │ │ ├── adapter/
│ │ │ │ │ ├── AiModelAdapter.java
│ │ │ │ │ ├── OpenAiAdapter.java
│ │ │ │ │ ├── AzureAiAdapter.java
│ │ │ │ │ └── LocalAiAdapter.java
│ │ │ │ ├── builder/
│ │ │ │ │ └── ChatRequestBuilder.java
│ │ │ │ ├── chain/
│ │ │ │ │ ├── AiClientInterceptor.java
│ │ │ │ │ └── AiClientInterceptorChain.java
│ │ │ │ └── proxy/
│ │ │ │ └── AiClientProxy.java
│ │ │ ├── provider/
│ │ │ │ ├── OpenAiClient.java
│ │ │ │ ├── AzureAiClient.java
│ │ │ │ └── LocalAiClient.java
│ │ │ ├── service/
│ │ │ │ ├── AiService.java
│ │ │ │ ├── ChatService.java
│ │ │ │ └── EmbeddingService.java
│ │ │ ├── interceptor/
│ │ │ │ ├── LoggingInterceptor.java
│ │ │ │ ├── MetricsInterceptor.java
│ │ │ │ └── RetryInterceptor.java
│ │ │ ├── metrics/
│ │ │ │ └── AiClientMetrics.java
│ │ │ └── router/
│ │ │ └── ModelRouter.java
│ │ └── resources/
│ │ ├── application.yml
│ │ └── application-dev.yml
│ └── test/
│ └── java/
│ └── com/
│ └── example/
│ └── ai/
│ ├── SpringAiSimulationApplicationTests.java
│ └── service/
│ └── AiServiceTest.java
📦 三、Maven 依赖配置
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0
http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.example</groupId>
<artifactId>spring-ai-simulation</artifactId>
<version>1.0.0</version>
<name>Spring AI Simulation</name>
<description>A simulation framework demonstrating Spring AI design patterns</description>
<parent>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-parent</artifactId>
<version>3.1.5</version>
<relativePath/>
</parent>
<properties>
<java.version>17</java.version>
<jackson.version>2.15.2</jackson.version>
<micrometer.version>1.11.5</micrometer.version>
</properties>
<dependencies>
<!-- Spring Boot Starters -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-actuator</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-aop</artifactId>
</dependency>
<!-- Monitoring -->
<dependency>
<groupId>io.micrometer</groupId>
<artifactId>micrometer-core</artifactId>
</dependency>
<dependency>
<groupId>io.micrometer</groupId>
<artifactId>micrometer-registry-prometheus</artifactId>
</dependency>
<!-- HTTP Client -->
<dependency>
<groupId>org.apache.httpcomponents.client5</groupId>
<artifactId>httpclient5</artifactId>
</dependency>
<!-- JSON Processing -->
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>${jackson.version}</version>
</dependency>
<!-- Utilities -->
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<optional>true</optional>
</dependency>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<version>32.1.3-jre</version>
</dependency>
<!-- Testing -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-maven-plugin</artifactId>
<configuration>
<excludes>
<exclude>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
</exclude>
</excludes>
</configuration>
</plugin>
</plugins>
</build>
</project>
⚙️ 四、核心配置文件
1. application.yml
# 应用配置
server:
port: 8080
servlet:
context-path: /api
spring:
application:
name: spring-ai-simulation
# AI 配置
ai:
enabled: true
default-provider: openai
default-model: gpt-3.5-turbo
enable-metrics: true
# OpenAI 配置
openai:
enabled: true
api-key: ${OPENAI_API_KEY:your-openai-key}
base-url: https://api.openai.com/v1
timeout: 30s
max-retries: 3
# Azure AI 配置
azure:
enabled: false
endpoint: https://your-resource.openai.azure.com
api-key: ${AZURE_API_KEY:your-azure-key}
deployment-name: gpt-35-turbo
api-version: 2023-12-01-preview
# 本地模型配置
local:
enabled: false
model-path: /path/to/local/model
device: cpu
# HTTP 客户端配置
http-client:
max-connections: 50
connection-timeout: 10s
read-timeout: 30s
keep-alive: 5m
# 监控配置
monitoring:
enabled: true
metrics-prefix: ai.client
slow-query-threshold: 1000ms
enable-tracing: true
# 日志配置
logging:
level:
com.example.ai: DEBUG
pattern:
console: "%d{yyyy-MM-dd HH:mm:ss} [%thread] %-5level %logger{36} - %msg%n"
# Actuator 配置
management:
endpoints:
web:
exposure:
include: health,info,metrics,prometheus
metrics:
export:
prometheus:
enabled: true
endpoint:
health:
show-details: always
2. 应用启动类
package com.example.ai;
import com.example.ai.annotation.EnableAiClient;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
@SpringBootApplication
@EnableAiClient
public class SpringAiSimulationApplication {
public static void main(String[] args) {
SpringApplication.run(SpringAiSimulationApplication.class, args);
}
}
🔧 五、核心设计模式实现
1. 策略模式 (Strategy Pattern) - AI 提供商策略
// 策略接口
package com.example.ai.pattern.strategy;
import com.example.ai.core.model.ChatRequest;
import com.example.ai.core.model.ChatResponse;
import com.example.ai.core.model.EmbeddingRequest;
import com.example.ai.core.model.EmbeddingResponse;
/**
* AI 提供商策略接口
*/
public interface AiProviderStrategy {
String getProviderName();
ChatResponse chat(ChatRequest request);
EmbeddingResponse embed(EmbeddingRequest request);
boolean supports(String model);
void validateApiKey(String apiKey);
}
// OpenAI 策略实现
package com.example.ai.pattern.strategy;
import com.example.ai.core.model.ChatRequest;
import com.example.ai.core.model.ChatResponse;
import com.example.ai.core.model.EmbeddingRequest;
import com.example.ai.core.model.EmbeddingResponse;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.web.client.RestTemplate;
@Slf4j
public class OpenAiStrategy implements AiProviderStrategy {
private final RestTemplate restTemplate;
private final ObjectMapper objectMapper;
private final String apiKey;
private final String baseUrl;
public OpenAiStrategy(RestTemplate restTemplate,
ObjectMapper objectMapper,
String apiKey,
String baseUrl) {
this.restTemplate = restTemplate;
this.objectMapper = objectMapper;
this.apiKey = apiKey;
this.baseUrl = baseUrl;
}
@Override
public String getProviderName() {
return "openai";
}
@Override
public ChatResponse chat(ChatRequest request) {
log.debug("OpenAI Strategy: Processing chat request for model: {}", request.getModel());
// 构建 OpenAI 特定格式的请求
var openAiRequest = buildOpenAiChatRequest(request);
// 发送请求
var headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
headers.setBearerAuth(apiKey);
var entity = new HttpEntity<>(openAiRequest, headers);
var url = String.format("%s/chat/completions", baseUrl);
try {
var response = restTemplate.postForObject(url, entity, String.class);
return parseOpenAiChatResponse(response);
} catch (Exception e) {
throw new RuntimeException("OpenAI API call failed", e);
}
}
@Override
public EmbeddingResponse embed(EmbeddingRequest request) {
// 嵌入向量化实现
return null;
}
@Override
public boolean supports(String model) {
return model != null &&
(model.startsWith("gpt-") ||
model.startsWith("text-embedding-"));
}
@Override
public void validateApiKey(String apiKey) {
if (apiKey == null || apiKey.trim().isEmpty()) {
throw new IllegalArgumentException("OpenAI API key is required");
}
if (!apiKey.startsWith("sk-")) {
log.warn("OpenAI API key may be invalid, should start with 'sk-'");
}
}
private Object buildOpenAiChatRequest(ChatRequest request) {
// 转换逻辑
return null;
}
private ChatResponse parseOpenAiChatResponse(String response) {
// 解析逻辑
return null;
}
}
2. 模板方法模式 (Template Method Pattern) - AI 客户端模板
// 抽象模板类
package com.example.ai.core.template;
import com.example.ai.core.AiClient;
import com.example.ai.core.model.ChatRequest;
import com.example.ai.core.model.ChatResponse;
import com.example.ai.core.model.EmbeddingRequest;
import com.example.ai.core.model.EmbeddingResponse;
import com.example.ai.pattern.chain.AiClientInterceptorChain;
import lombok.extern.slf4j.Slf4j;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
/**
* 抽象 AI 客户端模板
* 定义 AI 请求的标准处理流程
*/
@Slf4j
public abstract class AbstractAiClient implements AiClient {
protected final AiClientInterceptorChain interceptorChain;
protected final Executor executor;
protected AbstractAiClient(AiClientInterceptorChain interceptorChain,
Executor executor) {
this.interceptorChain = interceptorChain;
this.executor = executor;
}
@Override
public ChatResponse chat(ChatRequest request) {
// 模板方法:定义标准处理流程
long startTime = System.currentTimeMillis();
try {
// 1. 预处理
preProcess(request);
// 2. 执行拦截器链前置处理
interceptorChain.applyPreHandle(request);
// 3. 执行实际业务逻辑(由子类实现)
ChatResponse response = doChat(request);
// 4. 执行拦截器链后置处理
interceptorChain.applyPostHandle(request, response);
// 5. 后处理
postProcess(request, response, startTime);
return response;
} catch (Exception e) {
// 6. 异常处理
handleException(request, e, startTime);
throw e;
}
}
@Override
public CompletableFuture<ChatResponse> chatAsync(ChatRequest request) {
return CompletableFuture.supplyAsync(() -> chat(request), executor);
}
/**
* 预处理(钩子方法)
*/
protected void preProcess(ChatRequest request) {
log.debug("Pre-processing chat request: {}", request);
validateRequest(request);
}
/**
* 实际聊天逻辑(由子类实现)
*/
protected abstract ChatResponse doChat(ChatRequest request);
/**
* 后处理(钩子方法)
*/
protected void postProcess(ChatRequest request, ChatResponse response, long startTime) {
long duration = System.currentTimeMillis() - startTime;
log.debug("Chat completed in {} ms. Tokens used: {}",
duration, response.getUsage().getTotalTokens());
// 发布事件
publishChatCompletedEvent(request, response, duration);
}
/**
* 异常处理
*/
protected void handleException(ChatRequest request, Exception e, long startTime) {
long duration = System.currentTimeMillis() - startTime;
log.error("Chat request failed after {} ms: {}", duration, e.getMessage(), e);
// 发布异常事件
publishChatFailedEvent(request, e, duration);
}
/**
* 验证请求
*/
protected void validateRequest(ChatRequest request) {
if (request == null) {
throw new IllegalArgumentException("ChatRequest cannot be null");
}
if (request.getMessages() == null || request.getMessages().isEmpty()) {
throw new IllegalArgumentException("Messages cannot be empty");
}
}
// 其他钩子方法...
protected abstract void publishChatCompletedEvent(ChatRequest request,
ChatResponse response,
long duration);
protected abstract void publishChatFailedEvent(ChatRequest request,
Exception e,
long duration);
}
3. 工厂模式 (Factory Pattern) - AI 客户端工厂
// AI 客户端工厂
package com.example.ai.pattern.factory;
import com.example.ai.core.AiClient;
import com.example.ai.config.AiProperties;
import com.example.ai.pattern.strategy.AiProviderStrategy;
import com.example.ai.pattern.strategy.OpenAiStrategy;
import com.example.ai.pattern.strategy.AzureAiStrategy;
import com.example.ai.pattern.strategy.LocalAiStrategy;
import com.example.ai.provider.OpenAiClient;
import com.example.ai.provider.AzureAiClient;
import com.example.ai.provider.LocalAiClient;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.client.RestTemplate;
import java.util.HashMap;
import java.util.Map;
/**
* AI 客户端工厂
* 根据配置创建不同类型的 AI 客户端
*/
@Slf4j
public class AiClientFactory {
private final Map<String, AiClient> clientCache = new HashMap<>();
private final AiProperties aiProperties;
private final RestTemplate restTemplate;
private final ObjectMapper objectMapper;
public AiClientFactory(AiProperties aiProperties,
RestTemplate restTemplate,
ObjectMapper objectMapper) {
this.aiProperties = aiProperties;
this.restTemplate = restTemplate;
this.objectMapper = objectMapper;
}
/**
* 根据提供商创建 AI 客户端
*/
public AiClient createAiClient(String provider) {
return clientCache.computeIfAbsent(provider, this::createNewAiClient);
}
/**
* 根据模型自动选择客户端
*/
public AiClient createAiClientForModel(String model) {
String provider = determineProviderByModel(model);
return createAiClient(provider);
}
private AiClient createNewAiClient(String provider) {
log.info("Creating new AI client for provider: {}", provider);
switch (provider.toLowerCase()) {
case "openai":
return createOpenAiClient();
case "azure":
return createAzureAiClient();
case "local":
return createLocalAiClient();
default:
throw new IllegalArgumentException("Unsupported AI provider: " + provider);
}
}
private AiClient createOpenAiClient() {
var properties = aiProperties.getOpenai();
var strategy = new OpenAiStrategy(
restTemplate,
objectMapper,
properties.getApiKey(),
properties.getBaseUrl()
);
return new OpenAiClient(strategy, aiProperties);
}
private AiClient createAzureAiClient() {
var properties = aiProperties.getAzure();
var strategy = new AzureAiStrategy(
restTemplate,
objectMapper,
properties.getApiKey(),
properties.getEndpoint(),
properties.getDeploymentName(),
properties.getApiVersion()
);
return new AzureAiClient(strategy, aiProperties);
}
private AiClient createLocalAiClient() {
var properties = aiProperties.getLocal();
var strategy = new LocalAiStrategy(
properties.getModelPath(),
properties.getDevice()
);
return new LocalAiClient(strategy, aiProperties);
}
private String determineProviderByModel(String model) {
if (model == null) {
return aiProperties.getDefaultProvider();
}
if (model.startsWith("gpt-") || model.startsWith("text-")) {
return "openai";
} else if (model.contains("azure")) {
return "azure";
} else if (model.contains("local") || model.contains("llama")) {
return "local";
}
return aiProperties.getDefaultProvider();
}
/**
* 获取所有支持的提供商
*/
public Map<String, AiProviderStrategy> getSupportedStrategies() {
Map<String, AiProviderStrategy> strategies = new HashMap<>();
if (aiProperties.getOpenai().isEnabled()) {
strategies.put("openai", new OpenAiStrategy(
restTemplate, objectMapper,
aiProperties.getOpenai().getApiKey(),
aiProperties.getOpenai().getBaseUrl()
));
}
if (aiProperties.getAzure().isEnabled()) {
strategies.put("azure", new AzureAiStrategy(
restTemplate, objectMapper,
aiProperties.getAzure().getApiKey(),
aiProperties.getAzure().getEndpoint(),
aiProperties.getAzure().getDeploymentName(),
aiProperties.getAzure().getApiVersion()
));
}
if (aiProperties.getLocal().isEnabled()) {
strategies.put("local", new LocalAiStrategy(
aiProperties.getLocal().getModelPath(),
aiProperties.getLocal().getDevice()
));
}
return strategies;
}
}
4. 适配器模式 (Adapter Pattern) - 模型适配器
// 适配器接口
package com.example.ai.pattern.adapter;
import com.example.ai.core.model.ChatRequest;
import com.example.ai.core.model.ChatResponse;
/**
* AI 模型适配器接口
* 将统一请求转换为特定模型格式
*/
public interface AiModelAdapter {
String getModelType();
boolean supports(String model);
Object adaptRequest(ChatRequest request);
ChatResponse adaptResponse(Object rawResponse, ChatRequest originalRequest);
}
// OpenAI 适配器实现
package com.example.ai.pattern.adapter;
import com.example.ai.core.model.ChatRequest;
import com.example.ai.core.model.ChatResponse;
import com.example.ai.core.model.Message;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@Slf4j
@RequiredArgsConstructor
public class OpenAiAdapter implements AiModelAdapter {
private final ObjectMapper objectMapper;
@Override
public String getModelType() {
return "openai";
}
@Override
public boolean supports(String model) {
return model != null &&
(model.startsWith("gpt-") ||
model.startsWith("text-"));
}
@Override
public Object adaptRequest(ChatRequest request) {
Map<String, Object> openAiRequest = new HashMap<>();
// 转换模型名称
openAiRequest.put("model", request.getModel());
// 转换消息
List<Map<String, String>> messages = request.getMessages().stream()
.map(this::convertMessage)
.collect(Collectors.toList());
openAiRequest.put("messages", messages);
// 转换参数
if (request.getTemperature() != null) {
openAiRequest.put("temperature", request.getTemperature());
}
if (request.getMaxTokens() != null) {
openAiRequest.put("max_tokens", request.getMaxTokens());
}
if (request.getTopP() != null) {
openAiRequest.put("top_p", request.getTopP());
}
// 流式响应
if (request.isStream()) {
openAiRequest.put("stream", true);
}
log.debug("Adapted OpenAI request: {}", openAiRequest);
return openAiRequest;
}
@Override
public ChatResponse adaptResponse(Object rawResponse, ChatRequest originalRequest) {
try {
String responseJson = objectMapper.writeValueAsString(rawResponse);
Map<String, Object> responseMap = objectMapper.readValue(responseJson, Map.class);
return ChatResponse.builder()
.id((String) responseMap.get("id"))
.model((String) responseMap.get("model"))
.created((Integer) responseMap.get("created"))
.choices(extractChoices(responseMap))
.usage(extractUsage(responseMap))
.build();
} catch (JsonProcessingException e) {
throw new RuntimeException("Failed to adapt OpenAI response", e);
}
}
private Map<String, String> convertMessage(Message message) {
Map<String, String> result = new HashMap<>();
result.put("role", message.getRole().name().toLowerCase());
result.put("content", message.getContent());
return result;
}
private List<ChatResponse.Choice> extractChoices(Map<String, Object> responseMap) {
// 提取 choices
return null;
}
private ChatResponse.Usage extractUsage(Map<String, Object> responseMap) {
// 提取 usage
return null;
}
}
5. 责任链模式 (Chain of Responsibility) - 拦截器链
// 拦截器接口
package com.example.ai.pattern.chain;
import com.example.ai.core.model.ChatRequest;
import com.example.ai.core.model.ChatResponse;
/**
* AI 客户端拦截器
*/
public interface AiClientInterceptor {
/**
* 前置处理
*/
default boolean preHandle(ChatRequest request) {
return true;
}
/**
* 后置处理
*/
default void postHandle(ChatRequest request, ChatResponse response) {
}
/**
* 异常处理
*/
default void afterCompletion(ChatRequest request, Exception ex) {
}
}
// 拦截器链
package com.example.ai.pattern.chain;
import com.example.ai.core.model.ChatRequest;
import com.example.ai.core.model.ChatResponse;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import java.util.List;
@Slf4j
@RequiredArgsConstructor
public class AiClientInterceptorChain {
private final List<AiClientInterceptor> interceptors;
/**
* 应用所有拦截器的前置处理
*/
public boolean applyPreHandle(ChatRequest request) {
for (AiClientInterceptor interceptor : interceptors) {
if (!interceptor.preHandle(request)) {
log.debug("Interceptor {} prevented request processing",
interceptor.getClass().getSimpleName());
return false;
}
}
return true;
}
/**
* 应用所有拦截器的后置处理
*/
public void applyPostHandle(ChatRequest request, ChatResponse response) {
for (AiClientInterceptor interceptor : interceptors) {
try {
interceptor.postHandle(request, response);
} catch (Exception e) {
log.error("Interceptor postHandle failed", e);
}
}
}
/**
* 应用所有拦截器的完成处理
*/
public void applyAfterCompletion(ChatRequest request, Exception ex) {
for (AiClientInterceptor interceptor : interceptors) {
try {
interceptor.afterCompletion(request, ex);
} catch (Exception e) {
log.error("Interceptor afterCompletion failed", e);
}
}
}
}
6. 构建者模式 (Builder Pattern) - 请求构建器
// 聊天请求构建器
package com.example.ai.pattern.builder;
import com.example.ai.core.model.ChatRequest;
import com.example.ai.core.model.Message;
import java.util.ArrayList;
import java.util.List;
/**
* 聊天请求构建器
* 使用构建者模式创建复杂的请求对象
*/
public class ChatRequestBuilder {
private String model;
private final List<Message> messages = new ArrayList<>();
private Double temperature;
private Integer maxTokens;
private Double topP;
private boolean stream = false;
public ChatRequestBuilder model(String model) {
this.model = model;
return this;
}
public ChatRequestBuilder message(Message message) {
this.messages.add(message);
return this;
}
public ChatRequestBuilder message(String role, String content) {
return message(Message.builder()
.role(Message.Role.valueOf(role.toUpperCase()))
.content(content)
.build());
}
public ChatRequestBuilder systemMessage(String content) {
return message("system", content);
}
public ChatRequestBuilder userMessage(String content) {
return message("user", content);
}
public ChatRequestBuilder assistantMessage(String content) {
return message("assistant", content);
}
public ChatRequestBuilder temperature(Double temperature) {
this.temperature = temperature;
return this;
}
public ChatRequestBuilder maxTokens(Integer maxTokens) {
this.maxTokens = maxTokens;
return this;
}
public ChatRequestBuilder topP(Double topP) {
this.topP = topP;
return this;
}
public ChatRequestBuilder stream(boolean stream) {
this.stream = stream;
return this;
}
public ChatRequest build() {
if (model == null || model.trim().isEmpty()) {
throw new IllegalArgumentException("Model is required");
}
if (messages.isEmpty()) {
throw new IllegalArgumentException("At least one message is required");
}
return ChatRequest.builder()
.model(model)
.messages(new ArrayList<>(messages))
.temperature(temperature)
.maxTokens(maxTokens)
.topP(topP)
.stream(stream)
.build();
}
}
🔄 六、模型路由器 (Model Router)
package com.example.ai.router;
import com.example.ai.core.AiClient;
import com.example.ai.pattern.factory.AiClientFactory;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
/**
* 模型路由器
* 根据模型名称路由到合适的 AI 客户端
*/
@Slf4j
@Component
@RequiredArgsConstructor
public class ModelRouter {
private final AiClientFactory aiClientFactory;
private final Map<String, AiClient> clientCache = new ConcurrentHashMap<>();
/**
* 路由到合适的 AI 客户端
*/
public AiClient route(String model) {
return clientCache.computeIfAbsent(model, this::createClientForModel);
}
/**
* 获取所有可用的模型
*/
public Map<String, String> getAvailableModels() {
Map<String, String> models = new ConcurrentHashMap<>();
// OpenAI 模型
models.put("gpt-3.5-turbo", "openai");
models.put("gpt-4", "openai");
models.put("text-embedding-ada-002", "openai");
// Azure 模型
models.put("gpt-35-turbo", "azure");
models.put("gpt-4-azure", "azure");
// 本地模型
models.put("llama-2-7b", "local");
models.put("vicuna-13b", "local");
return models;
}
/**
* 根据模型选择最佳提供商
*/
public String selectBestProvider(String model, String preferredProvider) {
if (preferredProvider != null) {
return preferredProvider;
}
var availableModels = getAvailableModels();
return availableModels.getOrDefault(model, "openai");
}
private AiClient createClientForModel(String model) {
String provider = determineProvider(model);
log.info("Creating AI client for model: {} -> provider: {}", model, provider);
return aiClientFactory.createAiClient(provider);
}
private String determineProvider(String model) {
if (model == null) {
return "openai";
}
if (model.startsWith("gpt-") || model.startsWith("text-")) {
return "openai";
} else if (model.contains("azure") || model.contains("AZURE")) {
return "azure";
} else if (model.contains("local") || model.contains("llama")) {
return "local";
}
return "openai";
}
}
📊 七、监控指标收集
package com.example.ai.metrics;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.Timer;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
/**
* AI 客户端监控指标
*/
@Slf4j
@Component
@RequiredArgsConstructor
public class AiClientMetrics {
private final MeterRegistry meterRegistry;
private final Map<String, Timer> timers = new ConcurrentHashMap<>();
private final Map<String, Counter> counters = new ConcurrentHashMap<>();
/**
* 记录请求成功
*/
public void recordSuccess(String provider, String model, String operation, long duration) {
getTimer(provider, model, operation)
.record(duration, TimeUnit.MILLISECONDS);
incrementCounter("ai.client.requests.success",
Map.of("provider", provider, "model", model, "operation", operation));
// 记录响应时间分布
meterRegistry.timer("ai.client.response.time",
"provider", provider,
"model", model,
"operation", operation)
.record(duration, TimeUnit.MILLISECONDS);
}
/**
* 记录请求失败
*/
public void recordError(String provider, String model, String operation, String errorType) {
incrementCounter("ai.client.requests.error",
Map.of("provider", provider,
"model", model,
"operation", operation,
"error_type", errorType));
}
/**
* 记录令牌使用
*/
public void recordTokenUsage(String provider, String model,
int promptTokens, int completionTokens, int totalTokens) {
meterRegistry.counter("ai.client.tokens.prompt",
"provider", provider, "model", model)
.increment(promptTokens);
meterRegistry.counter("ai.client.tokens.completion",
"provider", provider, "model", model)
.increment(completionTokens);
meterRegistry.counter("ai.client.tokens.total",
"provider", provider, "model", model)
.increment(totalTokens);
}
/**
* 记录流式响应块
*/
public void recordStreamChunk(String provider, String model) {
incrementCounter("ai.client.stream.chunks",
Map.of("provider", provider, "model", model));
}
private Timer getTimer(String provider, String model, String operation) {
String key = provider + ":" + model + ":" + operation;
return timers.computeIfAbsent(key, k ->
Timer.builder("ai.client.requests")
.tag("provider", provider)
.tag("model", model)
.tag("operation", operation)
.publishPercentiles(0.5, 0.95, 0.99)
.sla(java.time.Duration.ofMillis(100),
java.time.Duration.ofMillis(500),
java.time.Duration.ofMillis(1000))
.register(meterRegistry)
);
}
private void incrementCounter(String name, Map<String, String> tags) {
Counter.Builder builder = Counter.builder(name);
tags.forEach(builder::tag);
builder.register(meterRegistry).increment();
}
}
🎯 八、使用示例
1. 聊天服务示例
package com.example.ai.service;
import com.example.ai.core.AiClient;
import com.example.ai.core.model.ChatRequest;
import com.example.ai.core.model.ChatResponse;
import com.example.ai.core.model.Message;
import com.example.ai.router.ModelRouter;
import com.example.ai.metrics.AiClientMetrics;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.List;
import java.util.concurrent.CompletableFuture;
@Slf4j
@Service
@RequiredArgsConstructor
public class ChatService {
private final ModelRouter modelRouter;
private final AiClientMetrics metrics;
/**
* 同步聊天
*/
public String chat(String model, String userMessage) {
long startTime = System.currentTimeMillis();
try {
// 获取合适的 AI 客户端
AiClient aiClient = modelRouter.route(model);
// 构建请求
ChatRequest request = ChatRequest.builder()
.model(model)
.messages(List.of(
Message.userMessage(userMessage)
))
.temperature(0.7)
.maxTokens(1000)
.build();
// 执行聊天
ChatResponse response = aiClient.chat(request);
String content = response.getContent();
// 记录指标
long duration = System.currentTimeMillis() - startTime;
metrics.recordSuccess(
"openai", model, "chat", duration
);
metrics.recordTokenUsage(
"openai", model,
response.getUsage().getPromptTokens(),
response.getUsage().getCompletionTokens(),
response.getUsage().getTotalTokens()
);
return content;
} catch (Exception e) {
long duration = System.currentTimeMillis() - startTime;
metrics.recordError("openai", model, "chat", e.getClass().getSimpleName());
throw e;
}
}
/**
* 异步聊天
*/
public CompletableFuture<String> chatAsync(String model, String userMessage) {
return CompletableFuture.supplyAsync(() -> chat(model, userMessage));
}
/**
* 带上下文的聊天
*/
public String chatWithContext(String model, String systemPrompt,
List<Message> history, String userMessage) {
AiClient aiClient = modelRouter.route(model);
// 构建消息列表
List<Message> messages = new java.util.ArrayList<>();
messages.add(Message.systemMessage(systemPrompt));
messages.addAll(history);
messages.add(Message.userMessage(userMessage));
ChatRequest request = ChatRequest.builder()
.model(model)
.messages(messages)
.temperature(0.7)
.maxTokens(2000)
.build();
return aiClient.chat(request).getContent();
}
}
2. REST 控制器
package com.example.ai.controller;
import com.example.ai.service.ChatService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
@Slf4j
@RestController
@RequestMapping("/api/v1/ai")
@RequiredArgsConstructor
public class AiController {
private final ChatService chatService;
/**
* 简单聊天接口
*/
@PostMapping("/chat")
public ResponseEntity<Map<String, Object>> chat(
@RequestParam(defaultValue = "gpt-3.5-turbo") String model,
@RequestBody Map<String, String> request) {
String message = request.get("message");
if (message == null || message.trim().isEmpty()) {
return ResponseEntity.badRequest()
.body(Map.of("error", "Message is required"));
}
try {
String response = chatService.chat(model, message);
return ResponseEntity.ok(Map.of(
"model", model,
"response", response
));
} catch (Exception e) {
log.error("Chat failed", e);
return ResponseEntity.internalServerError()
.body(Map.of("error", e.getMessage()));
}
}
/**
* 异步聊天接口
*/
@PostMapping("/chat/async")
public CompletableFuture<ResponseEntity<Map<String, Object>>> chatAsync(
@RequestParam(defaultValue = "gpt-3.5-turbo") String model,
@RequestBody Map<String, String> request) {
return chatService.chatAsync(model, request.get("message"))
.thenApply(response -> ResponseEntity.ok(Map.of(
"model", model,
"response", response
)))
.exceptionally(e -> ResponseEntity.internalServerError()
.body(Map.of("error", e.getMessage())));
}
/**
* 健康检查
*/
@GetMapping("/health")
public ResponseEntity<Map<String, Object>> health() {
return ResponseEntity.ok(Map.of(
"status", "UP",
"timestamp", System.currentTimeMillis()
));
}
/**
* 获取支持的模型
*/
@GetMapping("/models")
public ResponseEntity<Map<String, Object>> getAvailableModels() {
return ResponseEntity.ok(Map.of(
"models", Map.of(
"gpt-3.5-turbo", "OpenAI GPT-3.5 Turbo",
"gpt-4", "OpenAI GPT-4",
"text-embedding-ada-002", "OpenAI Embedding Model"
)
));
}
}
🧪 九、测试代码
package com.example.ai.service;
import com.example.ai.core.AiClient;
import com.example.ai.core.model.ChatRequest;
import com.example.ai.core.model.ChatResponse;
import com.example.ai.core.model.Message;
import com.example.ai.router.ModelRouter;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.test.util.ReflectionTestUtils;
import java.util.List;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.*;
@ExtendWith(MockitoExtension.class)
class ChatServiceTest {
@Mock
private ModelRouter modelRouter;
@Mock
private AiClient aiClient;
private ChatService chatService;
@BeforeEach
void setUp() {
chatService = new ChatService(modelRouter, null);
}
@Test
void testChat_Success() {
// 准备测试数据
String model = "gpt-3.5-turbo";
String userMessage = "Hello, how are you?";
String expectedResponse = "I'm fine, thank you!";
// 模拟行为
when(modelRouter.route(model)).thenReturn(aiClient);
ChatResponse mockResponse = ChatResponse.builder()
.content(expectedResponse)
.build();
when(aiClient.chat(any(ChatRequest.class))).thenReturn(mockResponse);
// 执行测试
String result = chatService.chat(model, userMessage);
// 验证结果
assertEquals(expectedResponse, result);
verify(modelRouter, times(1)).route(model);
verify(aiClient, times(1)).chat(any(ChatRequest.class));
}
@Test
void testChat_WithSystemPrompt() {
// 准备测试数据
String model = "gpt-3.5-turbo";
String systemPrompt = "You are a helpful assistant.";
String userMessage = "What is AI?";
// 模拟行为
when(modelRouter.route(model)).thenReturn(aiClient);
ChatResponse mockResponse = ChatResponse.builder()
.content("AI stands for Artificial Intelligence.")
.build();
when(aiClient.chat(any(ChatRequest.class))).thenReturn(mockResponse);
// 执行测试
String result = chatService.chatWithContext(
model, systemPrompt, List.of(), userMessage
);
// 验证结果
assertNotNull(result);
verify(aiClient, times(1)).chat(any(ChatRequest.class));
}
@Test
void testChat_ModelNotFound() {
// 准备测试数据
String model = "unknown-model";
String userMessage = "Hello";
// 模拟行为
when(modelRouter.route(model)).thenThrow(
new IllegalArgumentException("Model not found: " + model)
);
// 执行测试并验证异常
assertThrows(IllegalArgumentException.class, () -> {
chatService.chat(model, userMessage);
});
}
}
📈 十、设计模式组合使用架构图
🔄 十一、工作流程时序图
🎯 十二、运行和测试
1. 启动应用
# 设置环境变量
export OPENAI_API_KEY=your-api-key-here
# 编译并运行
mvn clean package
java -jar target/spring-ai-simulation-1.0.0.jar
# 或者使用 Maven 直接运行
mvn spring-boot:run
2. 测试 API
# 健康检查
curl http://localhost:8080/api/v1/ai/health
# 获取可用模型
curl http://localhost:8080/api/v1/ai/models
# 聊天测试
curl -X POST http://localhost:8080/api/v1/ai/chat \
-H "Content-Type: application/json" \
-d '{"message": "Hello, how are you?"}'
# 指定模型聊天
curl -X POST "http://localhost:8080/api/v1/ai/chat?model=gpt-3.5-turbo" \
-H "Content-Type: application/json" \
-d '{"message": "What is Spring AI?"}'
3. 监控指标
- Prometheus 指标: http://localhost:8080/actuator/prometheus
- 健康检查: http://localhost:8080/actuator/health
- 应用信息: http://localhost:8080/actuator/info
- 指标详情: http://localhost:8080/actuator/metrics
📊 十三、设计模式总结
| 设计模式 | 应用位置 | 解决的问题 | 核心实现类 |
|---|---|---|---|
| 工厂模式 | AI 客户端创建 | 统一创建不同 AI 提供商客户端 | AiClientFactory |
| 策略模式 | AI 提供商切换 | 支持多种 AI 服务提供商 | AiProviderStrategy |
| 模板方法 | 请求处理流程 | 统一 AI 请求处理流程 | AbstractAiClient |
| 适配器模式 | 模型格式转换 | 转换不同模型的请求/响应格式 | AiModelAdapter |
| 责任链模式 | 拦截器处理 | 实现可插拔的拦截器链 | AiClientInterceptorChain |
| 构建者模式 | 请求对象构建 | 构建复杂的请求对象 | ChatRequestBuilder |
| 单例模式 | 配置管理 | 确保配置对象唯一性 | AiProperties |
| 代理模式 | 监控增强 | 为 AI 客户端添加监控功能 | AiClientProxy |
🔧 十四、扩展指南
1. 添加新的 AI 提供商
@Component
public class CustomAiStrategy implements AiProviderStrategy {
@Override
public String getProviderName() {
return "custom";
}
@Override
public ChatResponse chat(ChatRequest request) {
// 实现自定义 AI 服务调用逻辑
return null;
}
// 其他方法实现...
}
2. 添加自定义拦截器
@Component
public class CustomInterceptor implements AiClientInterceptor {
@Override
public boolean preHandle(ChatRequest request) {
// 自定义前置处理逻辑
return true;
}
@Override
public void postHandle(ChatRequest request, ChatResponse response) {
// 自定义后置处理逻辑
}
}
3. 配置自定义模型
spring:
ai:
custom:
enabled: true
api-key: ${CUSTOM_API_KEY}
endpoint: https://api.custom-ai.com/v1
models:
- custom-model-1
- custom-model-2
这个完整的 Spring AI 模拟框架展示了如何通过多种设计模式的组合,构建一个灵活、可扩展、可维护的 AI 服务集成框架。每个设计模式都解决了特定的问题,共同构成了一个完整的解决方案。
857

被折叠的 条评论
为什么被折叠?



