SpringAI兼容百炼平台教程(黑马课程)

兼容百炼平台

截止SpringAI的1.0.0-M6版本为止,SpringAI的OpenAiModel和阿里云百炼的部分接口存在兼容性问题,包括但不限于以下两个问题:

  • FunctionCalling的stream模式,阿里云百炼返回的tool-arguments是不完整的,需要拼接,而OpenAI则是完整的,无需拼接。
  • 音频识别中的数据格式,阿里云百炼的qwen-omni模型要求的参数格式为data:;base64,${media-data},而OpenAI是直接{media-data}

由于SpringAI的OpenAI模块是遵循OpenAI规范的,所以即便版本升级也不会去兼容阿里云,除非SpringAI单独为阿里云开发starter,所以目前解决方案有两个:

  • 等待阿里云官方推出的spring-alibaba-ai升级到最新版本
  • 自己重写OpenAiModel的实现逻辑。

接下来,我们就用重写OpenAiModel的方式,来解决上述两个问题。

AlibabaOpenAIChatModel

首先,我们自己写一个遵循阿里巴巴百炼平台接口规范的ChatModel,其中大部分代码来自SpringAI的OpenAiChatModel,只需要重写接口协议不匹配的地方即可,重写部分会以黄色高亮显示。

新建一个AlibabaOpenAiChatModel类:

package com.itheima.ai.model;

import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationRegistry;
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.*;
import org.springframework.ai.chat.model.*;
import org.springframework.ai.chat.observation.ChatModelObservationContext;
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.Media;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackResolver;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.model.tool.LegacyToolCallingManager;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.model.tool.ToolExecutionResult;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.api.common.OpenAiApiConstants;
import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.core.io.ByteArrayResource;
import org.springframework.core.io.Resource;
import org.springframework.http.ResponseEntity;
import org.springframework.lang.Nullable;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.*;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;

public class AlibabaOpenAiChatModel extends AbstractToolCallSupport implements ChatModel {

    private static final Logger logger = LoggerFactory.getLogger(AlibabaOpenAiChatModel.class);

    private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention();

    private static final ToolCallingManager DEFAULT_TOOL_CALLING_MANAGER = ToolCallingManager.builder().build();

    /**
     * The default options used for the chat completion requests.
     */
    private final OpenAiChatOptions defaultOptions;

    /**
     * The retry template used to retry the OpenAI API calls.
     */
    private final RetryTemplate retryTemplate;

    /**
     * Low-level access to the OpenAI API.
     */
    private final OpenAiApi openAiApi;

    /**
     * Observation registry used for instrumentation.
     */
    private final ObservationRegistry observationRegistry;

    private final ToolCallingManager toolCallingManager;

    /**
     * Conventions to use for generating observations.
     */
    private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;

    /**
     * Creates an instance of the AlibabaOpenAiChatModel.
     * @param openAiApi The OpenAiApi instance to be used for interacting with the OpenAI
     * Chat API.
     * @throws IllegalArgumentException if openAiApi is null
     * @deprecated Use AlibabaOpenAiChatModel.Builder.
     */
    @Deprecated
    public AlibabaOpenAiChatModel(OpenAiApi openAiApi) {
        this(openAiApi, OpenAiChatOptions.builder().model(OpenAiApi.DEFAULT_CHAT_MODEL).temperature(0.7).build());
    }

