Spring AI 设计模式综合应用与完整工程实现

「鸿蒙心迹」“2025・领航者闯关记“主题征文活动 10w+人浏览 601人参与

Spring AI 设计模式综合应用与完整工程实现

🎯 一、项目概述

本项目将模拟实现 Spring AI 的核心原理,通过多种设计模式的组合使用,展示一个完整的 AI 服务集成框架。我们将实现以下核心功能:

  1. 统一 AI 客户端接口 - 支持多种 AI 模型(OpenAI、Azure、本地模型)
  2. 智能路由机制 - 根据模型类型自动选择适配器
  3. 可扩展插件架构 - 支持自定义模型扩展
  4. 完整配置管理 - 基于 Spring Boot 的自动配置
  5. 监控和指标收集 - 集成 Micrometer
  6. 完整的示例应用 - 演示如何使用

📁 二、完整项目结构

spring-ai-simulation/
├── pom.xml                              # Maven 依赖配置
├── src/
│   ├── main/
│   │   ├── java/
│   │   │   └── com/
│   │   │       └── example/
│   │   │           └── ai/
│   │   │               ├── SpringAiSimulationApplication.java
│   │   │               ├── annotation/
│   │   │               │   └── EnableAiClient.java
│   │   │               ├── config/
│   │   │               │   ├── AiAutoConfiguration.java
│   │   │               │   ├── AiProperties.java
│   │   │               │   ├── AiClientConfig.java
│   │   │               │   └── Condition/
│   │   │               │       ├── OnAiModelCondition.java
│   │   │               │       └── OnAiProviderCondition.java
│   │   │               ├── core/
│   │   │               │   ├── AiClient.java
│   │   │               │   ├── ChatClient.java
│   │   │               │   ├── EmbeddingClient.java
│   │   │               │   ├── model/
│   │   │               │   │   ├── ChatRequest.java
│   │   │               │   │   ├── ChatResponse.java
│   │   │               │   │   ├── Message.java
│   │   │               │   │   ├── Choice.java
│   │   │               │   │   ├── Usage.java
│   │   │               │   │   ├── EmbeddingRequest.java
│   │   │               │   │   └── EmbeddingResponse.java
│   │   │               │   ├── template/
│   │   │               │   │   ├── AiClientTemplate.java
│   │   │               │   │   ├── BaseAiClient.java
│   │   │               │   │   ├── AbstractAiClient.java
│   │   │               │   │   └── ChatClientTemplate.java
│   │   │               │   └── exception/
│   │   │               │       ├── AiClientException.java
│   │   │               │       ├── AiServiceException.java
│   │   │               │       ├── RateLimitException.java
│   │   │               │       └── UnauthorizedException.java
│   │   │               ├── pattern/
│   │   │               │   ├── factory/
│   │   │               │   │   ├── AiClientFactory.java
│   │   │               │   │   └── AiAdapterFactory.java
│   │   │               │   ├── strategy/
│   │   │               │   │   ├── AiStrategy.java
│   │   │               │   │   ├── AiProviderStrategy.java
│   │   │               │   │   ├── OpenAiStrategy.java
│   │   │               │   │   ├── AzureAiStrategy.java
│   │   │               │   │   └── LocalAiStrategy.java
│   │   │               │   ├── adapter/
│   │   │               │   │   ├── AiModelAdapter.java
│   │   │               │   │   ├── OpenAiAdapter.java
│   │   │               │   │   ├── AzureAiAdapter.java
│   │   │               │   │   └── LocalAiAdapter.java
│   │   │               │   ├── builder/
│   │   │               │   │   └── ChatRequestBuilder.java
│   │   │               │   ├── chain/
│   │   │               │   │   ├── AiClientInterceptor.java
│   │   │               │   │   └── AiClientInterceptorChain.java
│   │   │               │   └── proxy/
│   │   │               │       └── AiClientProxy.java
│   │   │               ├── provider/
│   │   │               │   ├── OpenAiClient.java
│   │   │               │   ├── AzureAiClient.java
│   │   │               │   └── LocalAiClient.java
│   │   │               ├── service/
│   │   │               │   ├── AiService.java
│   │   │               │   ├── ChatService.java
│   │   │               │   └── EmbeddingService.java
│   │   │               ├── interceptor/
│   │   │               │   ├── LoggingInterceptor.java
│   │   │               │   ├── MetricsInterceptor.java
│   │   │               │   └── RetryInterceptor.java
│   │   │               ├── metrics/
│   │   │               │   └── AiClientMetrics.java
│   │   │               └── router/
│   │   │                   └── ModelRouter.java
│   │   └── resources/
│   │       ├── application.yml
│   │       └── application-dev.yml
│   └── test/
│       └── java/
│           └── com/
│               └── example/
│                   └── ai/
│                       ├── SpringAiSimulationApplicationTests.java
│                       └── service/
│                           └── AiServiceTest.java

