From 61a478c4f6f4c5435d8e3809fd020782f20ead9b Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Tue, 16 Jul 2024 14:47:02 +0200 Subject: [PATCH] Add high-level function calling support for Azure OpenAI Resolves #613, #1042 --- .../ai/azure/openai/AzureOpenAiChatModel.java | 305 +++++++----------- .../azure/openai/AzureOpenAiChatModelIT.java | 6 +- .../AzureOpenAiChatModelFunctionCallIT.java | 70 +++- 3 files changed, 192 insertions(+), 189 deletions(-) diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java index b883458481e..77902f353ac 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java @@ -15,6 +15,33 @@ */ package org.springframework.ai.azure.openai; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.springframework.ai.azure.openai.metadata.AzureOpenAiChatResponseMetadata; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.ToolResponseMessage; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.metadata.PromptMetadata; +import org.springframework.ai.chat.metadata.PromptMetadata.PromptFilterMetadata; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.model.function.AbstractToolCallSupport; +import org.springframework.ai.model.function.FunctionCallbackContext; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; + import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.models.ChatChoice; import com.azure.ai.openai.models.ChatCompletions; @@ -41,32 +68,9 @@ import com.azure.ai.openai.models.FunctionDefinition; import com.azure.core.util.BinaryData; import com.azure.core.util.IterableStream; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.ai.azure.openai.metadata.AzureOpenAiChatResponseMetadata; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.metadata.ChatGenerationMetadata; -import org.springframework.ai.chat.metadata.PromptMetadata; -import org.springframework.ai.chat.metadata.PromptMetadata.PromptFilterMetadata; -import org.springframework.ai.chat.model.ChatModel; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.Generation; -import org.springframework.ai.chat.prompt.ChatOptions; -import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.model.ModelOptionsUtils; -import org.springframework.ai.model.function.AbstractFunctionCallSupport; -import org.springframework.ai.model.function.FunctionCallbackContext; -import org.springframework.util.Assert; -import org.springframework.util.CollectionUtils; -import reactor.core.publisher.Flux; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashSet; -import java.util.List; -import java.util.Optional; -import java.util.Set; -import java.util.concurrent.atomic.AtomicBoolean; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; /** * {@link ChatModel} implementation for {@literal Microsoft Azure AI} backed by @@ -79,18 +83,16 @@ * @author Grogdunn * @author Benoit Moussaud * @author luocongqiu + * @author timostark * @see ChatModel * @see com.azure.ai.openai.OpenAIClient */ -public class AzureOpenAiChatModel extends - AbstractFunctionCallSupport implements ChatModel { +public class AzureOpenAiChatModel extends AbstractToolCallSupport implements ChatModel { private static final String DEFAULT_DEPLOYMENT_NAME = "gpt-35-turbo"; private static final Float DEFAULT_TEMPERATURE = 0.7f; - private final Logger logger = LoggerFactory.getLogger(getClass()); - /** * The {@link OpenAIClient} used to interact with the Azure OpenAI service. */ @@ -122,17 +124,6 @@ public AzureOpenAiChatModel(OpenAIClient microsoftOpenAiClient, AzureOpenAiChatO this.defaultOptions = options; } - /** - * @deprecated since 0.8.0, use - * {@link #AzureOpenAiChatModel(OpenAIClient, AzureOpenAiChatOptions)} instead. - */ - @Deprecated(forRemoval = true, since = "0.8.0") - public AzureOpenAiChatModel withDefaultOptions(AzureOpenAiChatOptions defaultOptions) { - Assert.notNull(defaultOptions, "DefaultOptions must not be null"); - this.defaultOptions = defaultOptions; - return this; - } - public AzureOpenAiChatOptions getDefaultOptions() { return AzureOpenAiChatOptions.fromOptions(this.defaultOptions); } @@ -143,9 +134,15 @@ public ChatResponse call(Prompt prompt) { ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt); options.setStream(false); - logger.trace("Azure ChatCompletionsOptions: {}", options); - ChatCompletions chatCompletions = this.callWithFunctionSupport(options); - logger.trace("Azure ChatCompletions: {}", chatCompletions); + ChatCompletions chatCompletions = this.openAIClient.getChatCompletions(options.getModel(), options); + + if (isToolFunctionCall(chatCompletions)) { + List toolCallMessageConversation = this.handleToolCallRequests(prompt.getInstructions(), + chatCompletions); + // Recursively call the call method with the tool call message + // conversation that contains the call responses. + return this.call(new Prompt(toolCallMessageConversation, prompt.getOptions())); + } List generations = nullSafeList(chatCompletions.getChoices()).stream() .map(choice -> new Generation(choice.getMessage().getContent()) @@ -167,10 +164,8 @@ public Flux stream(Prompt prompt) { IterableStream chatCompletionsStream = this.openAIClient .getChatCompletionsStream(options.getModel(), options); - Flux chatCompletionsFlux = Flux.fromIterable(chatCompletionsStream); - final var isFunctionCall = new AtomicBoolean(false); - final var accessibleChatCompletionsFlux = chatCompletionsFlux + final var accessibleChatCompletionsFlux = Flux.fromIterable(chatCompletionsStream) // Note: the first chat completions can be ignored when using Azure OpenAI // service which is a known service bug. .skip(1) @@ -193,16 +188,59 @@ public Flux stream(Prompt prompt) { return List.of(reduce); }) .flatMap(mono -> mono); - return accessibleChatCompletionsFlux - .switchMap(accessibleChatCompletions -> handleFunctionCallOrReturnStream(options, - Flux.just(accessibleChatCompletions))) - .flatMapIterable(ChatCompletions::getChoices) - .map(choice -> { + + return accessibleChatCompletionsFlux.switchMap(chatCompletion -> { + if (isToolFunctionCall(chatCompletion)) { + List toolCallMessageConversation = this.handleToolCallRequests(prompt.getInstructions(), + chatCompletion); + return this.stream(new Prompt(toolCallMessageConversation, prompt.getOptions())); + } + + return Mono.just(chatCompletion).flatMapIterable(ChatCompletions::getChoices).map(choice -> { var content = Optional.ofNullable(choice.getMessage()).orElse(choice.getDelta()).getContent(); var generation = new Generation(content).withGenerationMetadata(generateChoiceMetadata(choice)); return new ChatResponse(List.of(generation)); }); + }); + } + private List handleToolCallRequests(List previousMessages, ChatCompletions chatCompletion) { + + ChatRequestAssistantMessage nativeAssistantMessage = this.extractAssistantMessage(chatCompletion); + + List assistantToolCalls = nativeAssistantMessage.getToolCalls() + .stream() + .map(tc -> (ChatCompletionsFunctionToolCall) tc) + .map(toolCall -> new AssistantMessage.ToolCall(toolCall.getId(), toolCall.getType(), + toolCall.getFunction().getName(), toolCall.getFunction().getArguments())) + .toList(); + + AssistantMessage assistantMessage = new AssistantMessage(nativeAssistantMessage.getContent(), Map.of(), + assistantToolCalls); + + ToolResponseMessage toolResponseMessage = this.executeFuncitons(assistantMessage); + + // History + List messages = new ArrayList<>(previousMessages); + messages.add(assistantMessage); + messages.add(toolResponseMessage); + + return messages; + } + + private ChatRequestAssistantMessage extractAssistantMessage(ChatCompletions response) { + final var accessibleChatChoice = response.getChoices().get(0); + var responseMessage = Optional.ofNullable(accessibleChatChoice.getMessage()) + .orElse(accessibleChatChoice.getDelta()); + ChatRequestAssistantMessage assistantMessage = new ChatRequestAssistantMessage(""); + final var toolCalls = responseMessage.getToolCalls(); + assistantMessage.setToolCalls(toolCalls.stream().map(tc -> { + final var tc1 = (ChatCompletionsFunctionToolCall) tc; + var toDowncast = new ChatCompletionsFunctionToolCall(tc.getId(), + new FunctionCall(tc1.getFunction().getName(), tc1.getFunction().getArguments())); + return ((ChatCompletionsToolCall) toDowncast); + }).toList()); + return assistantMessage; } /** @@ -215,16 +253,14 @@ ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) { List azureMessages = prompt.getInstructions() .stream() .map(this::fromSpringAiMessage) + .flatMap(List::stream) .toList(); ChatCompletionsOptions options = new ChatCompletionsOptions(azureMessages); if (this.defaultOptions != null) { - // JSON merge doesn't due to Azure OpenAI service bug: - // https://github.com/Azure/azure-sdk-for-java/issues/38183 - // options = ModelOptionsUtils.merge(options, this.defaultOptions, - // ChatCompletionsOptions.class); - options = merge(options, this.defaultOptions); + + options = this.merge(options, this.defaultOptions); Set defaultEnabledFunctions = this.handleFunctionCallbackConfigurations(this.defaultOptions, !IS_RUNTIME_CALL); @@ -234,11 +270,7 @@ ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) { if (prompt.getOptions() != null) { AzureOpenAiChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class, AzureOpenAiChatOptions.class); - // JSON merge doesn't due to Azure OpenAI service bug: - // https://github.com/Azure/azure-sdk-for-java/issues/38183 - // options = ModelOptionsUtils.merge(runtimeOptions, options, - // ChatCompletionsOptions.class); - options = merge(updatedRuntimeOptions, options); + options = this.merge(updatedRuntimeOptions, options); Set promptEnabledFunctions = this.handleFunctionCallbackConfigurations(updatedRuntimeOptions, IS_RUNTIME_CALL); @@ -270,7 +302,7 @@ private List getFunctionTools(Set }).toList(); } - private ChatRequestMessage fromSpringAiMessage(Message message) { + private List fromSpringAiMessage(Message message) { switch (message.getMessageType()) { case USER: @@ -284,15 +316,41 @@ private ChatRequestMessage fromSpringAiMessage(Message message) { new ChatMessageImageUrl(media.getData().toString()))) .toList()); } - return new ChatRequestUserMessage(items); + return List.of(new ChatRequestUserMessage(items)); case SYSTEM: - return new ChatRequestSystemMessage(message.getContent()); - case ASSISTANT: - return new ChatRequestAssistantMessage(message.getContent()); + return List.of(new ChatRequestSystemMessage(message.getContent())); + case ASSISTANT: { + AssistantMessage assistantMessage = (AssistantMessage) message; + List toolCalls = null; + if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) { + toolCalls = assistantMessage.getToolCalls().stream().map(toolCall -> { + var function = new FunctionCall(toolCall.name(), toolCall.arguments()); + return new ChatCompletionsFunctionToolCall(toolCall.id(), function); + }) + .map(tc -> ((ChatCompletionsToolCall) tc)) // !!! + .toList(); + } + var azureAssistantMessage = new ChatRequestAssistantMessage(message.getContent()); + azureAssistantMessage.setToolCalls(toolCalls); + return List.of(azureAssistantMessage); + } + case TOOL: { + ToolResponseMessage toolMessage = (ToolResponseMessage) message; + + toolMessage.getResponses().forEach(response -> { + Assert.isTrue(response.id() != null, "ToolResponseMessage must have an id"); + Assert.isTrue(response.name() != null, "ToolResponseMessage must have a name"); + }); + + return toolMessage.getResponses() + .stream() + .map(tr -> new ChatRequestToolMessage(tr.responseData(), tr.id())) + .map(crtm -> ((ChatRequestMessage) crtm)) + .toList(); + } default: throw new IllegalArgumentException("Unknown message type " + message.getMessageType()); } - } private ChatGenerationMetadata generateChoiceMetadata(ChatChoice choice) { @@ -438,58 +496,6 @@ private ChatCompletionsOptions merge(AzureOpenAiChatOptions fromSpringAiOptions, return mergedAzureOptions; } - /** - * Merges the fromOptions into the toOptions and returns a new ChatCompletionsOptions - * instance. - * @param fromOptions the ChatCompletionsOptions to merge from. - * @param toOptions the ChatCompletionsOptions to merge to. - * @return a new ChatCompletionsOptions instance. - */ - private ChatCompletionsOptions merge(ChatCompletionsOptions fromOptions, ChatCompletionsOptions toOptions) { - - if (fromOptions == null) { - return toOptions; - } - - ChatCompletionsOptions mergedOptions = this.copy(toOptions); - - if (fromOptions.getMaxTokens() != null) { - mergedOptions.setMaxTokens(fromOptions.getMaxTokens()); - } - if (fromOptions.getLogitBias() != null) { - mergedOptions.setLogitBias(fromOptions.getLogitBias()); - } - if (fromOptions.getStop() != null) { - mergedOptions.setStop(fromOptions.getStop()); - } - if (fromOptions.getTemperature() != null) { - mergedOptions.setTemperature(fromOptions.getTemperature()); - } - if (fromOptions.getTopP() != null) { - mergedOptions.setTopP(fromOptions.getTopP()); - } - if (fromOptions.getFrequencyPenalty() != null) { - mergedOptions.setFrequencyPenalty(fromOptions.getFrequencyPenalty()); - } - if (fromOptions.getPresencePenalty() != null) { - mergedOptions.setPresencePenalty(fromOptions.getPresencePenalty()); - } - if (fromOptions.getN() != null) { - mergedOptions.setN(fromOptions.getN()); - } - if (fromOptions.getUser() != null) { - mergedOptions.setUser(fromOptions.getUser()); - } - if (fromOptions.getModel() != null) { - mergedOptions.setModel(fromOptions.getModel()); - } - if (fromOptions.getResponseFormat() != null) { - mergedOptions.setResponseFormat(fromOptions.getResponseFormat()); - } - - return mergedOptions; - } - /** * Copy the fromOptions into a new ChatCompletionsOptions instance. * @param fromOptions the ChatCompletionsOptions to copy from. @@ -537,67 +543,6 @@ private ChatCompletionsOptions copy(ChatCompletionsOptions fromOptions) { return copyOptions; } - @Override - protected ChatCompletionsOptions doCreateToolResponseRequest(ChatCompletionsOptions previousRequest, - ChatRequestMessage responseMessage, List conversationHistory) { - - // Every tool-call item requires a separate function call and a response (TOOL) - // message. - for (ChatCompletionsToolCall toolCall : ((ChatRequestAssistantMessage) responseMessage).getToolCalls()) { - - var functionName = ((ChatCompletionsFunctionToolCall) toolCall).getFunction().getName(); - String functionArguments = ((ChatCompletionsFunctionToolCall) toolCall).getFunction().getArguments(); - - if (!this.functionCallbackRegister.containsKey(functionName)) { - throw new IllegalStateException("No function callback found for function name: " + functionName); - } - - String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments); - - // Add the function response to the conversation. - conversationHistory.add(new ChatRequestToolMessage(functionResponse, toolCall.getId())); - } - - // Recursively call chatCompletionWithTools until the model doesn't call a - // functions anymore. - ChatCompletionsOptions newRequest = new ChatCompletionsOptions(conversationHistory); - - newRequest = merge(previousRequest, newRequest); - - return newRequest; - } - - @Override - protected List doGetUserMessages(ChatCompletionsOptions request) { - return request.getMessages(); - } - - @Override - protected ChatRequestMessage doGetToolResponseMessage(ChatCompletions response) { - final var accessibleChatChoice = response.getChoices().get(0); - var responseMessage = Optional.ofNullable(accessibleChatChoice.getMessage()) - .orElse(accessibleChatChoice.getDelta()); - ChatRequestAssistantMessage assistantMessage = new ChatRequestAssistantMessage(""); - final var toolCalls = responseMessage.getToolCalls(); - assistantMessage.setToolCalls(toolCalls.stream().map(tc -> { - final var tc1 = (ChatCompletionsFunctionToolCall) tc; - var toDowncast = new ChatCompletionsFunctionToolCall(tc.getId(), - new FunctionCall(tc1.getFunction().getName(), tc1.getFunction().getArguments())); - return ((ChatCompletionsToolCall) toDowncast); - }).toList()); - return assistantMessage; - } - - @Override - protected ChatCompletions doChatCompletion(ChatCompletionsOptions request) { - return this.openAIClient.getChatCompletions(request.getModel(), request); - } - - @Override - protected Flux doChatCompletionStream(ChatCompletionsOptions request) { - return Flux.fromIterable(this.openAIClient.getChatCompletionsStream(request.getModel(), request)); - } - @Override protected boolean isToolFunctionCall(ChatCompletions chatCompletions) { diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelIT.java index 079b840e2cb..38fe7f36e34 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelIT.java @@ -150,7 +150,7 @@ void beanOutputConverterRecords() { Generation generation = chatModel.call(prompt).getResult(); ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent()); - System.out.println(actorsFilms); + logger.info("" + actorsFilms); assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); assertThat(actorsFilms.movies()).hasSize(5); } @@ -180,7 +180,7 @@ void beanStreamOutputConverterRecords() { .collect(Collectors.joining()); ActorsFilmsRecord actorsFilms = outputParser.convert(generationTextFromStream); - System.out.println(actorsFilms); + logger.info("" + actorsFilms); assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); assertThat(actorsFilms.movies()).hasSize(5); } @@ -223,7 +223,7 @@ public OpenAIClient openAIClient() { @Bean public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClient openAIClient) { return new AzureOpenAiChatModel(openAIClient, - AzureOpenAiChatOptions.builder().withDeploymentName("gpt-35-turbo").withMaxTokens(200).build()); + AzureOpenAiChatOptions.builder().withDeploymentName("gpt-35-turbo").withMaxTokens(1000).build()); } diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java index f4349e17211..8b257818aa3 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java @@ -17,6 +17,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Objects; import java.util.Optional; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; @@ -80,9 +81,31 @@ void functionCallTest() { logger.info("Response: {}", response); - assertThat(response.getResult().getOutput().getContent()).containsAnyOf("30.0", "30"); - assertThat(response.getResult().getOutput().getContent()).containsAnyOf("10.0", "10"); - assertThat(response.getResult().getOutput().getContent()).containsAnyOf("15.0", "15"); + assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); + } + + @Test + void functionCallSequentialTest() { + + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco? If the weather is above 25 degrees, please check the weather in Tokyo and Paris."); + + List messages = new ArrayList<>(List.of(userMessage)); + + var promptOptions = AzureOpenAiChatOptions.builder() + .withDeploymentName(selectedModel) + .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) + .withName("getCurrentWeather") + .withDescription("Get the current weather in a given location") + .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .build())) + .build(); + + ChatResponse response = chatModel.call(new Prompt(messages, promptOptions)); + + logger.info("Response: {}", response); + + assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); } @Test @@ -116,9 +139,44 @@ void streamFunctionCallTest() { assertThat(counter.get()).isGreaterThan(30).as("The response should be chunked in more than 30 messages"); - assertThat(content).containsAnyOf("30.0", "30"); - assertThat(content).containsAnyOf("10.0", "10"); - assertThat(content).containsAnyOf("15.0", "15"); + assertThat(content).contains("30", "10", "15"); + + } + + @Test + void functionCallSequentialAndStreamTest() { + + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco? If the weather is above 25 degrees, please check the weather in Tokyo and Paris."); + + List messages = new ArrayList<>(List.of(userMessage)); + + var promptOptions = AzureOpenAiChatOptions.builder() + .withDeploymentName(selectedModel) + .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) + .withName("getCurrentWeather") + .withDescription("Get the current weather in a given location") + .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .build())) + .build(); + + var response = chatModel.stream(new Prompt(messages, promptOptions)); + + final var counter = new AtomicInteger(); + String content = response.doOnEach(listSignal -> counter.getAndIncrement()) + .collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .filter(Objects::nonNull) + .collect(Collectors.joining()); + + logger.info("Response: {}", response); + + assertThat(content).contains("30", "10", "15"); } @SpringBootConfiguration