    /**
     * Initializes an instance of the AlibabaOpenAiChatModel.
     * @param openAiApi The OpenAiApi instance to be used for interacting with the OpenAI
     * Chat API.
     * @param options The OpenAiChatOptions to configure the chat model.
     * @deprecated Use AlibabaOpenAiChatModel.Builder.
     */
    @Deprecated
    public AlibabaOpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions options) {
        this(openAiApi, options, null, RetryUtils.DEFAULT_RETRY_TEMPLATE);
    }

    /**
     * Initializes a new instance of the AlibabaOpenAiChatModel.
     * @param openAiApi The OpenAiApi instance to be used for interacting with the OpenAI
     * Chat API.
     * @param options The OpenAiChatOptions to configure the chat model.
     * @param functionCallbackResolver The function callback resolver.
     * @param retryTemplate The retry template.
     * @deprecated Use AlibabaOpenAiChatModel.Builder.
     */
    @Deprecated
    public AlibabaOpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions options,
                           @Nullable FunctionCallbackResolver functionCallbackResolver, RetryTemplate retryTemplate) {
        this(openAiApi, options, functionCallbackResolver, List.of(), retryTemplate);
    }

    /**
     * Initializes a new instance of the AlibabaOpenAiChatModel.
     * @param openAiApi The OpenAiApi instance to be used for interacting with the OpenAI
     * Chat API.
     * @param options The OpenAiChatOptions to configure the chat model.
     * @param functionCallbackResolver The function callback resolver.
     * @param toolFunctionCallbacks The tool function callbacks.
     * @param retryTemplate The retry template.
     * @deprecated Use AlibabaOpenAiChatModel.Builder.
     */
    @Deprecated
    public AlibabaOpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions options,
                           @Nullable FunctionCallbackResolver functionCallbackResolver,
                           @Nullable List<FunctionCallback> toolFunctionCallbacks, RetryTemplate retryTemplate) {
        this(openAiApi, options, functionCallbackResolver, toolFunctionCallbacks, retryTemplate,
                ObservationRegistry.NOOP);
    }

    /**
     * Initializes a new instance of the AlibabaOpenAiChatModel.
     * @param openAiApi The OpenAiApi instance to be used for interacting with the OpenAI
     * Chat API.
     * @param options The OpenAiChatOptions to configure the chat model.
     * @param functionCallbackResolver The function callback resolver.
     * @param toolFunctionCallbacks The tool function callbacks.
     * @param retryTemplate The retry template.
     * @param observationRegistry The ObservationRegistry used for instrumentation.
     * @deprecated Use AlibabaOpenAiChatModel.Builder or AlibabaOpenAiChatModel(OpenAiApi,
     * OpenAiChatOptions, ToolCallingManager, RetryTemplate, ObservationRegistry).
     */
    @Deprecated
    public AlibabaOpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions options,
                           @Nullable FunctionCallbackResolver functionCallbackResolver,
                           @Nullable List<FunctionCallback> toolFunctionCallbacks, RetryTemplate retryTemplate,
                           ObservationRegistry observationRegistry) {
        this(openAiApi, options,
                LegacyToolCallingManager.builder()
                        .functionCallbackResolver(functionCallbackResolver)
                        .functionCallbacks(toolFunctionCallbacks)
                        .build(),
                retryTemplate, observationRegistry);
        logger.warn("This constructor is deprecated and will be removed in the next milestone. "
                + "Please use the AlibabaOpenAiChatModel.Builder or the new constructor accepting ToolCallingManager instead.");
    }

    public AlibabaOpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions defaultOptions, ToolCallingManager toolCallingManager,
                           RetryTemplate retryTemplate, ObservationRegistry observationRegistry) {
        // We do not pass the 'defaultOptions' to the AbstractToolSupport,
        // because it modifies them. We are using ToolCallingManager instead,
        // so we just pass empty options here.
        super(null, OpenAiChatOptions.builder().build(), List.of());
        Assert.notNull(openAiApi, "openAiApi cannot be null");
        Assert.notNull(defaultOptions, "defaultOptions cannot be null");
        Assert.notNull(toolCallingManager, "toolCallingManager cannot be null");
        Assert.notNull(retryTemplate, "retryTemplate cannot be null");
        Assert.notNull(observationRegistry, "observationRegistry cannot be null");
        this.openAiApi = openAiApi;
        this.defaultOptions = defaultOptions;
        this.toolCallingManager = toolCallingManager;
        this.retryTemplate = retryTemplate;
        this.observationRegistry = observationRegistry;
    }

    @Override
    public ChatResponse call(Prompt prompt) {
        // Before moving any further, build the final request Prompt,
        // merging runtime and default options.
        Prompt requestPrompt = buildRequestPrompt(prompt);
        return this.internalCall(requestPrompt, null);
    }

    public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) {

        OpenAiApi.ChatCompletionRequest request = createRequest(prompt, false);

        ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
                .prompt(prompt)
                .provider(OpenAiApiConstants.PROVIDER_NAME)
                .requestOptions(prompt.getOptions())
                .build();

        ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION
                .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
                        this.observationRegistry)
                .observe(() -> {

                    ResponseEntity<OpenAiApi.ChatCompletion> completionEntity = this.retryTemplate
                            .execute(ctx -> this.openAiApi.chatCompletionEntity(request, getAdditionalHttpHeaders(prompt)));

                    var chatCompletion = completionEntity.getBody();

                    if (chatCompletion == null) {
                        logger.warn("No chat completion returned for prompt: {}", prompt);
                        return new ChatResponse(List.of());
                    }

                    List<OpenAiApi.ChatCompletion.Choice> choices = chatCompletion.choices();
                    if (choices == null) {
                        logger.warn("No choices returned for prompt: {}", prompt);
                        return new ChatResponse(List.of());
                    }

                    List<Generation> generations = choices.stream().map(choice -> {
                        // @formatter:off
                        Map<String, Object> metadata = Map.of(
                                "id", chatCompletion.id() != null ? chatCompletion.id() : "",
                                "role", choice.message().role() != null ? choice.message().role().name() : "",
                                "index", choice.index(),
                                "finishReason", choice.finishReason() != null ? choice.finishReason().name() : "",
                                "refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "");
                        // @formatter:on
                        return buildGeneration(choice, metadata, request);
                    }).toList();

                    RateLimit rateLimit = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(completionEntity);

                    // Current usage
                    OpenAiApi.Usage usage = completionEntity.getBody().usage();
                    Usage currentChatResponseUsage = usage != null ? getDefaultUsage(usage) : new EmptyUsage();
                    Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage, previousChatResponse);
                    ChatResponse chatResponse = new ChatResponse(generations,
                            from(completionEntity.getBody(), rateLimit, accumulatedUsage));

                    observationContext.setResponse(chatResponse);

                    return chatResponse;

                });

        if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) && response != null
                && response.hasToolCalls()) {
            var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
            if (toolExecutionResult.returnDirect()) {
                // Return tool execution result directly to the client.
                return ChatResponse.builder()
                        .from(response)
                        .generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
                        .build();
            }
            else {
                // Send the tool execution result back to the model.
                return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
                        response);
            }
        }

        return response;
    }

    @Override
    public Flux<ChatResponse> stream(Prompt prompt) {
        // Before moving any further, build the final request Prompt,
        // merging runtime and default options.
        Prompt requestPrompt = buildRequestPrompt(prompt);
        return internalStream(requestPrompt, null);
    }

    public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse) {
        return Flux.deferContextual(contextView -> {
            OpenAiApi.ChatCompletionRequest request = createRequest(prompt, true);

            if (request.outputModalities() != null) {
                if (request.outputModalities().stream().anyMatch(m -> m.equals("audio"))) {
                    logger.warn("Audio output is not supported for streaming requests. Removing audio output.");
                    throw new IllegalArgumentException("Audio output is not supported for streaming requests.");
                }
            }
            if (request.audioParameters() != null) {
                logger.warn("Audio parameters are not supported for streaming requests. Removing audio parameters.");
                throw new IllegalArgumentException("Audio parameters are not supported for streaming requests.");
            }

            Flux<OpenAiApi.ChatCompletionChunk> completionChunks = this.openAiApi.chatCompletionStream(request,
                    getAdditionalHttpHeaders(prompt));

            // For chunked responses, only the first chunk contains the choice role.
            // The rest of the chunks with same ID share the same role.
            ConcurrentHashMap<String, String> roleMap = new ConcurrentHashMap<>();

            final ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
                    .prompt(prompt)
                    .provider(OpenAiApiConstants.PROVIDER_NAME)
                    .requestOptions(prompt.getOptions())
                    .build();

            Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(
                    this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
                    this.observationRegistry);

            observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start();

            // Convert the ChatCompletionChunk into a ChatCompletion to be able to reuse
            // the function call handling logic.
            Flux<ChatResponse> chatResponse = completionChunks.map(this::chunkToChatCompletion)
                    .switchMap(chatCompletion -> Mono.just(chatCompletion).map(chatCompletion2 -> {
                        try {
                            @SuppressWarnings("null")
                            String id = chatCompletion2.id();

                            List<Generation> generations = chatCompletion2.choices().stream().map(choice -> { // @formatter:off
                                if (choice.message().role() != null) {
                                    roleMap.putIfAbsent(id, choice.message().role().name());
                                }
                                Map<String, Object> metadata = Map.of(
                                        "id", chatCompletion2.id(),
                                        "role", roleMap.getOrDefault(id, ""),
                                        "index", choice.index(),
                                        "finishReason", choice.finishReason() != null ? choice.finishReason().name() : "",
                                        "refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "");

                                return buildGeneration(choice, metadata, request);
                            }).toList();
                            // @formatter:on
                            OpenAiApi.Usage usage = chatCompletion2.usage();
                            Usage currentChatResponseUsage = usage != null ? getDefaultUsage(usage) : new EmptyUsage();
                            Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage,
                                    previousChatResponse);
                            return new ChatResponse(generations, from(chatCompletion2, null, accumulatedUsage));
                        }
                        catch (Exception e) {
                            logger.error("Error processing chat completion", e);
                            return new ChatResponse(List.of());
                        }
                        // When in stream mode and enabled to include the usage, the OpenAI
                        // Chat completion response would have the usage set only in its
                        // final response. Hence, the following overlapping buffer is
                        // created to store both the current and the subsequent response
                        // to accumulate the usage from the subsequent response.
                    }))
                    .buffer(2, 1)
                    .map(bufferList -> {
                        ChatResponse firstResponse = bufferList.get(0);
                        if (request.streamOptions() != null && request.streamOptions().includeUsage()) {
                            if (bufferList.size() == 2) {
                                ChatResponse secondResponse = bufferList.get(1);
                                if (secondResponse != null && secondResponse.getMetadata() != null) {
                                    // This is the usage from the final Chat response for a
                                    // given Chat request.
                                    Usage usage = secondResponse.getMetadata().getUsage();
                                    if (!UsageUtils.isEmpty(usage)) {
                                        // Store the usage from the final response to the
                                        // penultimate response for accumulation.
                                        return new ChatResponse(firstResponse.getResults(),
                                                from(firstResponse.getMetadata(), usage));
                                    }
                                }
                            }
                        }
                        return firstResponse;
                    });

            // @formatter:off
            Flux<ChatResponse> flux = chatResponse.flatMap(response -> {

                        if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) && response.hasToolCalls()) {
                            var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
                            if (toolExecutionResult.returnDirect()) {
                                // Return tool execution result directly to the client.
                                return Flux.just(ChatResponse.builder().from(response)
                                        .generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
                                        .build());
                            } else {
                                // Send the tool execution result back to the model.
                                return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
                                        response);
                            }
                        }
                        else {
                            return Flux.just(response);
                        }
                    })
                    .doOnError(observation::error)
                    .doFinally(s -> observation.stop())
                    .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
            // @formatter:on

            return new MessageAggregator().aggregate(flux, observationContext::setResponse);

        });
    }

    private MultiValueMap<String, String> getAdditionalHttpHeaders(Prompt prompt) {

        Map<String, String> headers = new HashMap<>(this.defaultOptions.getHttpHeaders());
        if (prompt.getOptions() != null && prompt.getOptions() instanceof OpenAiChatOptions chatOptions) {
            headers.putAll(chatOptions.getHttpHeaders());
        }
        return CollectionUtils.toMultiValueMap(
                headers.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> List.of(e.getValue()))));
    }

    private Generation buildGeneration(OpenAiApi.ChatCompletion.Choice choice, Map<String, Object> metadata, OpenAiApi.ChatCompletionRequest request) {
        List<AssistantMessage.ToolCall> toolCalls = choice.message().toolCalls() == null ? List.of()
                : choice.message()
                .toolCalls()
                .stream()
                .map(toolCall -> new AssistantMessage.ToolCall(toolCall.id(), "function",
                        toolCall.function().name(), toolCall.function().arguments()))
                .reduce((tc1, tc2) -> new AssistantMessage.ToolCall(tc1.id(), "function", tc1.name(), tc1.arguments() + tc2.arguments()))
                .stream()
                .toList();

        String finishReason = (choice.finishReason() != null ? choice.finishReason().name() : "");
        var generationMetadataBuilder = ChatGenerationMetadata.builder().finishReason(finishReason);

        List<Media> media = new ArrayList<>();
        String textContent = choice.message().content();
        var audioOutput = choice.message().audioOutput();
        if (audioOutput != null) {
            String mimeType = String.format("audio/%s", request.audioParameters().format().name().toLowerCase());
            byte[] audioData = Base64.getDecoder().decode(audioOutput.data());
            Resource resource = new ByteArrayResource(audioData);
            Media.builder().mimeType(MimeTypeUtils.parseMimeType(mimeType)).data(resource).id(audioOutput.id()).build();
            media.add(Media.builder()
                    .mimeType(MimeTypeUtils.parseMimeType(mimeType))
                    .data(resource)
                    .id(audioOutput.id())
                    .build());
            if (!StringUtils.hasText(textContent)) {
                textContent = audioOutput.transcript();
            }
            generationMetadataBuilder.metadata("audioId", audioOutput.id());
            generationMetadataBuilder.metadata("audioExpiresAt", audioOutput.expiresAt());
        }

        var assistantMessage = new AssistantMessage(textContent, metadata, toolCalls, media);
        return new Generation(assistantMessage, generationMetadataBuilder.build());
    }

    private ChatResponseMetadata from(OpenAiApi.ChatCompletion result, RateLimit rateLimit, Usage usage) {
        Assert.notNull(result, "OpenAI ChatCompletionResult must not be null");
        var builder = ChatResponseMetadata.builder()
                .id(result.id() != null ? result.id() : "")
                .usage(usage)
                .model(result.model() != null ? result.model() : "")
                .keyValue("created", result.created() != null ? result.created() : 0L)
                .keyValue("system-fingerprint", result.systemFingerprint() != null ? result.systemFingerprint() : "");
        if (rateLimit != null) {
            builder.rateLimit(rateLimit);
        }
        return builder.build();
    }

    private ChatResponseMetadata from(ChatResponseMetadata chatResponseMetadata, Usage usage) {
        Assert.notNull(chatResponseMetadata, "OpenAI ChatResponseMetadata must not be null");
        var builder = ChatResponseMetadata.builder()
                .id(chatResponseMetadata.getId() != null ? chatResponseMetadata.getId() : "")
                .usage(usage)
                .model(chatResponseMetadata.getModel() != null ? chatResponseMetadata.getModel() : "");
        if (chatResponseMetadata.getRateLimit() != null) {
            builder.rateLimit(chatResponseMetadata.getRateLimit());
        }
        return builder.build();
    }

    /**
     * Convert the ChatCompletionChunk into a ChatCompletion. The Usage is set to null.
     * @param chunk the ChatCompletionChunk to convert
     * @return the ChatCompletion
     */
    private OpenAiApi.ChatCompletion chunkToChatCompletion(OpenAiApi.ChatCompletionChunk chunk) {
        List<OpenAiApi.ChatCompletion.Choice> choices = chunk.choices()
                .stream()
                .map(chunkChoice -> new OpenAiApi.ChatCompletion.Choice(chunkChoice.finishReason(), chunkChoice.index(), chunkChoice.delta(),
                        chunkChoice.logprobs()))
                .toList();

        return new OpenAiApi.ChatCompletion(chunk.id(), choices, chunk.created(), chunk.model(), chunk.serviceTier(),
                chunk.systemFingerprint(), "chat.completion", chunk.usage());
    }

    private DefaultUsage getDefaultUsage(OpenAiApi.Usage usage) {
        return new DefaultUsage(usage.promptTokens(), usage.completionTokens(), usage.totalTokens(), usage);
    }

    Prompt buildRequestPrompt(Prompt prompt) {
        // Process runtime options
        OpenAiChatOptions runtimeOptions = null;
        if (prompt.getOptions() != null) {
            if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) {
                runtimeOptions = ModelOptionsUtils.copyToTarget(toolCallingChatOptions, ToolCallingChatOptions.class,
                        OpenAiChatOptions.class);
            }
            else if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) {
                runtimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions, FunctionCallingOptions.class,
                        OpenAiChatOptions.class);
            }
            else {
                runtimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
                        OpenAiChatOptions.class);
            }
        }

        // Define request options by merging runtime options and default options
        OpenAiChatOptions requestOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions,
                OpenAiChatOptions.class);

        // Merge @JsonIgnore-annotated options explicitly since they are ignored by
        // Jackson, used by ModelOptionsUtils.
        if (runtimeOptions != null) {
            requestOptions.setHttpHeaders(
                    mergeHttpHeaders(runtimeOptions.getHttpHeaders(), this.defaultOptions.getHttpHeaders()));
            requestOptions.setInternalToolExecutionEnabled(
                    ModelOptionsUtils.mergeOption(runtimeOptions.isInternalToolExecutionEnabled(),
                            this.defaultOptions.isInternalToolExecutionEnabled()));
            requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(),
                    this.defaultOptions.getToolNames()));
            requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(),
                    this.defaultOptions.getToolCallbacks()));
            requestOptions.setToolContext(ToolCallingChatOptions.mergeToolContext(runtimeOptions.getToolContext(),
                    this.defaultOptions.getToolContext()));
        }
        else {
            requestOptions.setHttpHeaders(this.defaultOptions.getHttpHeaders());
            requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.isInternalToolExecutionEnabled());
            requestOptions.setToolNames(this.defaultOptions.getToolNames());
            requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks());
            requestOptions.setToolContext(this.defaultOptions.getToolContext());
        }

        ToolCallingChatOptions.validateToolCallbacks(requestOptions.getToolCallbacks());

        return new Prompt(prompt.getInstructions(), requestOptions);
    }

    private Map<String, String> mergeHttpHeaders(Map<String, String> runtimeHttpHeaders,
                                                 Map<String, String> defaultHttpHeaders) {
        var mergedHttpHeaders = new HashMap<>(defaultHttpHeaders);
        mergedHttpHeaders.putAll(runtimeHttpHeaders);
        return mergedHttpHeaders;
    }

    /**
     * Accessible for testing.
     */
    OpenAiApi.ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {

        List<OpenAiApi.ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions().stream().map(message -> {
            if (message.getMessageType() == MessageType.USER || message.getMessageType() == MessageType.SYSTEM) {
                Object content = message.getText();
                if (message instanceof UserMessage userMessage) {
                    if (!CollectionUtils.isEmpty(userMessage.getMedia())) {
                        List<OpenAiApi.ChatCompletionMessage.MediaContent> contentList = new ArrayList<>(List.of(new OpenAiApi.ChatCompletionMessage.MediaContent(message.getText())));

                        contentList.addAll(userMessage.getMedia().stream().map(this::mapToMediaContent).toList());

                        content = contentList;
                    }
                }

                return List.of(new OpenAiApi.ChatCompletionMessage(content,
                        OpenAiApi.ChatCompletionMessage.Role.valueOf(message.getMessageType().name())));
            }
            else if (message.getMessageType() == MessageType.ASSISTANT) {
                var assistantMessage = (AssistantMessage) message;
                List<OpenAiApi.ChatCompletionMessage.ToolCall> toolCalls = null;
                if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
                    toolCalls = assistantMessage.getToolCalls().stream().map(toolCall -> {
                        var function = new OpenAiApi.ChatCompletionMessage.ChatCompletionFunction(toolCall.name(), toolCall.arguments());
                        return new OpenAiApi.ChatCompletionMessage.ToolCall(toolCall.id(), toolCall.type(), function);
                    }).toList();
                }
                OpenAiApi.ChatCompletionMessage.AudioOutput audioOutput = null;
                if (!CollectionUtils.isEmpty(assistantMessage.getMedia())) {
                    Assert.isTrue(assistantMessage.getMedia().size() == 1,
                            "Only one media content is supported for assistant messages");
                    audioOutput = new OpenAiApi.ChatCompletionMessage.AudioOutput(assistantMessage.getMedia().get(0).getId(), null, null, null);

                }
                return List.of(new OpenAiApi.ChatCompletionMessage(assistantMessage.getText(),
                        OpenAiApi.ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, null, audioOutput));
            }
            else if (message.getMessageType() == MessageType.TOOL) {
                ToolResponseMessage toolMessage = (ToolResponseMessage) message;

                toolMessage.getResponses()
                        .forEach(response -> Assert.isTrue(response.id() != null, "ToolResponseMessage must have an id"));
                return toolMessage.getResponses()
                        .stream()
                        .map(tr -> new OpenAiApi.ChatCompletionMessage(tr.responseData(), OpenAiApi.ChatCompletionMessage.Role.TOOL, tr.name(),
                                tr.id(), null, null, null))
                        .toList();
            }
            else {
                throw new IllegalArgumentException("Unsupported message type: " + message.getMessageType());
            }
        }).flatMap(List::stream).toList();

        OpenAiApi.ChatCompletionRequest request = new OpenAiApi.ChatCompletionRequest(chatCompletionMessages, stream);

        OpenAiChatOptions requestOptions = (OpenAiChatOptions) prompt.getOptions();
        request = ModelOptionsUtils.merge(requestOptions, request, OpenAiApi.ChatCompletionRequest.class);

        // Add the tool definitions to the request's tools parameter.
        List<ToolDefinition> toolDefinitions = this.toolCallingManager.resolveToolDefinitions(requestOptions);
        if (!CollectionUtils.isEmpty(toolDefinitions)) {
            request = ModelOptionsUtils.merge(
                    OpenAiChatOptions.builder().tools(this.getFunctionTools(toolDefinitions)).build(), request,
                    OpenAiApi.ChatCompletionRequest.class);
        }

        // Remove `streamOptions` from the request if it is not a streaming request
        if (request.streamOptions() != null && !stream) {
            logger.warn("Removing streamOptions from the request as it is not a streaming request!");
            request = request.streamOptions(null);
        }

        return request;
    }

    private OpenAiApi.ChatCompletionMessage.MediaContent mapToMediaContent(Media media) {
        var mimeType = media.getMimeType();
        if (MimeTypeUtils.parseMimeType("audio/mp3").equals(mimeType) || MimeTypeUtils.parseMimeType("audio/mpeg").equals(mimeType)) {
            return new OpenAiApi.ChatCompletionMessage.MediaContent(
                    new OpenAiApi.ChatCompletionMessage.MediaContent.InputAudio(fromAudioData(media.getData()), OpenAiApi.ChatCompletionMessage.MediaContent.InputAudio.Format.MP3));
        }
        if (MimeTypeUtils.parseMimeType("audio/wav").equals(mimeType)) {
            return new OpenAiApi.ChatCompletionMessage.MediaContent(
                    new OpenAiApi.ChatCompletionMessage.MediaContent.InputAudio(fromAudioData(media.getData()), OpenAiApi.ChatCompletionMessage.MediaContent.InputAudio.Format.WAV));
        }
        else {
            return new OpenAiApi.ChatCompletionMessage.MediaContent(
                    new OpenAiApi.ChatCompletionMessage.MediaContent.ImageUrl(this.fromMediaData(media.getMimeType(), media.getData())));
        }
    }

    private String fromAudioData(Object audioData) {
        if (audioData instanceof byte[] bytes) {
            return String.format("data:;base64,%s", Base64.getEncoder().encodeToString(bytes));
        }
        throw new IllegalArgumentException("Unsupported audio data type: " + audioData.getClass().getSimpleName());
    }

    private String fromMediaData(MimeType mimeType, Object mediaContentData) {
        if (mediaContentData instanceof byte[] bytes) {
            // Assume the bytes are an image. So, convert the bytes to a base64 encoded
            // following the prefix pattern.
            return String.format("data:%s;base64,%s", mimeType.toString(), Base64.getEncoder().encodeToString(bytes));
        }
        else if (mediaContentData instanceof String text) {
            // Assume the text is a URLs or a base64 encoded image prefixed by the user.
            return text;
        }
        else {
            throw new IllegalArgumentException(
                    "Unsupported media data type: " + mediaContentData.getClass().getSimpleName());
        }
    }

    private List<OpenAiApi.FunctionTool> getFunctionTools(List<ToolDefinition> toolDefinitions) {
        return toolDefinitions.stream().map(toolDefinition -> {
            var function = new OpenAiApi.FunctionTool.Function(toolDefinition.description(), toolDefinition.name(),
                    toolDefinition.inputSchema());
            return new OpenAiApi.FunctionTool(function);
        }).toList();
    }

    @Override
    public ChatOptions getDefaultOptions() {
        return OpenAiChatOptions.fromOptions(this.defaultOptions);
    }

    @Override
    public String toString() {
        return "AlibabaOpenAiChatModel [defaultOptions=" + this.defaultOptions + "]";
    }

    /**
     * Use the provided convention for reporting observation data
     * @param observationConvention The provided convention
     */
    public void setObservationConvention(ChatModelObservationConvention observationConvention) {
        Assert.notNull(observationConvention, "observationConvention cannot be null");
        this.observationConvention = observationConvention;
    }

    public static AlibabaOpenAiChatModel.Builder builder() {
        return new AlibabaOpenAiChatModel.Builder();
    }

    public static final class Builder {

        private OpenAiApi openAiApi;

        private OpenAiChatOptions defaultOptions = OpenAiChatOptions.builder()
                .model(OpenAiApi.DEFAULT_CHAT_MODEL)
                .temperature(0.7)
                .build();

        private ToolCallingManager toolCallingManager;

        private FunctionCallbackResolver functionCallbackResolver;

        private List<FunctionCallback> toolFunctionCallbacks;

        private RetryTemplate retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE;

        private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;

        private Builder() {
        }

        public AlibabaOpenAiChatModel.Builder openAiApi(OpenAiApi openAiApi) {
            this.openAiApi = openAiApi;
            return this;
        }

        public AlibabaOpenAiChatModel.Builder defaultOptions(OpenAiChatOptions defaultOptions) {
            this.defaultOptions = defaultOptions;
            return this;
        }

        public AlibabaOpenAiChatModel.Builder toolCallingManager(ToolCallingManager toolCallingManager) {
            this.toolCallingManager = toolCallingManager;
            return this;
        }

        @Deprecated
        public AlibabaOpenAiChatModel.Builder functionCallbackResolver(FunctionCallbackResolver functionCallbackResolver) {
            this.functionCallbackResolver = functionCallbackResolver;
            return this;
        }

        @Deprecated
        public AlibabaOpenAiChatModel.Builder toolFunctionCallbacks(List<FunctionCallback> toolFunctionCallbacks) {
            this.toolFunctionCallbacks = toolFunctionCallbacks;
            return this;
        }

        public AlibabaOpenAiChatModel.Builder retryTemplate(RetryTemplate retryTemplate) {
            this.retryTemplate = retryTemplate;
            return this;
        }

        public AlibabaOpenAiChatModel.Builder observationRegistry(ObservationRegistry observationRegistry) {
            this.observationRegistry = observationRegistry;
            return this;
        }

        public AlibabaOpenAiChatModel build() {
            if (toolCallingManager != null) {
                Assert.isNull(functionCallbackResolver,
                        "functionCallbackResolver cannot be set when toolCallingManager is set");
                Assert.isNull(toolFunctionCallbacks,
                        "toolFunctionCallbacks cannot be set when toolCallingManager is set");

                return new AlibabaOpenAiChatModel(openAiApi, defaultOptions, toolCallingManager, retryTemplate,
                        observationRegistry);
            }

            if (functionCallbackResolver != null) {
                Assert.isNull(toolCallingManager,
                        "toolCallingManager cannot be set when functionCallbackResolver is set");
                List<FunctionCallback> toolCallbacks = this.toolFunctionCallbacks != null ? this.toolFunctionCallbacks
                        : List.of();

                return new AlibabaOpenAiChatModel(openAiApi, defaultOptions, functionCallbackResolver, toolCallbacks,
                        retryTemplate, observationRegistry);
            }

            return new AlibabaOpenAiChatModel(openAiApi, defaultOptions, DEFAULT_TOOL_CALLING_MANAGER, retryTemplate,
                    observationRegistry);
        }

    }

}