📦 三、Maven 依赖配置

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 
         http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>com.example</groupId>
    <artifactId>spring-ai-simulation</artifactId>
    <version>1.0.0</version>
    <name>Spring AI Simulation</name>
    <description>A simulation framework demonstrating Spring AI design patterns</description>

    <parent>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-parent</artifactId>
        <version>3.1.5</version>
        <relativePath/>
    </parent>

    <properties>
        <java.version>17</java.version>
        <jackson.version>2.15.2</jackson.version>
        <micrometer.version>1.11.5</micrometer.version>
    </properties>

    <dependencies>
        <!-- Spring Boot Starters -->
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-actuator</artifactId>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-aop</artifactId>
        </dependency>
        
        <!-- Monitoring -->
        <dependency>
            <groupId>io.micrometer</groupId>
            <artifactId>micrometer-core</artifactId>
        </dependency>
        <dependency>
            <groupId>io.micrometer</groupId>
            <artifactId>micrometer-registry-prometheus</artifactId>
        </dependency>
        
        <!-- HTTP Client -->
        <dependency>
            <groupId>org.apache.httpcomponents.client5</groupId>
            <artifactId>httpclient5</artifactId>
        </dependency>
        
        <!-- JSON Processing -->
        <dependency>
            <groupId>com.fasterxml.jackson.core</groupId>
            <artifactId>jackson-databind</artifactId>
            <version>${jackson.version}</version>
        </dependency>
        
        <!-- Utilities -->
        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
            <optional>true</optional>
        </dependency>
        <dependency>
            <groupId>com.google.guava</groupId>
            <artifactId>guava</artifactId>
            <version>32.1.3-jre</version>
        </dependency>
        
        <!-- Testing -->
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
            <scope>test</scope>
        </dependency>
        <dependency>
            <groupId>org.mockito</groupId>
            <artifactId>mockito-core</artifactId>
            <scope>test</scope>
        </dependency>
    </dependencies>

    <build>
        <plugins>
            <plugin>
                <groupId>org.springframework.boot</groupId>
                <artifactId>spring-boot-maven-plugin</artifactId>
                <configuration>
                    <excludes>
                        <exclude>
                            <groupId>org.projectlombok</groupId>
                            <artifactId>lombok</artifactId>
                        </exclude>
                    </excludes>
                </configuration>
            </plugin>
        </plugins>
    </build>
</project>

⚙️ 四、核心配置文件

1. application.yml

# 应用配置
server:
  port: 8080
  servlet:
    context-path: /api

spring:
  application:
    name: spring-ai-simulation
  
  # AI 配置
  ai:
    enabled: true
    default-provider: openai
    default-model: gpt-3.5-turbo
    enable-metrics: true
    
    # OpenAI 配置
    openai:
      enabled: true
      api-key: ${OPENAI_API_KEY:your-openai-key}
      base-url: https://api.openai.com/v1
      timeout: 30s
      max-retries: 3
    
    # Azure AI 配置
    azure:
      enabled: false
      endpoint: https://your-resource.openai.azure.com
      api-key: ${AZURE_API_KEY:your-azure-key}
      deployment-name: gpt-35-turbo
      api-version: 2023-12-01-preview
    
    # 本地模型配置
    local:
      enabled: false
      model-path: /path/to/local/model
      device: cpu
    
    # HTTP 客户端配置
    http-client:
      max-connections: 50
      connection-timeout: 10s
      read-timeout: 30s
      keep-alive: 5m
    
    # 监控配置
    monitoring:
      enabled: true
      metrics-prefix: ai.client
      slow-query-threshold: 1000ms
      enable-tracing: true

# 日志配置
logging:
  level:
    com.example.ai: DEBUG
  pattern:
    console: "%d{yyyy-MM-dd HH:mm:ss} [%thread] %-5level %logger{36} - %msg%n"

# Actuator 配置
management:
  endpoints:
    web:
      exposure:
        include: health,info,metrics,prometheus
  metrics:
    export:
      prometheus:
        enabled: true
  endpoint:
    health:
      show-details: always

2. 应用启动类

package com.example.ai;

import com.example.ai.annotation.EnableAiClient;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;

@SpringBootApplication
@EnableAiClient
public class SpringAiSimulationApplication {
    public static void main(String[] args) {
        SpringApplication.run(SpringAiSimulationApplication.class, args);
    }
}

🔧 五、核心设计模式实现

1. 策略模式 (Strategy Pattern) - AI 提供商策略

// 策略接口
package com.example.ai.pattern.strategy;

import com.example.ai.core.model.ChatRequest;
import com.example.ai.core.model.ChatResponse;
import com.example.ai.core.model.EmbeddingRequest;
import com.example.ai.core.model.EmbeddingResponse;

/**
 * AI 提供商策略接口
 */
public interface AiProviderStrategy {
    String getProviderName();
    
    ChatResponse chat(ChatRequest request);
    
    EmbeddingResponse embed(EmbeddingRequest request);
    
    boolean supports(String model);
    
    void validateApiKey(String apiKey);
}
// OpenAI 策略实现
package com.example.ai.pattern.strategy;

import com.example.ai.core.model.ChatRequest;
import com.example.ai.core.model.ChatResponse;
import com.example.ai.core.model.EmbeddingRequest;
import com.example.ai.core.model.EmbeddingResponse;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.web.client.RestTemplate;

@Slf4j
public class OpenAiStrategy implements AiProviderStrategy {
    
    private final RestTemplate restTemplate;
    private final ObjectMapper objectMapper;
    private final String apiKey;
    private final String baseUrl;
    
    public OpenAiStrategy(RestTemplate restTemplate, 
                         ObjectMapper objectMapper,
                         String apiKey, 
                         String baseUrl) {
        this.restTemplate = restTemplate;
        this.objectMapper = objectMapper;
        this.apiKey = apiKey;
        this.baseUrl = baseUrl;
    }
    
    @Override
    public String getProviderName() {
        return "openai";
    }
    
    @Override
    public ChatResponse chat(ChatRequest request) {
        log.debug("OpenAI Strategy: Processing chat request for model: {}", request.getModel());
        
        // 构建 OpenAI 特定格式的请求
        var openAiRequest = buildOpenAiChatRequest(request);
        
        // 发送请求
        var headers = new HttpHeaders();
        headers.setContentType(MediaType.APPLICATION_JSON);
        headers.setBearerAuth(apiKey);
        
        var entity = new HttpEntity<>(openAiRequest, headers);
        var url = String.format("%s/chat/completions", baseUrl);
        
        try {
            var response = restTemplate.postForObject(url, entity, String.class);
            return parseOpenAiChatResponse(response);
        } catch (Exception e) {
            throw new RuntimeException("OpenAI API call failed", e);
        }
    }
    
