diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java index 371772f73b7..23fce66201e 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java @@ -15,6 +15,37 @@ */ package org.springframework.ai.azure.openai; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.springframework.ai.azure.openai.metadata.AzureOpenAiUsage; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.ToolResponseMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.EmptyUsage; +import org.springframework.ai.chat.metadata.PromptMetadata; +import org.springframework.ai.chat.metadata.PromptMetadata.PromptFilterMetadata; +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.chat.model.AbstractToolCallSupport; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.model.function.FunctionCallbackContext; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; + import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.models.ChatChoice; import com.azure.ai.openai.models.ChatCompletions; @@ -41,37 +72,10 @@ import com.azure.ai.openai.models.FunctionDefinition; import com.azure.core.util.BinaryData; import com.azure.core.util.IterableStream; -import org.springframework.ai.azure.openai.metadata.AzureOpenAiUsage; -import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.ToolResponseMessage; -import org.springframework.ai.chat.messages.UserMessage; -import org.springframework.ai.chat.metadata.ChatGenerationMetadata; -import org.springframework.ai.chat.metadata.ChatResponseMetadata; -import org.springframework.ai.chat.metadata.PromptMetadata; -import org.springframework.ai.chat.metadata.PromptMetadata.PromptFilterMetadata; -import org.springframework.ai.chat.model.ChatModel; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.Generation; -import org.springframework.ai.chat.prompt.ChatOptions; -import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.model.ModelOptionsUtils; -import org.springframework.ai.chat.model.AbstractToolCallSupport; -import org.springframework.ai.model.function.FunctionCallbackContext; -import org.springframework.util.Assert; -import org.springframework.util.CollectionUtils; + import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.Set; -import java.util.concurrent.atomic.AtomicBoolean; - /** * {@link ChatModel} implementation for {@literal Microsoft Azure AI} backed by * {@link OpenAIClient}. @@ -136,37 +140,16 @@ public ChatResponse call(Prompt prompt) { ChatCompletions chatCompletions = this.openAIClient.getChatCompletions(options.getModel(), options); - if (isToolFunctionCall(chatCompletions)) { - List toolCallMessageConversation = this.handleToolCallRequests(prompt.getInstructions(), - chatCompletions); + ChatResponse chatResponse = toChatResponse(chatCompletions); + + if (isToolCall(chatResponse, Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS).toLowerCase()))) { + var toolCallConversation = handleToolCalls(prompt, chatResponse); // Recursively call the call method with the tool call message // conversation that contains the call responses. - return this.call(new Prompt(toolCallMessageConversation, prompt.getOptions())); + return this.call(new Prompt(toolCallConversation, prompt.getOptions())); } - List generations = nullSafeList(chatCompletions.getChoices()).stream() - .map(choice -> new Generation(choice.getMessage().getContent()) - .withGenerationMetadata(generateChoiceMetadata(choice))) - .toList(); - - PromptMetadata promptFilterMetadata = generatePromptMetadata(chatCompletions); - - return new ChatResponse(generations, from(chatCompletions, promptFilterMetadata)); - } - - public static ChatResponseMetadata from(ChatCompletions chatCompletions, PromptMetadata promptFilterMetadata) { - Assert.notNull(chatCompletions, "Azure OpenAI ChatCompletions must not be null"); - String id = chatCompletions.getId(); - AzureOpenAiUsage usage = AzureOpenAiUsage.from(chatCompletions); - ChatResponseMetadata chatResponseMetadata = ChatResponseMetadata.builder() - .withId(id) - .withUsage(usage) - .withModel(chatCompletions.getModel()) - .withPromptMetadata(promptFilterMetadata) - .withKeyValue("system-fingerprint", chatCompletions.getSystemFingerprint()) - .build(); - - return chatResponseMetadata; + return chatResponse; } @Override @@ -179,10 +162,9 @@ public Flux stream(Prompt prompt) { .getChatCompletionsStream(options.getModel(), options); final var isFunctionCall = new AtomicBoolean(false); - final var accessibleChatCompletionsFlux = Flux.fromIterable(chatCompletionsStream) + final Flux accessibleChatCompletionsFlux = Flux.fromIterable(chatCompletionsStream) // Note: the first chat completions can be ignored when using Azure OpenAI // service which is a known service bug. - // .skip(1) .filter(chatCompletions -> !CollectionUtils.isEmpty(chatCompletions.getChoices())) .map(chatCompletions -> { final var toolCalls = chatCompletions.getChoices().get(0).getDelta().getToolCalls(); @@ -204,58 +186,70 @@ public Flux stream(Prompt prompt) { }) .flatMap(mono -> mono); - return accessibleChatCompletionsFlux.switchMap(chatCompletion -> { - if (isToolFunctionCall(chatCompletion)) { - List toolCallMessageConversation = this.handleToolCallRequests(prompt.getInstructions(), - chatCompletion); - return this.stream(new Prompt(toolCallMessageConversation, prompt.getOptions())); + return accessibleChatCompletionsFlux.switchMap(chatCompletions -> { + + ChatResponse chatResponse = toChatResponse(chatCompletions); + + if (isToolCall(chatResponse, Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS).toLowerCase()))) { + var toolCallConversation = handleToolCalls(prompt, chatResponse); + // Recursively call the call method with the tool call message + // conversation that contains the call responses. + return this.stream(new Prompt(toolCallConversation, prompt.getOptions())); } - return Mono.just(chatCompletion).flatMapIterable(ChatCompletions::getChoices).map(choice -> { - var content = Optional.ofNullable(choice.getMessage()).orElse(choice.getDelta()).getContent(); - var generation = new Generation(content).withGenerationMetadata(generateChoiceMetadata(choice)); - return new ChatResponse(List.of(generation)); - }); + return Mono.just(chatResponse); }); } - private List handleToolCallRequests(List previousMessages, ChatCompletions chatCompletion) { + private ChatResponse toChatResponse(ChatCompletions chatCompletions) { + + List generations = nullSafeList(chatCompletions.getChoices()).stream().map(choice -> { + // @formatter:off + Map metadata = Map.of( + "id", chatCompletions.getId() != null ? chatCompletions.getId() : "", + "choiceIndex", choice.getIndex(), + "finishReason", choice.getFinishReason() != null ? String.valueOf(choice.getFinishReason()) : ""); + // @formatter:on + return buildGeneration(choice, metadata); + }).toList(); - ChatRequestAssistantMessage nativeAssistantMessage = this.extractAssistantMessage(chatCompletion); + PromptMetadata promptFilterMetadata = generatePromptMetadata(chatCompletions); - List assistantToolCalls = nativeAssistantMessage.getToolCalls() - .stream() - .map(tc -> (ChatCompletionsFunctionToolCall) tc) - .map(toolCall -> new AssistantMessage.ToolCall(toolCall.getId(), toolCall.getType(), - toolCall.getFunction().getName(), toolCall.getFunction().getArguments())) - .toList(); + return new ChatResponse(generations, from(chatCompletions, promptFilterMetadata)); + } + + private Generation buildGeneration(ChatChoice choice, Map metadata) { - AssistantMessage assistantMessage = new AssistantMessage(nativeAssistantMessage.getContent(), Map.of(), - assistantToolCalls); + var responseMessage = Optional.ofNullable(choice.getMessage()).orElse(choice.getDelta()); - ToolResponseMessage toolResponseMessage = this.executeFunctions(assistantMessage); + List toolCalls = responseMessage.getToolCalls() == null ? List.of() + : responseMessage.getToolCalls().stream().map(toolCall -> { + final var tc1 = (ChatCompletionsFunctionToolCall) toolCall; + String id = tc1.getId(); + String name = tc1.getFunction().getName(); + String arguments = tc1.getFunction().getArguments(); + return new AssistantMessage.ToolCall(id, "function", name, arguments); + }).toList(); - // History - List messages = new ArrayList<>(previousMessages); - messages.add(assistantMessage); - messages.add(toolResponseMessage); + var assistantMessage = new AssistantMessage(responseMessage.getContent(), metadata, toolCalls); + var generationMetadata = generateChoiceMetadata(choice); - return messages; + return new Generation(assistantMessage, generationMetadata); } - private ChatRequestAssistantMessage extractAssistantMessage(ChatCompletions response) { - final var accessibleChatChoice = response.getChoices().get(0); - var responseMessage = Optional.ofNullable(accessibleChatChoice.getMessage()) - .orElse(accessibleChatChoice.getDelta()); - ChatRequestAssistantMessage assistantMessage = new ChatRequestAssistantMessage(""); - final var toolCalls = responseMessage.getToolCalls(); - assistantMessage.setToolCalls(toolCalls.stream().map(tc -> { - final var tc1 = (ChatCompletionsFunctionToolCall) tc; - var toDowncast = new ChatCompletionsFunctionToolCall(tc.getId(), - new FunctionCall(tc1.getFunction().getName(), tc1.getFunction().getArguments())); - return ((ChatCompletionsToolCall) toDowncast); - }).toList()); - return assistantMessage; + public static ChatResponseMetadata from(ChatCompletions chatCompletions, PromptMetadata promptFilterMetadata) { + Assert.notNull(chatCompletions, "Azure OpenAI ChatCompletions must not be null"); + String id = chatCompletions.getId(); + Usage usage = (chatCompletions.getUsage() != null) ? AzureOpenAiUsage.from(chatCompletions) : new EmptyUsage(); + ChatResponseMetadata chatResponseMetadata = ChatResponseMetadata.builder() + .withId(id) + .withUsage(usage) + .withModel(chatCompletions.getModel()) + .withPromptMetadata(promptFilterMetadata) + .withKeyValue("system-fingerprint", chatCompletions.getSystemFingerprint()) + .build(); + + return chatResponseMetadata; } /** @@ -560,21 +554,6 @@ private ChatCompletionsOptions copy(ChatCompletionsOptions fromOptions) { return copyOptions; } - protected boolean isToolFunctionCall(ChatCompletions chatCompletions) { - - if (chatCompletions == null || CollectionUtils.isEmpty(chatCompletions.getChoices())) { - return false; - } - - var choice = chatCompletions.getChoices().get(0); - - if (choice == null || choice.getFinishReason() == null) { - return false; - } - - return choice.getFinishReason() == CompletionsFinishReason.TOOL_CALLS; - } - /** * Maps the SpringAI response format to the Azure response format * @param responseFormat SpringAI response format diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java index 8b257818aa3..2b876455c4c 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java @@ -206,7 +206,7 @@ public static String getDeploymentName() { return deploymentName; } else { - return "gpt-4-0125-preview"; + return "gpt-4o"; } } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/DeploymentNameUtil.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/DeploymentNameUtil.java index e8f9d61b9b5..dafd8f49a2a 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/DeploymentNameUtil.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/DeploymentNameUtil.java @@ -10,7 +10,7 @@ public static String getDeploymentName() { return deploymentName; } else { - return "gpt-4-0125-preview"; + return "gpt-4o"; } }