配置ChatModel

接下来,我们要把AliababaOpenAiChatModel配置到Spring容器。

修改CommonConfiguration,添加配置:

@Bean
public AlibabaOpenAiChatModel alibabaOpenAiChatModel(OpenAiConnectionProperties commonProperties, OpenAiChatProperties chatProperties, ObjectProvider<RestClient.Builder> restClientBuilderProvider, ObjectProvider<WebClient.Builder> webClientBuilderProvider, ToolCallingManager toolCallingManager, RetryTemplate retryTemplate, ResponseErrorHandler responseErrorHandler, ObjectProvider<ObservationRegistry> observationRegistry, ObjectProvider<ChatModelObservationConvention> observationConvention) {
    String baseUrl = StringUtils.hasText(chatProperties.getBaseUrl()) ? chatProperties.getBaseUrl() : commonProperties.getBaseUrl();
    String apiKey = StringUtils.hasText(chatProperties.getApiKey()) ? chatProperties.getApiKey() : commonProperties.getApiKey();
    String projectId = StringUtils.hasText(chatProperties.getProjectId()) ? chatProperties.getProjectId() : commonProperties.getProjectId();
    String organizationId = StringUtils.hasText(chatProperties.getOrganizationId()) ? chatProperties.getOrganizationId() : commonProperties.getOrganizationId();
    Map<String, List<String>> connectionHeaders = new HashMap<>();
    if (StringUtils.hasText(projectId)) {
        connectionHeaders.put("OpenAI-Project", List.of(projectId));
    }

    if (StringUtils.hasText(organizationId)) {
        connectionHeaders.put("OpenAI-Organization", List.of(organizationId));
    }
    RestClient.Builder restClientBuilder = restClientBuilderProvider.getIfAvailable(RestClient::builder);
    WebClient.Builder webClientBuilder = webClientBuilderProvider.getIfAvailable(WebClient::builder);
    OpenAiApi openAiApi = OpenAiApi.builder().baseUrl(baseUrl).apiKey(new SimpleApiKey(apiKey)).headers(CollectionUtils.toMultiValueMap(connectionHeaders)).completionsPath(chatProperties.getCompletionsPath()).embeddingsPath("/v1/embeddings").restClientBuilder(restClientBuilder).webClientBuilder(webClientBuilder).responseErrorHandler(responseErrorHandler).build();
    AlibabaOpenAiChatModel chatModel = AlibabaOpenAiChatModel.builder().openAiApi(openAiApi).defaultOptions(chatProperties.getOptions()).toolCallingManager(toolCallingManager).retryTemplate(retryTemplate).observationRegistry((ObservationRegistry)observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)).build();
    Objects.requireNonNull(chatModel);
    observationConvention.ifAvailable(chatModel::setObservationConvention);
    return chatModel;
}