    @Override
    public EmbeddingResponse embed(EmbeddingRequest request) {
        // 嵌入向量化实现
        return null;
    }
    
    @Override
    public boolean supports(String model) {
        return model != null && 
               (model.startsWith("gpt-") || 
                model.startsWith("text-embedding-"));
    }
    
    @Override
    public void validateApiKey(String apiKey) {
        if (apiKey == null || apiKey.trim().isEmpty()) {
            throw new IllegalArgumentException("OpenAI API key is required");
        }
        if (!apiKey.startsWith("sk-")) {
            log.warn("OpenAI API key may be invalid, should start with 'sk-'");
        }
    }
    
    private Object buildOpenAiChatRequest(ChatRequest request) {
        // 转换逻辑
        return null;
    }
    
    private ChatResponse parseOpenAiChatResponse(String response) {
        // 解析逻辑
        return null;
    }
}

2. 模板方法模式 (Template Method Pattern) - AI 客户端模板

// 抽象模板类
package com.example.ai.core.template;

import com.example.ai.core.AiClient;
import com.example.ai.core.model.ChatRequest;
import com.example.ai.core.model.ChatResponse;
import com.example.ai.core.model.EmbeddingRequest;
import com.example.ai.core.model.EmbeddingResponse;
import com.example.ai.pattern.chain.AiClientInterceptorChain;
import lombok.extern.slf4j.Slf4j;

import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;

/**
 * 抽象 AI 客户端模板
 * 定义 AI 请求的标准处理流程
 */
@Slf4j
public abstract class AbstractAiClient implements AiClient {
    
    protected final AiClientInterceptorChain interceptorChain;
    protected final Executor executor;
    
    protected AbstractAiClient(AiClientInterceptorChain interceptorChain, 
                              Executor executor) {
        this.interceptorChain = interceptorChain;
        this.executor = executor;
    }
    
    @Override
    public ChatResponse chat(ChatRequest request) {
        // 模板方法:定义标准处理流程
        long startTime = System.currentTimeMillis();
        
        try {
            // 1. 预处理
            preProcess(request);
            
            // 2. 执行拦截器链前置处理
            interceptorChain.applyPreHandle(request);
            
            // 3. 执行实际业务逻辑(由子类实现)
            ChatResponse response = doChat(request);
            
            // 4. 执行拦截器链后置处理
            interceptorChain.applyPostHandle(request, response);
            
            // 5. 后处理
            postProcess(request, response, startTime);
            
            return response;
            
        } catch (Exception e) {
            // 6. 异常处理
            handleException(request, e, startTime);
            throw e;
        }
    }
    
    @Override
    public CompletableFuture<ChatResponse> chatAsync(ChatRequest request) {
        return CompletableFuture.supplyAsync(() -> chat(request), executor);
    }
    
    /**
     * 预处理(钩子方法)
     */
    protected void preProcess(ChatRequest request) {
        log.debug("Pre-processing chat request: {}", request);
        validateRequest(request);
    }
    
    /**
     * 实际聊天逻辑(由子类实现)
     */
    protected abstract ChatResponse doChat(ChatRequest request);
    
    /**
     * 后处理(钩子方法)
     */
    protected void postProcess(ChatRequest request, ChatResponse response, long startTime) {
        long duration = System.currentTimeMillis() - startTime;
        log.debug("Chat completed in {} ms. Tokens used: {}", 
                 duration, response.getUsage().getTotalTokens());
        
        // 发布事件
        publishChatCompletedEvent(request, response, duration);
    }
    
    /**
     * 异常处理
     */
    protected void handleException(ChatRequest request, Exception e, long startTime) {
        long duration = System.currentTimeMillis() - startTime;
        log.error("Chat request failed after {} ms: {}", duration, e.getMessage(), e);
        
        // 发布异常事件
        publishChatFailedEvent(request, e, duration);
    }
    
    /**
     * 验证请求
     */
    protected void validateRequest(ChatRequest request) {
        if (request == null) {
            throw new IllegalArgumentException("ChatRequest cannot be null");
        }
        if (request.getMessages() == null || request.getMessages().isEmpty()) {
            throw new IllegalArgumentException("Messages cannot be empty");
        }
    }
    
    // 其他钩子方法...
    protected abstract void publishChatCompletedEvent(ChatRequest request, 
                                                     ChatResponse response, 
                                                     long duration);
    
    protected abstract void publishChatFailedEvent(ChatRequest request, 
                                                  Exception e, 
                                                  long duration);
}

3. 工厂模式 (Factory Pattern) - AI 客户端工厂

// AI 客户端工厂
package com.example.ai.pattern.factory;

import com.example.ai.core.AiClient;
import com.example.ai.config.AiProperties;
import com.example.ai.pattern.strategy.AiProviderStrategy;
import com.example.ai.pattern.strategy.OpenAiStrategy;
import com.example.ai.pattern.strategy.AzureAiStrategy;
import com.example.ai.pattern.strategy.LocalAiStrategy;
import com.example.ai.provider.OpenAiClient;
import com.example.ai.provider.AzureAiClient;
import com.example.ai.provider.LocalAiClient;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.client.RestTemplate;

import java.util.HashMap;
import java.util.Map;

/**
 * AI 客户端工厂
 * 根据配置创建不同类型的 AI 客户端
 */
@Slf4j
public class AiClientFactory {
    
