diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java index ae5b3556933..9303e7a41fd 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java @@ -51,6 +51,7 @@ import com.azure.ai.openai.models.ChatRequestToolMessage; import com.azure.ai.openai.models.ChatRequestUserMessage; import com.azure.ai.openai.models.CompletionsFinishReason; +import com.azure.ai.openai.models.CompletionsUsage; import com.azure.ai.openai.models.ContentFilterResultsForPrompt; import com.azure.ai.openai.models.FunctionCall; import com.azure.core.util.BinaryData; @@ -70,6 +71,7 @@ import org.springframework.ai.chat.metadata.PromptMetadata; import org.springframework.ai.chat.metadata.PromptMetadata.PromptFilterMetadata; import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.chat.metadata.UsageUtils; import org.springframework.ai.chat.model.AbstractToolCallSupport; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; @@ -105,6 +107,7 @@ * @author timostark * @author Soby Chacko * @author Jihoon Kim + * @author Ilayaperumal Gopinathan * @see ChatModel * @see com.azure.ai.openai.OpenAIClient * @since 1.0.0 @@ -176,10 +179,10 @@ public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAi this.observationRegistry = observationRegistry; } - public static ChatResponseMetadata from(ChatCompletions chatCompletions, PromptMetadata promptFilterMetadata) { + public static ChatResponseMetadata from(ChatCompletions chatCompletions, PromptMetadata promptFilterMetadata, + Usage usage) { Assert.notNull(chatCompletions, "Azure OpenAI ChatCompletions must not be null"); String id = chatCompletions.getId(); - Usage usage = (chatCompletions.getUsage() != null) ? AzureOpenAiUsage.from(chatCompletions) : new EmptyUsage(); return ChatResponseMetadata.builder() .withId(id) .withUsage(usage) @@ -189,12 +192,40 @@ public static ChatResponseMetadata from(ChatCompletions chatCompletions, PromptM .build(); } + public static ChatResponseMetadata from(ChatCompletions chatCompletions, PromptMetadata promptFilterMetadata) { + Usage usage = (chatCompletions.getUsage() != null) ? AzureOpenAiUsage.from(chatCompletions) : new EmptyUsage(); + return from(chatCompletions, promptFilterMetadata, usage); + } + + public static ChatResponseMetadata from(ChatCompletions chatCompletions, PromptMetadata promptFilterMetadata, + CompletionsUsage usage) { + return from(chatCompletions, promptFilterMetadata, AzureOpenAiUsage.from(usage)); + } + + public static ChatResponseMetadata from(ChatResponse chatResponse, Usage usage) { + Assert.notNull(chatResponse, "ChatResponse must not be null"); + ChatResponseMetadata chatResponseMetadata = chatResponse.getMetadata(); + ChatResponseMetadata.Builder builder = ChatResponseMetadata.builder(); + builder.withId(chatResponseMetadata.getId()) + .withUsage(usage) + .withModel(chatResponseMetadata.getModel()) + .withPromptMetadata(chatResponseMetadata.getPromptMetadata()); + if (chatResponseMetadata.containsKey("system-fingerprint")) { + builder.withKeyValue("system-fingerprint", chatResponseMetadata.get("system-fingerprint")); + } + return builder.build(); + } + public AzureOpenAiChatOptions getDefaultOptions() { return AzureOpenAiChatOptions.fromOptions(this.defaultOptions); } @Override public ChatResponse call(Prompt prompt) { + return this.internalCall(prompt, null); + } + + public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) { ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(prompt) @@ -210,7 +241,7 @@ public ChatResponse call(Prompt prompt) { ChatCompletionsOptionsAccessHelper.setStream(options, false); ChatCompletions chatCompletions = this.openAIClient.getChatCompletions(options.getModel(), options); - ChatResponse chatResponse = toChatResponse(chatCompletions); + ChatResponse chatResponse = toChatResponse(chatCompletions, previousChatResponse); observationContext.setResponse(chatResponse); return chatResponse; }); @@ -220,7 +251,7 @@ && isToolCall(response, Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS var toolCallConversation = handleToolCalls(prompt, response); // Recursively call the call method with the tool call message // conversation that contains the call responses. - return this.call(new Prompt(toolCallConversation, prompt.getOptions())); + return this.internalCall(new Prompt(toolCallConversation, prompt.getOptions()), response); } return response; @@ -228,6 +259,10 @@ && isToolCall(response, Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS @Override public Flux stream(Prompt prompt) { + return this.internalStream(prompt, null); + } + + public Flux internalStream(Prompt prompt, ChatResponse previousChatResponse) { return Flux.deferContextual(contextView -> { ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt); @@ -279,16 +314,36 @@ public Flux stream(Prompt prompt) { }) .flatMap(mono -> mono); - return accessibleChatCompletionsFlux.switchMap(chatCompletions -> { - - ChatResponse chatResponse = toChatResponse(chatCompletions); + final Flux chatResponseFlux = accessibleChatCompletionsFlux.map(chatCompletion -> { + if (previousChatResponse == null) { + return toChatResponse(chatCompletion); + } + // Accumulate the usage from the previous chat response + CompletionsUsage usage = chatCompletion.getUsage(); + Usage currentChatResponseUsage = usage != null ? AzureOpenAiUsage.from(usage) : new EmptyUsage(); + Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage, previousChatResponse); + return toChatResponse(chatCompletion, accumulatedUsage); + }).buffer(2, 1).map(bufferList -> { + ChatResponse chatResponse1 = bufferList.get(0); + if (options.getStreamOptions() != null && options.getStreamOptions().isIncludeUsage()) { + if (bufferList.size() == 2) { + ChatResponse chatResponse2 = bufferList.get(1); + if (chatResponse2 != null && chatResponse2.getMetadata() != null + && !UsageUtils.isEmpty(chatResponse2.getMetadata().getUsage())) { + return toChatResponse(chatResponse1, chatResponse2.getMetadata().getUsage()); + } + } + } + return chatResponse1; + }); + return chatResponseFlux.flatMap(chatResponse -> { if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(chatResponse, Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS).toLowerCase()))) { var toolCallConversation = handleToolCalls(prompt, chatResponse); // Recursively call the call method with the tool call message // conversation that contains the call responses. - return this.stream(new Prompt(toolCallConversation, prompt.getOptions())); + return this.internalStream(new Prompt(toolCallConversation, prompt.getOptions()), chatResponse); } Flux flux = Flux.just(chatResponse) @@ -305,6 +360,44 @@ public Flux stream(Prompt prompt) { private ChatResponse toChatResponse(ChatCompletions chatCompletions) { + List generations = nullSafeList(chatCompletions.getChoices()).stream().map(choice -> { + // @formatter:off + Map metadata = Map.of( + "id", chatCompletions.getId() != null ? chatCompletions.getId() : "", + "choiceIndex", choice.getIndex(), + "finishReason", choice.getFinishReason() != null ? String.valueOf(choice.getFinishReason()) : ""); + // @formatter:on + return buildGeneration(choice, metadata); + }).toList(); + + PromptMetadata promptFilterMetadata = generatePromptMetadata(chatCompletions); + + return new ChatResponse(generations, from(chatCompletions, promptFilterMetadata)); + } + + private ChatResponse toChatResponse(ChatCompletions chatCompletions, Usage usage) { + + List generations = nullSafeList(chatCompletions.getChoices()).stream().map(choice -> { + // @formatter:off + Map metadata = Map.of( + "id", chatCompletions.getId() != null ? chatCompletions.getId() : "", + "choiceIndex", choice.getIndex(), + "finishReason", choice.getFinishReason() != null ? String.valueOf(choice.getFinishReason()) : ""); + // @formatter:on + return buildGeneration(choice, metadata); + }).toList(); + + PromptMetadata promptFilterMetadata = generatePromptMetadata(chatCompletions); + + return new ChatResponse(generations, from(chatCompletions, promptFilterMetadata, usage)); + } + + private ChatResponse toChatResponse(ChatResponse chatResponse, Usage usage) { + return new ChatResponse(chatResponse.getResults(), from(chatResponse, usage)); + } + + private ChatResponse toChatResponse(ChatCompletions chatCompletions, ChatResponse previousChatResponse) { + List generations = nullSafeList(chatCompletions.getChoices()).stream().map(choice -> { // @formatter:off Map metadata = Map.of( @@ -316,8 +409,12 @@ private ChatResponse toChatResponse(ChatCompletions chatCompletions) { }).toList(); PromptMetadata promptFilterMetadata = generatePromptMetadata(chatCompletions); - - return new ChatResponse(generations, from(chatCompletions, promptFilterMetadata)); + Usage currentUsage = null; + if (chatCompletions.getUsage() != null) { + currentUsage = AzureOpenAiUsage.from(chatCompletions); + } + Usage cumulativeUsage = UsageUtils.getCumulativeUsage(currentUsage, previousChatResponse); + return new ChatResponse(generations, from(chatCompletions, promptFilterMetadata, cumulativeUsage)); } private Generation buildGeneration(ChatChoice choice, Map metadata) { diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java index a80817cd1a3..614b0a69fa7 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java @@ -24,6 +24,7 @@ import java.util.stream.Collectors; import com.azure.ai.openai.OpenAIClientBuilder; +import com.azure.ai.openai.models.ChatCompletionStreamOptions; import com.azure.core.credential.AzureKeyCredential; import org.junit.jupiter.api.Test; import org.slf4j.Logger; @@ -80,7 +81,12 @@ void functionCallTest() { logger.info("Response: {}", response); + assertThat(response.getResult()).isNotNull(); + assertThat(response.getResult().getOutput()).isNotNull(); assertThat(response.getResult().getOutput().getText()).contains("30", "10", "15"); + assertThat(response.getMetadata()).isNotNull(); + assertThat(response.getMetadata().getUsage()).isNotNull(); + assertThat(response.getMetadata().getUsage().getTotalTokens()).isGreaterThan(600).isLessThan(800); } @Test @@ -142,6 +148,34 @@ void streamFunctionCallTest() { } + @Test + void streamFunctionCallUsageTest() { + UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); + + List messages = new ArrayList<>(List.of(userMessage)); + + ChatCompletionStreamOptions streamOptions = new ChatCompletionStreamOptions(); + streamOptions.setIncludeUsage(true); + + var promptOptions = AzureOpenAiChatOptions.builder() + .withDeploymentName(this.selectedModel) + .withFunctionCallbacks(List.of(FunctionCallback.builder() + .function("getCurrentWeather", new MockWeatherService()) + .description("Get the current weather in a given location") + .inputType(MockWeatherService.Request.class) + .build())) + .withStreamOptions(streamOptions) + .build(); + + Flux response = this.chatModel.stream(new Prompt(messages, promptOptions)); + + ChatResponse chatResponse = response.last().block(); + logger.info("Response: {}", chatResponse); + + assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isGreaterThan(600).isLessThan(800); + + } + @Test void functionCallSequentialAndStreamTest() {