修改ChatClient

最后,让之前的ChatClient都使用自定义的AlibabaOpenAiChatModel.

修改CommonConfiguration中的ChatClient配置:

@Bean
public ChatClient chatClient(AlibabaOpenAiChatModel model, ChatMemory chatMemory) {
    return ChatClient.builder(model) // 创建ChatClient工厂实例
            .defaultOptions(ChatOptions.builder().model("qwen-omni-turbo").build())
            .defaultSystem("您是一家名为“黑马程序员”的职业教育公司的客户聊天助手,你的名字叫小黑。请以友好、乐于助人和愉快的方式解答用户的各种问题。")
            .defaultAdvisors(new SimpleLoggerAdvisor()) // 添加默认的Advisor,记录日志
            .defaultAdvisors(new MessageChatMemoryAdvisor(chatMemory))
            .build(); // 构建ChatClient实例

}

@Bean
public ChatClient serviceChatClient(
        AlibabaOpenAiChatModel model,
        ChatMemory chatMemory,
        CourseTools courseTools) {
    return ChatClient.builder(model)
            .defaultSystem(CUSTOMER_SERVICE_SYSTEM)
            .defaultAdvisors(
                    new MessageChatMemoryAdvisor(chatMemory), // CHAT MEMORY
                    new SimpleLoggerAdvisor())
            .defaultTools(courseTools)
            .build();
}

重启测试

OK,现在我们的应用能支持stream版本的FunctionCalling和音频识别了。

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值