    private final Map<String, AiClient> clientCache = new HashMap<>();
    private final AiProperties aiProperties;
    private final RestTemplate restTemplate;
    private final ObjectMapper objectMapper;
    
    public AiClientFactory(AiProperties aiProperties, 
                          RestTemplate restTemplate, 
                          ObjectMapper objectMapper) {
        this.aiProperties = aiProperties;
        this.restTemplate = restTemplate;
        this.objectMapper = objectMapper;
    }
    
    /**
     * 根据提供商创建 AI 客户端
     */
    public AiClient createAiClient(String provider) {
        return clientCache.computeIfAbsent(provider, this::createNewAiClient);
    }
    
    /**
     * 根据模型自动选择客户端
     */
    public AiClient createAiClientForModel(String model) {
        String provider = determineProviderByModel(model);
        return createAiClient(provider);
    }
    
    private AiClient createNewAiClient(String provider) {
        log.info("Creating new AI client for provider: {}", provider);
        
        switch (provider.toLowerCase()) {
            case "openai":
                return createOpenAiClient();
            case "azure":
                return createAzureAiClient();
            case "local":
                return createLocalAiClient();
            default:
                throw new IllegalArgumentException("Unsupported AI provider: " + provider);
        }
    }
    
    private AiClient createOpenAiClient() {
        var properties = aiProperties.getOpenai();
        var strategy = new OpenAiStrategy(
            restTemplate,
            objectMapper,
            properties.getApiKey(),
            properties.getBaseUrl()
        );
        
        return new OpenAiClient(strategy, aiProperties);
    }
    
    private AiClient createAzureAiClient() {
        var properties = aiProperties.getAzure();
        var strategy = new AzureAiStrategy(
            restTemplate,
            objectMapper,
            properties.getApiKey(),
            properties.getEndpoint(),
            properties.getDeploymentName(),
            properties.getApiVersion()
        );
        
        return new AzureAiClient(strategy, aiProperties);
    }
    
    private AiClient createLocalAiClient() {
        var properties = aiProperties.getLocal();
        var strategy = new LocalAiStrategy(
            properties.getModelPath(),
            properties.getDevice()
        );
        
        return new LocalAiClient(strategy, aiProperties);
    }
    
    private String determineProviderByModel(String model) {
        if (model == null) {
            return aiProperties.getDefaultProvider();
        }
        
        if (model.startsWith("gpt-") || model.startsWith("text-")) {
            return "openai";
        } else if (model.contains("azure")) {
            return "azure";
        } else if (model.contains("local") || model.contains("llama")) {
            return "local";
        }
        
        return aiProperties.getDefaultProvider();
    }
    
    /**
     * 获取所有支持的提供商
     */
    public Map<String, AiProviderStrategy> getSupportedStrategies() {
        Map<String, AiProviderStrategy> strategies = new HashMap<>();
        
        if (aiProperties.getOpenai().isEnabled()) {
            strategies.put("openai", new OpenAiStrategy(
                restTemplate, objectMapper,
                aiProperties.getOpenai().getApiKey(),
                aiProperties.getOpenai().getBaseUrl()
            ));
        }
        
        if (aiProperties.getAzure().isEnabled()) {
            strategies.put("azure", new AzureAiStrategy(
                restTemplate, objectMapper,
                aiProperties.getAzure().getApiKey(),
                aiProperties.getAzure().getEndpoint(),
                aiProperties.getAzure().getDeploymentName(),
                aiProperties.getAzure().getApiVersion()
            ));
        }
        
        if (aiProperties.getLocal().isEnabled()) {
            strategies.put("local", new LocalAiStrategy(
                aiProperties.getLocal().getModelPath(),
                aiProperties.getLocal().getDevice()
            ));
        }
        
        return strategies;
    }
}

4. 适配器模式 (Adapter Pattern) - 模型适配器

// 适配器接口
package com.example.ai.pattern.adapter;

import com.example.ai.core.model.ChatRequest;
import com.example.ai.core.model.ChatResponse;

/**
 * AI 模型适配器接口
 * 将统一请求转换为特定模型格式
 */
public interface AiModelAdapter {
    
    String getModelType();
    
    boolean supports(String model);
    
    Object adaptRequest(ChatRequest request);
    
    ChatResponse adaptResponse(Object rawResponse, ChatRequest originalRequest);
}
// OpenAI 适配器实现
package com.example.ai.pattern.adapter;

import com.example.ai.core.model.ChatRequest;
import com.example.ai.core.model.ChatResponse;
import com.example.ai.core.model.Message;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

@Slf4j
@RequiredArgsConstructor
public class OpenAiAdapter implements AiModelAdapter {
    
    private final ObjectMapper objectMapper;
    
    @Override
    public String getModelType() {
        return "openai";
    }
    
    @Override
    public boolean supports(String model) {
        return model != null && 
               (model.startsWith("gpt-") || 
                model.startsWith("text-"));
    }
    
    @Override
    public Object adaptRequest(ChatRequest request) {
        Map<String, Object> openAiRequest = new HashMap<>();
        
        // 转换模型名称
        openAiRequest.put("model", request.getModel());
        
        // 转换消息
        List<Map<String, String>> messages = request.getMessages().stream()
            .map(this::convertMessage)
            .collect(Collectors.toList());
        openAiRequest.put("messages", messages);
        
        // 转换参数
        if (request.getTemperature() != null) {
            openAiRequest.put("temperature", request.getTemperature());
        }
        if (request.getMaxTokens() != null) {
            openAiRequest.put("max_tokens", request.getMaxTokens());
        }
        if (request.getTopP() != null) {
            openAiRequest.put("top_p", request.getTopP());
        }
        
        // 流式响应
        if (request.isStream()) {
            openAiRequest.put("stream", true);
        }
        
        log.debug("Adapted OpenAI request: {}", openAiRequest);
        return openAiRequest;
    }
    
