Spring AI 源码深度分析
先接地气 了解下

🎯 一、Spring AI 源码架构概览
1. 源码仓库结构
spring-ai/
├── spring-ai-core/ # 核心模块
│ ├── src/main/java/org/springframework/ai/
│ │ ├── client/ # 客户端接口
│ │ ├── model/ # 模型定义
│ │ ├── prompt/ # 提示词
│ │ ├── vectorstore/ # 向量存储
│ │ └── util/ # 工具类
├── spring-ai-openai/ # OpenAI 适配器
├── spring-ai-azure-openai/ # Azure OpenAI 适配器
├── spring-ai-transformers/ # 本地模型
├── spring-ai-huggingface/ # Hugging Face
├── spring-ai-bom/ # 依赖管理
└── spring-ai-test/ # 测试支持
2. 核心类图关系
🔧 二、核心流程源码分析
1. 自动配置机制
1.1 自动配置类
// spring-ai-openai/src/main/java/org/springframework/ai/openai/autoconfigure/OpenAiAutoConfiguration.java
@Configuration(proxyBeanMethods = false)
@ConditionalOnClass(OpenAiApi.class)
@EnableConfigurationProperties(OpenAiProperties.class)
@Import({ OpenAiConnectionDetails.class, OpenAiRestClientConfiguration.class })
public class OpenAiAutoConfiguration {
@Bean
@ConditionalOnMissingBean
public OpenAiChatClient openAiChatClient(
OpenAiConnectionDetails connectionDetails,
OpenAiProperties openAiProperties,
List<RequestResponsePostProcessor> postProcessors,
RestClient.Builder restClientBuilder) {
// 创建 OpenAiApi
OpenAiApi openAiApi = new OpenAiApi(
connectionDetails.getBaseUrl(),
openAiProperties.getApiKey(),
restClientBuilder
);
// 创建 ChatClient
OpenAiChatClient openAiChatClient = new OpenAiChatClient(
openAiApi,
openAiProperties.getOptions(),
postProcessors
);
// 配置重试机制
if (openAiProperties.getRetry() != null) {
openAiChatClient.setRetryTemplate(createRetryTemplate(openAiProperties.getRetry()));
}
return openAiChatClient;
}
@Bean
@ConditionalOnMissingBean
public OpenAiEmbeddingClient openAiEmbeddingClient(
OpenAiConnectionDetails connectionDetails,
OpenAiProperties openAiProperties,
RestClient.Builder restClientBuilder) {
OpenAiApi openAiApi = new OpenAiApi(
connectionDetails.getBaseUrl(),
openAiProperties.getApiKey(),
restClientBuilder
);
return new OpenAiEmbeddingClient(openAiApi, openAiProperties.getEmbedding().getOptions());
}
}
1.2 属性配置类
// spring-ai-openai/src/main/java/org/springframework/ai/openai/autoconfigure/OpenAiProperties.java
@ConfigurationProperties(prefix = OpenAiProperties.CONFIG_PREFIX)
public class OpenAiProperties {
public static final String CONFIG_PREFIX = "spring.ai.openai";
private String baseUrl = "https://api.openai.com/v1";
private String apiKey;
private ChatOptions chat = new ChatOptions();
private EmbeddingOptions embedding = new EmbeddingOptions();
private ImageOptions image = new ImageOptions();
@DurationUnit(ChronoUnit.SECONDS)
private Duration timeout = Duration.ofSeconds(60);
private RetryOptions retry;
@Data
public static class ChatOptions {
private String model = "gpt-3.5-turbo";
private Double temperature = 0.7;
private Double topP = 1.0;
private Integer maxTokens = null;
private List<String> stop = null;
private Double presencePenalty = 0.0;
private Double frequencyPenalty = 0.0;
}
}
2. 核心接口设计
2.1 ChatClient 接口
// spring-ai-core/src/main/java/org/springframework/ai/client/ChatClient.java
public interface ChatClient {
/**
* 同步调用聊天接口
*/
default String call(String message) {
Prompt prompt = new Prompt(new UserMessage(message));
ChatResponse response = call(prompt);
return response.getResult().getOutput().getContent();
}
/**
* 核心聊天方法
*/
ChatResponse call(Prompt prompt);
/**
* 流式聊天
*/
default Flux<ChatResponse> stream(Prompt prompt) {
throw new UnsupportedOperationException("Streaming not supported");
}
/**
* 带选项的聊天
*/
default ChatResponse call(ChatOptions options) {
Prompt prompt = new Prompt(new UserMessage(""), options);
return call(prompt);
}
}
2.2 EmbeddingClient 接口
// spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingClient.java
public interface EmbeddingClient {
/**
* 单文本向量化
*/
default List<Double> embed(String text) {
EmbeddingResponse response = call(new EmbeddingRequest(List.of(text), null));
return response.getResults().get(0).getOutput();
}
/**
* 批量向量化
*/
default List<List<Double>> embed(List<String> texts) {
EmbeddingResponse response = call(new EmbeddingRequest(texts, null));
return response.getResults().stream()
.map(Embedding::getOutput)
.collect(Collectors.toList());
}
/**
* 核心向量化方法
*/
EmbeddingResponse call(EmbeddingRequest request);
/**
* 获取向量维度
*/
default int dimensions() {
return embed("test").size();
}
}
3. OpenAI 适配器实现
3.1 OpenAiChatClient 实现
// spring-ai-openai/src/main/java/org/springframework/ai/openai/client/OpenAiChatClient.java
public class OpenAiChatClient implements ChatClient, ChatOptionsCapable, Retryable {
private final OpenAiApi openAiApi;
private final OpenAiChatOptions defaultOptions;
private final List<RequestResponsePostProcessor> postProcessors;
private RetryTemplate retryTemplate = new RetryTemplate();
public OpenAiChatClient(OpenAiApi openAiApi,
OpenAiChatOptions options,
List<RequestResponsePostProcessor> postProcessors) {
this.openAiApi = openAiApi;
this.defaultOptions = options;
this.postProcessors = postProcessors != null ? postProcessors : List.of();
}
@Override
public ChatResponse call(Prompt prompt) {
return retryTemplate.execute(context -> {
// 1. 构建请求
ChatCompletionRequest request = createRequest(prompt);
// 2. 执行请求
ChatCompletion completion = openAiApi.chatCompletion(request);
// 3. 构建响应
ChatResponse response = convertCompletionToResponse(completion, prompt);
// 4. 后处理
return postProcessResponse(response, prompt);
});
}
@Override
public Flux<ChatResponse> stream(Prompt prompt) {
return Flux.create(sink -> {
try {
// 1. 构建流式请求
ChatCompletionRequest request = createRequest(prompt);
request.setStream(true);
// 2. 执行流式请求
Flux<ChatCompletionChunk> flux = openAiApi.chatCompletionStream(request);
// 3. 处理流式响应
flux.subscribe(
chunk -> {
ChatResponse response = convertChunkToResponse(chunk, prompt);
sink.next(response);
},
sink::error,
sink::complete
);
} catch (Exception e) {
sink.error(e);
}
});
}
private ChatCompletionRequest createRequest(Prompt prompt) {
// 合并默认选项和提示词选项
OpenAiChatOptions mergedOptions = mergeOptions(prompt);
// 转换消息格式
List<ChatCompletionMessage> messages = convertMessages(prompt.getMessages());
return ChatCompletionRequest.builder()
.model(mergedOptions.getModel())
.messages(messages)
.temperature(mergedOptions.getTemperature())
.topP(mergedOptions.getTopP())
.maxTokens(mergedOptions.getMaxTokens())
.stop(mergedOptions.getStop())
.presencePenalty(mergedOptions.getPresencePenalty())
.frequencyPenalty(mergedOptions.getFrequencyPenalty())
.build();
}
private ChatResponse convertCompletionToResponse(ChatCompletion completion, Prompt prompt) {
List<Generation> generations = completion.getChoices().stream()
.map(choice -> new Generation(choice.getMessage().getContent(), extractMetadata(choice)))
.collect(Collectors.toList());
Map<String, Object> metadata = new HashMap<>();
metadata.put("id", completion.getId());
metadata.put("model", completion.getModel());
metadata.put("created", completion.getCreated());
if (completion.getUsage() != null) {
metadata.put("usage", completion.getUsage());
}
return new ChatResponse(generations, metadata);
}
private ChatResponse postProcessResponse(ChatResponse response, Prompt prompt) {
ChatResponse processed = response;
for (RequestResponsePostProcessor processor : postProcessors) {
processed = processor.postProcess(processed, prompt);
}
return processed;
}
}
3.2 OpenAiApi HTTP 客户端
// spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java
public class OpenAiApi {
private final RestClient restClient;
private final String baseUrl;
public OpenAiApi(String baseUrl, String apiKey, RestClient.Builder restClientBuilder) {
this.baseUrl = baseUrl;
this.restClient = restClientBuilder
.baseUrl(baseUrl)
.defaultHeader("Authorization", "Bearer " + apiKey)
.defaultHeader("Content-Type", "application/json")
.requestInterceptor(new RetryRequestInterceptor())
.build();
}
public ChatCompletion chatCompletion(ChatCompletionRequest request) {
return restClient.post()
.uri("/chat/completions")
.body(request)
.retrieve()
.onStatus(status -> status.is4xxClientError() || status.is5xxServerError(),
(req, res) -> handleError(req, res, request))
.body(ChatCompletion.class);
}
public Flux<ChatCompletionChunk> chatCompletionStream(ChatCompletionRequest request) {
return restClient.post()
.uri("/chat/completions")
.body(request)
.retrieve()
.body(FluxTokenExtractor.class)
.flatMap(this::parseChunk);
}
private Mono<ChatCompletionChunk> parseChunk(String chunk) {
if (chunk.startsWith("data: ")) {
String data = chunk.substring(6);
if ("[DONE]".equals(data)) {
return Mono.empty();
}
try {
ChatCompletionChunk completion = objectMapper.readValue(data, ChatCompletionChunk.class);
return Mono.just(completion);
} catch (JsonProcessingException e) {
return Mono.error(e);
}
}
return Mono.empty();
}
private void handleError(ClientHttpRequest request,
ClientHttpResponse response,
ChatCompletionRequest originalRequest) {
// 解析错误响应
try {
String body = new String(response.getBody().readAllBytes(), StandardCharsets.UTF_8);
OpenAiError error = objectMapper.readValue(body, OpenAiError.class);
throw new OpenAiHttpException(error, response.getStatusCode());
} catch (IOException e) {
throw new RuntimeException("Failed to parse error response", e);
}
}
}
4. 请求/响应模型
4.1 请求模型
// spring-ai-openai/src/main/java/org/springframework/ai/openai/api/ChatCompletionRequest.java
public class ChatCompletionRequest {
private String model;
private List<ChatCompletionMessage> messages;
private Double temperature;
private Double topP;
private Integer n = 1;
private Boolean stream = false;
private List<String> stop;
private Integer maxTokens;
private Double presencePenalty = 0.0;
private Double frequencyPenalty = 0.0;
private Map<String, Integer> logitBias;
private String user;
@Data
@Builder
public static class ChatCompletionMessage {
private String role; // "system", "user", "assistant"
private String content;
private String name;
private List<ToolCall> toolCalls;
}
}
4.2 响应模型
// spring-ai-openai/src/main/java/org/springframework/ai/openai/api/ChatCompletion.java
public class ChatCompletion {
private String id;
private String object = "chat.completion";
private Long created;
private String model;
private List<Choice> choices;
private Usage usage;
@Data
public static class Choice {
private Integer index;
private ChatCompletionMessage message;
private String finishReason;
private Double logprobs;
}
@Data
public static class Usage {
private Integer promptTokens;
private Integer completionTokens;
private Integer totalTokens;
}
}
5. 核心调用流程
// 核心调用流程模拟
public class ChatClientInvocationProcess {
public ChatResponse invokeChat(ChatClient chatClient, String userMessage) {
// 1. 创建提示词
Prompt prompt = new Prompt(new UserMessage(userMessage));
// 2. 调用 ChatClient
ChatResponse response = chatClient.call(prompt);
// 3. 处理响应
return processResponse(response);
}
private ChatResponse processResponse(ChatResponse response) {
// 提取生成内容
String content = response.getResult().getOutput().getContent();
// 提取元数据
Map<String, Object> metadata = response.getMetadata();
// 处理 Token 使用
Usage usage = (Usage) metadata.get("usage");
if (usage != null) {
log.info("Token usage - Prompt: {}, Completion: {}, Total: {}",
usage.getPromptTokens(),
usage.getCompletionTokens(),
usage.getTotalTokens());
}
return response;
}
}
6. 向量存储实现
6.1 VectorStore 接口
// spring-ai-core/src/main/java/org/springframework/ai/vectorstore/VectorStore.java
public interface VectorStore {
/**
* 添加文档
*/
void add(List<Document> documents);
/**
* 相似度搜索
*/
List<Document> similaritySearch(SearchRequest request);
/**
* 混合搜索
*/
default List<Document> similaritySearch(String query) {
return similaritySearch(SearchRequest.query(query));
}
/**
* 带过滤的搜索
*/
default List<Document> similaritySearch(SearchRequest request,
MetadataFilter filter) {
SearchRequest filteredRequest = SearchRequest.from(request)
.withFilterExpression(filter.getExpression())
.build();
return similaritySearch(filteredRequest);
}
/**
* 删除文档
*/
default Optional<Boolean> delete(List<String> idList) {
return Optional.empty();
}
}
6.2 SimpleVectorStore 实现
// spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStore.java
public class SimpleVectorStore implements VectorStore {
private final EmbeddingClient embeddingClient;
private final Map<String, EmbeddingDocument> store = new ConcurrentHashMap<>();
private final RandomAccessSearcher searcher = new RandomAccessSearcher();
public SimpleVectorStore(EmbeddingClient embeddingClient) {
this.embeddingClient = embeddingClient;
}
@Override
public void add(List<Document> documents) {
// 1. 向量化文档
List<String> texts = documents.stream()
.map(Document::getContent)
.collect(Collectors.toList());
List<Embedding> embeddings = embeddingClient.embed(texts);
// 2. 存储文档和向量
for (int i = 0; i < documents.size(); i++) {
Document doc = documents.get(i);
Embedding embedding = embeddings.get(i);
EmbeddingDocument embeddingDoc = new EmbeddingDocument(doc.getId(),
doc.getContent(),
embedding,
doc.getMetadata());
store.put(doc.getId(), embeddingDoc);
searcher.add(embeddingDoc);
}
}
@Override
public List<Document> similaritySearch(SearchRequest request) {
// 1. 向量化查询
Embedding queryEmbedding = embeddingClient.embed(request.getQuery());
// 2. 搜索相似文档
List<SearchResult> results = searcher.search(queryEmbedding,
request.getTopK(),
request.getSimilarityThreshold());
// 3. 转换为 Document
return results.stream()
.map(result -> {
EmbeddingDocument embeddingDoc = store.get(result.getId());
return Document.builder()
.id(embeddingDoc.getId())
.content(embeddingDoc.getContent())
.metadata(embeddingDoc.getMetadata())
.score(result.getScore())
.build();
})
.collect(Collectors.toList());
}
}
7. 后处理器机制
// spring-ai-core/src/main/java/org/springframework/ai/client/RequestResponsePostProcessor.java
public interface RequestResponsePostProcessor {
/**
* 请求预处理
*/
default Prompt preProcess(Prompt prompt) {
return prompt;
}
/**
* 响应后处理
*/
default ChatResponse postProcess(ChatResponse response, Prompt prompt) {
return response;
}
}
// 实现示例:日志后处理器
@Component
public class LoggingPostProcessor implements RequestResponsePostProcessor {
private static final Logger log = LoggerFactory.getLogger(LoggingPostProcessor.class);
@Override
public Prompt preProcess(Prompt prompt) {
log.debug("Pre-processing prompt: {}", prompt.getContents());
return prompt;
}
@Override
public ChatResponse postProcess(ChatResponse response, Prompt prompt) {
log.debug("Post-processing response. Tokens used: {}",
response.getMetadata().get("usage"));
return response;
}
}
8. 重试机制实现
// spring-ai-core/src/main/java/org/springframework/ai/retry/Retryable.java
public interface Retryable {
void setRetryTemplate(RetryTemplate retryTemplate);
RetryTemplate getRetryTemplate();
}
// 默认重试配置
public class DefaultRetryConfig {
public static RetryTemplate createDefaultRetryTemplate() {
RetryTemplate retryTemplate = new RetryTemplate();
// 重试策略:最多3次,排除某些异常
SimpleRetryPolicy retryPolicy = new SimpleRetryPolicy(3,
Collections.singletonMap(TransientDataAccessException.class, true),
false);
// 退避策略:初始等待1秒,倍数2
ExponentialBackOffPolicy backOffPolicy = new ExponentialBackOffPolicy();
backOffPolicy.setInitialInterval(1000);
backOffPolicy.setMultiplier(2.0);
backOffPolicy.setMaxInterval(10000);
retryTemplate.setRetryPolicy(retryPolicy);
retryTemplate.setBackOffPolicy(backOffPolicy);
return retryTemplate;
}
}
🔄 三、核心流程时序图
📊 四、关键设计模式分析
1. 模板方法模式
// 抽象模板类定义标准流程
public abstract class AbstractAiClient {
public final ChatResponse chat(ChatRequest request) {
// 1. 预处理
preProcess(request);
// 2. 构建请求(由子类实现)
Object providerRequest = toProviderRequest(request);
// 3. 执行请求
Object rawResponse = executeRequest(providerRequest);
// 4. 转换响应(由子类实现)
ChatResponse response = toChatResponse(rawResponse);
// 5. 后处理
return postProcess(response, request);
}
protected abstract Object toProviderRequest(ChatRequest request);
protected abstract ChatResponse toChatResponse(Object response);
// 钩子方法
protected void preProcess(ChatRequest request) {}
protected ChatResponse postProcess(ChatResponse response, ChatRequest request) {
return response;
}
}
2. 策略模式
// 策略接口
public interface AiProviderStrategy {
String getProviderName();
ChatResponse chat(ChatRequest request);
boolean supports(String model);
}
// 具体策略
@Component
@ConditionalOnProperty(name = "spring.ai.openai.enabled", havingValue = "true")
public class OpenAiStrategy implements AiProviderStrategy {
@Override
public String getProviderName() {
return "openai";
}
@Override
public ChatResponse chat(ChatRequest request) {
// OpenAI 特定的实现
return null;
}
}
3. 工厂模式
@Component
public class AiClientFactory {
private final Map<String, AiProviderStrategy> strategies;
public AiClientFactory(List<AiProviderStrategy> strategyList) {
this.strategies = strategyList.stream()
.collect(Collectors.toMap(
AiProviderStrategy::getProviderName,
Function.identity()
));
}
public AiClient createClient(String provider, String model) {
AiProviderStrategy strategy = strategies.get(provider);
if (strategy == null) {
throw new IllegalArgumentException("Unsupported provider: " + provider);
}
return new AiClientAdapter(strategy, model);
}
}
🎯 五、性能优化实现
1. 连接池管理
@Configuration
public class ConnectionPoolConfig {
@Bean
public RestClient.Builder aiRestClientBuilder() {
HttpClientConnectionManager connectionManager = new PoolingHttpClientConnectionManager();
((PoolingHttpClientConnectionManager) connectionManager)
.setMaxTotal(100);
((PoolingHttpClientConnectionManager) connectionManager)
.setDefaultMaxPerRoute(20);
HttpClient httpClient = HttpClients.custom()
.setConnectionManager(connectionManager)
.setKeepAliveStrategy((response, context) -> 30 * 1000)
.build();
return RestClient.builder()
.requestFactory(new HttpComponentsClientHttpRequestFactory(httpClient));
}
}
2. 响应缓存
@Component
@CacheConfig(cacheNames = "aiResponses")
public class CachingAiClient implements ChatClient {
private final ChatClient delegate;
public CachingAiClient(ChatClient delegate) {
this.delegate = delegate;
}
@Override
@Cacheable(key = "#prompt.hashCode()", unless = "#result == null")
public ChatResponse call(Prompt prompt) {
return delegate.call(prompt);
}
@CacheEvict(allEntries = true)
public void clearCache() {
// 清理缓存
}
}
🔍 六、调试和监控
1. 日志配置
logging:
level:
org.springframework.ai: DEBUG
org.springframework.ai.openai: INFO
org.apache.http: WARN
2. 监控端点
@RestControllerEndpoint(id = "ai")
public class AiMetricsEndpoint {
private final MeterRegistry meterRegistry;
public AiMetricsEndpoint(MeterRegistry meterRegistry) {
this.meterRegistry = meterRegistry;
}
@ReadOperation
public Map<String, Object> metrics() {
Map<String, Object> metrics = new HashMap<>();
// 收集所有 AI 相关的指标
meterRegistry.find("ai.client.requests").meters()
.forEach(meter -> {
metrics.put(meter.getId().getName(), meter.measure());
});
return metrics;
}
}
📈 七、源码学习建议
1. 学习路径
- 从接口开始:先理解
ChatClient、EmbeddingClient等核心接口 - 查看自动配置:学习
OpenAiAutoConfiguration等配置类 - 跟踪实现:选择一个具体实现(如
OpenAiChatClient)深入理解 - 研究设计模式:注意模板方法、策略、工厂等模式的应用
- 调试运行:实际运行示例代码,打断点跟踪执行流程
2. 关键调试技巧
// 启用详细日志
@SpringBootApplication
@Slf4j
public class Application {
public static void main(String[] args) {
SpringApplication.run(Application.class, args);
}
@Bean
public RequestResponsePostProcessor debugProcessor() {
return new RequestResponsePostProcessor() {
@Override
public Prompt preProcess(Prompt prompt) {
log.debug(">>> Request: {}", prompt.getContents());
return prompt;
}
@Override
public ChatResponse postProcess(ChatResponse response, Prompt prompt) {
log.debug("<<< Response: {}", response.getResult().getOutput().getContent());
return response;
}
};
}
}
这个源码分析涵盖了 Spring AI 的核心架构、关键实现和设计思想。通过理解这些代码,您可以更深入地掌握 Spring AI 的工作原理,并在实际开发中更好地使用和扩展它。
199

被折叠的 条评论
为什么被折叠?