    @Override
    public ChatResponse adaptResponse(Object rawResponse, ChatRequest originalRequest) {
        try {
            String responseJson = objectMapper.writeValueAsString(rawResponse);
            Map<String, Object> responseMap = objectMapper.readValue(responseJson, Map.class);
            
            return ChatResponse.builder()
                .id((String) responseMap.get("id"))
                .model((String) responseMap.get("model"))
                .created((Integer) responseMap.get("created"))
                .choices(extractChoices(responseMap))
                .usage(extractUsage(responseMap))
                .build();
                
        } catch (JsonProcessingException e) {
            throw new RuntimeException("Failed to adapt OpenAI response", e);
        }
    }
    
    private Map<String, String> convertMessage(Message message) {
        Map<String, String> result = new HashMap<>();
        result.put("role", message.getRole().name().toLowerCase());
        result.put("content", message.getContent());
        return result;
    }
    
    private List<ChatResponse.Choice> extractChoices(Map<String, Object> responseMap) {
        // 提取 choices
        return null;
    }
    
    private ChatResponse.Usage extractUsage(Map<String, Object> responseMap) {
        // 提取 usage
        return null;
    }
}

5. 责任链模式 (Chain of Responsibility) - 拦截器链

// 拦截器接口
package com.example.ai.pattern.chain;

import com.example.ai.core.model.ChatRequest;
import com.example.ai.core.model.ChatResponse;

/**
 * AI 客户端拦截器
 */
public interface AiClientInterceptor {
    
    /**
     * 前置处理
     */
    default boolean preHandle(ChatRequest request) {
        return true;
    }
    
    /**
     * 后置处理
     */
    default void postHandle(ChatRequest request, ChatResponse response) {
    }
    
    /**
     * 异常处理
     */
    default void afterCompletion(ChatRequest request, Exception ex) {
    }
}
// 拦截器链
package com.example.ai.pattern.chain;

import com.example.ai.core.model.ChatRequest;
import com.example.ai.core.model.ChatResponse;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;

import java.util.List;

@Slf4j
@RequiredArgsConstructor
public class AiClientInterceptorChain {
    
    private final List<AiClientInterceptor> interceptors;
    
    /**
     * 应用所有拦截器的前置处理
     */
    public boolean applyPreHandle(ChatRequest request) {
        for (AiClientInterceptor interceptor : interceptors) {
            if (!interceptor.preHandle(request)) {
                log.debug("Interceptor {} prevented request processing", 
                         interceptor.getClass().getSimpleName());
                return false;
            }
        }
        return true;
    }
    
    /**
     * 应用所有拦截器的后置处理
     */
    public void applyPostHandle(ChatRequest request, ChatResponse response) {
        for (AiClientInterceptor interceptor : interceptors) {
            try {
                interceptor.postHandle(request, response);
            } catch (Exception e) {
                log.error("Interceptor postHandle failed", e);
            }
        }
    }
    
    /**
     * 应用所有拦截器的完成处理
     */
    public void applyAfterCompletion(ChatRequest request, Exception ex) {
        for (AiClientInterceptor interceptor : interceptors) {
            try {
                interceptor.afterCompletion(request, ex);
            } catch (Exception e) {
                log.error("Interceptor afterCompletion failed", e);
            }
        }
    }
}

6. 构建者模式 (Builder Pattern) - 请求构建器

// 聊天请求构建器
package com.example.ai.pattern.builder;

import com.example.ai.core.model.ChatRequest;
import com.example.ai.core.model.Message;

import java.util.ArrayList;
import java.util.List;

/**
 * 聊天请求构建器
 * 使用构建者模式创建复杂的请求对象
 */
public class ChatRequestBuilder {
    
    private String model;
    private final List<Message> messages = new ArrayList<>();
    private Double temperature;
    private Integer maxTokens;
    private Double topP;
    private boolean stream = false;
    
    public ChatRequestBuilder model(String model) {
        this.model = model;
        return this;
    }
    
    public ChatRequestBuilder message(Message message) {
        this.messages.add(message);
        return this;
    }
    
    public ChatRequestBuilder message(String role, String content) {
        return message(Message.builder()
            .role(Message.Role.valueOf(role.toUpperCase()))
            .content(content)
            .build());
    }
    
    public ChatRequestBuilder systemMessage(String content) {
        return message("system", content);
    }
    
    public ChatRequestBuilder userMessage(String content) {
        return message("user", content);
    }
    
    public ChatRequestBuilder assistantMessage(String content) {
        return message("assistant", content);
    }
    
    public ChatRequestBuilder temperature(Double temperature) {
        this.temperature = temperature;
        return this;
    }
    
    public ChatRequestBuilder maxTokens(Integer maxTokens) {
        this.maxTokens = maxTokens;
        return this;
    }
    
    public ChatRequestBuilder topP(Double topP) {
        this.topP = topP;
        return this;
    }
    
    public ChatRequestBuilder stream(boolean stream) {
        this.stream = stream;
        return this;
    }
    
    public ChatRequest build() {
        if (model == null || model.trim().isEmpty()) {
            throw new IllegalArgumentException("Model is required");
        }
        if (messages.isEmpty()) {
            throw new IllegalArgumentException("At least one message is required");
        }
        
        return ChatRequest.builder()
            .model(model)
            .messages(new ArrayList<>(messages))
            .temperature(temperature)
            .maxTokens(maxTokens)
            .topP(topP)
            .stream(stream)
            .build();
    }
}

🔄 六、模型路由器 (Model Router)

package com.example.ai.router;

import com.example.ai.core.AiClient;
import com.example.ai.pattern.factory.AiClientFactory;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
 * 模型路由器
 * 根据模型名称路由到合适的 AI 客户端
 */
@Slf4j
@Component
@RequiredArgsConstructor
public class ModelRouter {
    
    private final AiClientFactory aiClientFactory;
    private final Map<String, AiClient> clientCache = new ConcurrentHashMap<>();
    
    /**
     * 路由到合适的 AI 客户端
     */
    public AiClient route(String model) {
        return clientCache.computeIfAbsent(model, this::createClientForModel);
    }
    
    /**
     * 获取所有可用的模型
     */
    public Map<String, String> getAvailableModels() {
        Map<String, String> models = new ConcurrentHashMap<>();
        
        // OpenAI 模型
        models.put("gpt-3.5-turbo", "openai");
        models.put("gpt-4", "openai");
        models.put("text-embedding-ada-002", "openai");
        
        // Azure 模型
        models.put("gpt-35-turbo", "azure");
        models.put("gpt-4-azure", "azure");
        
        // 本地模型
        models.put("llama-2-7b", "local");
        models.put("vicuna-13b", "local");
        
        return models;
    }
    
    /**
     * 根据模型选择最佳提供商
     */
    public String selectBestProvider(String model, String preferredProvider) {
        if (preferredProvider != null) {
            return preferredProvider;
        }
        
        var availableModels = getAvailableModels();
        return availableModels.getOrDefault(model, "openai");
    }
    
    private AiClient createClientForModel(String model) {
        String provider = determineProvider(model);
        log.info("Creating AI client for model: {} -> provider: {}", model, provider);
        
        return aiClientFactory.createAiClient(provider);
    }
    
    private String determineProvider(String model) {
        if (model == null) {
            return "openai";
        }
        
        if (model.startsWith("gpt-") || model.startsWith("text-")) {
            return "openai";
        } else if (model.contains("azure") || model.contains("AZURE")) {
            return "azure";
        } else if (model.contains("local") || model.contains("llama")) {
            return "local";
        }
        
        return "openai";
    }
}

📊 七、监控指标收集

package com.example.ai.metrics;

import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.Timer;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;

/**
 * AI 客户端监控指标
 */
@Slf4j
@Component
@RequiredArgsConstructor
public class AiClientMetrics {
    
    private final MeterRegistry meterRegistry;
    private final Map<String, Timer> timers = new ConcurrentHashMap<>();
    private final Map<String, Counter> counters = new ConcurrentHashMap<>();
    
    /**
     * 记录请求成功
     */
    public void recordSuccess(String provider, String model, String operation, long duration) {
        getTimer(provider, model, operation)
            .record(duration, TimeUnit.MILLISECONDS);
        
        incrementCounter("ai.client.requests.success", 
            Map.of("provider", provider, "model", model, "operation", operation));
        
        // 记录响应时间分布
        meterRegistry.timer("ai.client.response.time", 
            "provider", provider, 
            "model", model, 
            "operation", operation)
            .record(duration, TimeUnit.MILLISECONDS);
    }
    
    /**
     * 记录请求失败
     */
    public void recordError(String provider, String model, String operation, String errorType) {
        incrementCounter("ai.client.requests.error", 
            Map.of("provider", provider, 
                   "model", model, 
                   "operation", operation, 
                   "error_type", errorType));
    }
    
    /**
     * 记录令牌使用
     */
    public void recordTokenUsage(String provider, String model, 
                                int promptTokens, int completionTokens, int totalTokens) {
        meterRegistry.counter("ai.client.tokens.prompt", 
            "provider", provider, "model", model)
            .increment(promptTokens);
        
        meterRegistry.counter("ai.client.tokens.completion", 
            "provider", provider, "model", model)
            .increment(completionTokens);
        
        meterRegistry.counter("ai.client.tokens.total", 
            "provider", provider, "model", model)
            .increment(totalTokens);
    }
    
    /**
     * 记录流式响应块
     */
    public void recordStreamChunk(String provider, String model) {
        incrementCounter("ai.client.stream.chunks", 
            Map.of("provider", provider, "model", model));
    }
    
    private Timer getTimer(String provider, String model, String operation) {
        String key = provider + ":" + model + ":" + operation;
        
        return timers.computeIfAbsent(key, k -> 
            Timer.builder("ai.client.requests")
                .tag("provider", provider)
                .tag("model", model)
                .tag("operation", operation)
                .publishPercentiles(0.5, 0.95, 0.99)
                .sla(java.time.Duration.ofMillis(100), 
                    java.time.Duration.ofMillis(500), 
                    java.time.Duration.ofMillis(1000))
                .register(meterRegistry)
        );
    }
    
    private void incrementCounter(String name, Map<String, String> tags) {
        Counter.Builder builder = Counter.builder(name);
        tags.forEach(builder::tag);
        builder.register(meterRegistry).increment();
    }
}

🎯 八、使用示例

1. 聊天服务示例

package com.example.ai.service;

import com.example.ai.core.AiClient;
import com.example.ai.core.model.ChatRequest;
import com.example.ai.core.model.ChatResponse;
import com.example.ai.core.model.Message;
import com.example.ai.router.ModelRouter;
import com.example.ai.metrics.AiClientMetrics;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;

import java.util.List;
import java.util.concurrent.CompletableFuture;

@Slf4j
@Service
@RequiredArgsConstructor
public class ChatService {
    
    private final ModelRouter modelRouter;
    private final AiClientMetrics metrics;
    
    /**
     * 同步聊天
     */
    public String chat(String model, String userMessage) {
        long startTime = System.currentTimeMillis();
        
        try {
            // 获取合适的 AI 客户端
            AiClient aiClient = modelRouter.route(model);
            
            // 构建请求
            ChatRequest request = ChatRequest.builder()
                .model(model)
                .messages(List.of(
                    Message.userMessage(userMessage)
                ))
                .temperature(0.7)
                .maxTokens(1000)
                .build();
            
            // 执行聊天
            ChatResponse response = aiClient.chat(request);
            String content = response.getContent();
            
            // 记录指标
            long duration = System.currentTimeMillis() - startTime;
            metrics.recordSuccess(
                "openai", model, "chat", duration
            );
            metrics.recordTokenUsage(
                "openai", model,
                response.getUsage().getPromptTokens(),
                response.getUsage().getCompletionTokens(),
                response.getUsage().getTotalTokens()
            );
            
            return content;
            
        } catch (Exception e) {
            long duration = System.currentTimeMillis() - startTime;
            metrics.recordError("openai", model, "chat", e.getClass().getSimpleName());
            throw e;
        }
    }
    
    /**
     * 异步聊天
     */
    public CompletableFuture<String> chatAsync(String model, String userMessage) {
        return CompletableFuture.supplyAsync(() -> chat(model, userMessage));
    }
    
    /**
     * 带上下文的聊天
     */
    public String chatWithContext(String model, String systemPrompt, 
                                 List<Message> history, String userMessage) {
        AiClient aiClient = modelRouter.route(model);
        
        // 构建消息列表
        List<Message> messages = new java.util.ArrayList<>();
        messages.add(Message.systemMessage(systemPrompt));
        messages.addAll(history);
        messages.add(Message.userMessage(userMessage));
        
        ChatRequest request = ChatRequest.builder()
            .model(model)
            .messages(messages)
            .temperature(0.7)
            .maxTokens(2000)
            .build();
        
        return aiClient.chat(request).getContent();
    }
}

2. REST 控制器

package com.example.ai.controller;

import com.example.ai.service.ChatService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;

import java.util.Map;
import java.util.concurrent.CompletableFuture;

@Slf4j
@RestController
@RequestMapping("/api/v1/ai")
@RequiredArgsConstructor
public class AiController {
    
    private final ChatService chatService;
    
    /**
     * 简单聊天接口
     */
    @PostMapping("/chat")
    public ResponseEntity<Map<String, Object>> chat(
            @RequestParam(defaultValue = "gpt-3.5-turbo") String model,
            @RequestBody Map<String, String> request) {
        
        String message = request.get("message");
        if (message == null || message.trim().isEmpty()) {
            return ResponseEntity.badRequest()
                .body(Map.of("error", "Message is required"));
        }
        
        try {
            String response = chatService.chat(model, message);
            return ResponseEntity.ok(Map.of(
                "model", model,
                "response", response
            ));
        } catch (Exception e) {
            log.error("Chat failed", e);
            return ResponseEntity.internalServerError()
                .body(Map.of("error", e.getMessage()));
        }
    }
    
    /**
     * 异步聊天接口
     */
    @PostMapping("/chat/async")
    public CompletableFuture<ResponseEntity<Map<String, Object>>> chatAsync(
            @RequestParam(defaultValue = "gpt-3.5-turbo") String model,
            @RequestBody Map<String, String> request) {
        
        return chatService.chatAsync(model, request.get("message"))
            .thenApply(response -> ResponseEntity.ok(Map.of(
                "model", model,
                "response", response
            )))
            .exceptionally(e -> ResponseEntity.internalServerError()
                .body(Map.of("error", e.getMessage())));
    }
    
    /**
     * 健康检查
     */
    @GetMapping("/health")
    public ResponseEntity<Map<String, Object>> health() {
        return ResponseEntity.ok(Map.of(
            "status", "UP",
            "timestamp", System.currentTimeMillis()
        ));
    }
    
    /**
     * 获取支持的模型
     */
    @GetMapping("/models")
    public ResponseEntity<Map<String, Object>> getAvailableModels() {
        return ResponseEntity.ok(Map.of(
            "models", Map.of(
                "gpt-3.5-turbo", "OpenAI GPT-3.5 Turbo",
                "gpt-4", "OpenAI GPT-4",
                "text-embedding-ada-002", "OpenAI Embedding Model"
            )
        ));
    }
}

🧪 九、测试代码

package com.example.ai.service;

import com.example.ai.core.AiClient;
import com.example.ai.core.model.ChatRequest;
import com.example.ai.core.model.ChatResponse;
import com.example.ai.core.model.Message;
import com.example.ai.router.ModelRouter;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.test.util.ReflectionTestUtils;

import java.util.List;

import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.*;

@ExtendWith(MockitoExtension.class)
class ChatServiceTest {
    
    @Mock
    private ModelRouter modelRouter;
    
    @Mock
    private AiClient aiClient;
    
    private ChatService chatService;
    
    @BeforeEach
    void setUp() {
        chatService = new ChatService(modelRouter, null);
    }
    
    @Test
    void testChat_Success() {
        // 准备测试数据
        String model = "gpt-3.5-turbo";
        String userMessage = "Hello, how are you?";
        String expectedResponse = "I'm fine, thank you!";
        
        // 模拟行为
        when(modelRouter.route(model)).thenReturn(aiClient);
        
        ChatResponse mockResponse = ChatResponse.builder()
            .content(expectedResponse)
            .build();
        when(aiClient.chat(any(ChatRequest.class))).thenReturn(mockResponse);
        
        // 执行测试
        String result = chatService.chat(model, userMessage);
        
        // 验证结果
        assertEquals(expectedResponse, result);
        verify(modelRouter, times(1)).route(model);
        verify(aiClient, times(1)).chat(any(ChatRequest.class));
    }
    
    @Test
    void testChat_WithSystemPrompt() {
        // 准备测试数据
        String model = "gpt-3.5-turbo";
        String systemPrompt = "You are a helpful assistant.";
        String userMessage = "What is AI?";
        
        // 模拟行为
        when(modelRouter.route(model)).thenReturn(aiClient);
        
        ChatResponse mockResponse = ChatResponse.builder()
            .content("AI stands for Artificial Intelligence.")
            .build();
        when(aiClient.chat(any(ChatRequest.class))).thenReturn(mockResponse);
        
        // 执行测试
        String result = chatService.chatWithContext(
            model, systemPrompt, List.of(), userMessage
        );
        
        // 验证结果
        assertNotNull(result);
        verify(aiClient, times(1)).chat(any(ChatRequest.class));
    }
    
    @Test
    void testChat_ModelNotFound() {
        // 准备测试数据
        String model = "unknown-model";
        String userMessage = "Hello";
        
        // 模拟行为
        when(modelRouter.route(model)).thenThrow(
            new IllegalArgumentException("Model not found: " + model)
        );
        
        // 执行测试并验证异常
        assertThrows(IllegalArgumentException.class, () -> {
            chatService.chat(model, userMessage);
        });
    }
}

📈 十、设计模式组合使用架构图

配置层

监控层

提供商层

核心层 - 设计模式组合

路由层

客户端层

REST 控制器

服务层

模型路由器

模型路由器工厂

工厂模式
AiClientFactory

策略模式
AiProviderStrategy

模板方法模式
AbstractAiClient

适配器模式
AiModelAdapter

责任链模式
AiClientInterceptorChain

构建者模式
ChatRequestBuilder

OpenAI 客户端

Azure AI 客户端

本地模型客户端

指标收集器

拦截器链

Micrometer 集成

自动配置

属性绑定

条件装配

🔄 十一、工作流程时序图

AI服务提供商指标监控拦截器链模型适配器AI客户端实现AI策略客户端工厂模型路由器聊天服务REST 控制器客户端AI服务提供商指标监控拦截器链模型适配器AI客户端实现AI策略客户端工厂模型路由器聊天服务REST 控制器客户端完整处理流程耗时: 200-800msPOST /api/v1/ai/chat调用聊天服务路由到合适客户端获取客户端工厂选择策略返回策略返回客户端返回AI客户端调用chat方法前置拦截处理继续处理转换请求格式返回转换后请求调用AI服务API返回原始响应转换响应格式返回统一格式响应后置拦截处理记录监控指标记录完成处理完成返回响应记录业务指标记录完成返回聊天结果返回HTTP响应

🎯 十二、运行和测试

1. 启动应用

# 设置环境变量
export OPENAI_API_KEY=your-api-key-here

# 编译并运行
mvn clean package
java -jar target/spring-ai-simulation-1.0.0.jar

# 或者使用 Maven 直接运行
mvn spring-boot:run

2. 测试 API

# 健康检查
curl http://localhost:8080/api/v1/ai/health

# 获取可用模型
curl http://localhost:8080/api/v1/ai/models

# 聊天测试
curl -X POST http://localhost:8080/api/v1/ai/chat \
  -H "Content-Type: application/json" \
  -d '{"message": "Hello, how are you?"}'

# 指定模型聊天
curl -X POST "http://localhost:8080/api/v1/ai/chat?model=gpt-3.5-turbo" \
  -H "Content-Type: application/json" \
  -d '{"message": "What is Spring AI?"}'

3. 监控指标

  • Prometheus 指标: http://localhost:8080/actuator/prometheus
  • 健康检查: http://localhost:8080/actuator/health
  • 应用信息: http://localhost:8080/actuator/info
  • 指标详情: http://localhost:8080/actuator/metrics

📊 十三、设计模式总结

设计模式应用位置解决的问题核心实现类
工厂模式AI 客户端创建统一创建不同 AI 提供商客户端AiClientFactory
策略模式AI 提供商切换支持多种 AI 服务提供商AiProviderStrategy
模板方法请求处理流程统一 AI 请求处理流程AbstractAiClient
适配器模式模型格式转换转换不同模型的请求/响应格式AiModelAdapter
责任链模式拦截器处理实现可插拔的拦截器链AiClientInterceptorChain
构建者模式请求对象构建构建复杂的请求对象ChatRequestBuilder
单例模式配置管理确保配置对象唯一性AiProperties
代理模式监控增强为 AI 客户端添加监控功能AiClientProxy

🔧 十四、扩展指南

1. 添加新的 AI 提供商

@Component
public class CustomAiStrategy implements AiProviderStrategy {
    
    @Override
    public String getProviderName() {
        return "custom";
    }
    
    @Override
    public ChatResponse chat(ChatRequest request) {
        // 实现自定义 AI 服务调用逻辑
        return null;
    }
    
    // 其他方法实现...
}

2. 添加自定义拦截器

@Component
public class CustomInterceptor implements AiClientInterceptor {
    
    @Override
    public boolean preHandle(ChatRequest request) {
        // 自定义前置处理逻辑
        return true;
    }
    
    @Override
    public void postHandle(ChatRequest request, ChatResponse response) {
        // 自定义后置处理逻辑
    }
}

3. 配置自定义模型

spring:
  ai:
    custom:
      enabled: true
      api-key: ${CUSTOM_API_KEY}
      endpoint: https://api.custom-ai.com/v1
      models:
        - custom-model-1
        - custom-model-2

这个完整的 Spring AI 模拟框架展示了如何通过多种设计模式的组合,构建一个灵活、可扩展、可维护的 AI 服务集成框架。每个设计模式都解决了特定的问题,共同构成了一个完整的解决方案。

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Coder_Boy_

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值