diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java index 110b8c7ee27..aef51c09582 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java @@ -47,7 +47,6 @@ import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; -import org.springframework.ai.vertexai.gemini.common.VertexAiGeminiSafetySetting; import reactor.core.publisher.Flux; import org.springframework.ai.chat.messages.AssistantMessage; @@ -58,6 +57,9 @@ import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.EmptyUsage; +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.chat.metadata.UsageUtils; import org.springframework.ai.chat.model.AbstractToolCallSupport; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; @@ -77,6 +79,7 @@ import org.springframework.ai.model.function.FunctionCallingOptions; import org.springframework.ai.retry.RetryUtils; import org.springframework.ai.vertexai.gemini.common.VertexAiGeminiConstants; +import org.springframework.ai.vertexai.gemini.common.VertexAiGeminiSafetySetting; import org.springframework.ai.vertexai.gemini.metadata.VertexAiUsage; import org.springframework.beans.factory.DisposableBean; import org.springframework.lang.NonNull; @@ -95,6 +98,7 @@ * @author Mark Pollack * @author Soby Chacko * @author Jihoon Kim + * @author Ilayaperumal Gopinathan * @since 0.8.1 */ public class VertexAiGeminiChatModel extends AbstractToolCallSupport implements ChatModel, DisposableBean { @@ -286,6 +290,10 @@ private static Schema jsonToSchema(String json) { // https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini @Override public ChatResponse call(Prompt prompt) { + return this.internalCall(prompt, null); + } + + public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) { VertexAiGeminiChatOptions vertexAiGeminiChatOptions = vertexAiGeminiChatOptions(prompt); @@ -310,8 +318,10 @@ public ChatResponse call(Prompt prompt) { .flatMap(List::stream) .toList(); - ChatResponse chatResponse = new ChatResponse(generations, - toChatResponseMetadata(generateContentResponse)); + GenerateContentResponse.UsageMetadata usage = generateContentResponse.getUsageMetadata(); + Usage currentUsage = (usage != null) ? new VertexAiUsage(usage) : new EmptyUsage(); + Usage cumulativeUsage = UsageUtils.getCumulativeUsage(currentUsage, previousChatResponse); + ChatResponse chatResponse = new ChatResponse(generations, toChatResponseMetadata(cumulativeUsage)); observationContext.setResponse(chatResponse); return chatResponse; @@ -321,7 +331,7 @@ public ChatResponse call(Prompt prompt) { var toolCallConversation = handleToolCalls(prompt, response); // Recursively call the call method with the tool call message // conversation that contains the call responses. - return this.call(new Prompt(toolCallConversation, prompt.getOptions())); + return this.internalCall(new Prompt(toolCallConversation, prompt.getOptions()), response); } return response; @@ -330,6 +340,10 @@ public ChatResponse call(Prompt prompt) { @Override public Flux stream(Prompt prompt) { + return this.internalStream(prompt, null); + } + + public Flux internalStream(Prompt prompt, ChatResponse previousChatResponse) { return Flux.deferContextual(contextView -> { VertexAiGeminiChatOptions vertexAiGeminiChatOptions = vertexAiGeminiChatOptions(prompt); @@ -349,32 +363,51 @@ public Flux stream(Prompt prompt) { try { ResponseStream responseStream = request.model .generateContentStream(request.contents); - - return Flux.fromStream(responseStream.stream()).switchMap(response -> { - + Flux chatResponseFlux = Flux.fromStream(responseStream.stream()).switchMap(response -> { List generations = response.getCandidatesList() .stream() .map(this::responseCandidateToGeneration) .flatMap(List::stream) .toList(); - ChatResponse chatResponse = new ChatResponse(generations, toChatResponseMetadata(response)); - - if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(chatResponse, + GenerateContentResponse.UsageMetadata usage = response.getUsageMetadata(); + Usage currentUsage = (usage != null) ? new VertexAiUsage(usage) : new EmptyUsage(); + Usage cumulativeUsage = UsageUtils.getCumulativeUsage(currentUsage, previousChatResponse); + ChatResponse chatResponse = new ChatResponse(generations, toChatResponseMetadata(cumulativeUsage)); + return Flux.just(chatResponse); + }).buffer(2, 1).filter(bufferedResponses -> { + return bufferedResponses.size() == 2; + }).map(bufferList -> { + ChatResponse firstResponse = bufferList.get(0); + ChatResponse secondResponse = bufferList.get(1); + if (secondResponse != null && secondResponse.getMetadata() != null) { + // This is the usage from the final Chat response for a + // given Chat request. + Usage usage = secondResponse.getMetadata().getUsage(); + if (!UsageUtils.isEmpty(usage)) { + // Store the usage from the final response to the + // penultimate response for accumulation. + return new ChatResponse(firstResponse.getResults(), toChatResponseMetadata(usage)); + } + } + return firstResponse; + }); + Flux flux = chatResponseFlux.flatMap(response -> { + if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(response, Set.of(FinishReason.STOP.name(), FinishReason.FINISH_REASON_UNSPECIFIED.name()))) { - var toolCallConversation = handleToolCalls(prompt, chatResponse); + var toolCallConversation = handleToolCalls(prompt, response); // Recursively call the stream method with the tool call message // conversation that contains the call responses. - return this.stream(new Prompt(toolCallConversation, prompt.getOptions())); + return this.internalStream(new Prompt(toolCallConversation, prompt.getOptions()), response); } - - Flux chatResponseFlux = Flux.just(chatResponse) - .doOnError(observation::error) - .doFinally(s -> observation.stop()) - .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); - - return new MessageAggregator().aggregate(chatResponseFlux, observationContext::setResponse); - }); + else { + return Flux.just(response); + } + }) + .doOnError(observation::error) + .doFinally(s -> observation.stop()) + .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); + return new MessageAggregator().aggregate(flux, observationContext::setResponse); } catch (Exception e) { throw new RuntimeException("Failed to generate content", e); @@ -430,6 +463,10 @@ private ChatResponseMetadata toChatResponseMetadata(GenerateContentResponse resp return ChatResponseMetadata.builder().withUsage(new VertexAiUsage(response.getUsageMetadata())).build(); } + private ChatResponseMetadata toChatResponseMetadata(Usage usage) { + return ChatResponseMetadata.builder().withUsage(usage).build(); + } + private VertexAiGeminiChatOptions vertexAiGeminiChatOptions(Prompt prompt) { VertexAiGeminiChatOptions updatedRuntimeOptions = VertexAiGeminiChatOptions.builder().build(); if (prompt.getOptions() != null) { diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiChatModelFunctionCallingIT.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiChatModelFunctionCallingIT.java index cf05a5a5257..c292b65f7d5 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiChatModelFunctionCallingIT.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiChatModelFunctionCallingIT.java @@ -127,8 +127,14 @@ public void functionCallTestInferredOpenApiSchema() { logger.info("Response: {}", response); + assertThat(response.getResult()).isNotNull(); + assertThat(response.getResult().getOutput()).isNotNull(); assertThat(response.getResult().getOutput().getText()).containsAnyOf("15.0", "15"); + assertThat(response.getMetadata()).isNotNull(); + assertThat(response.getMetadata().getUsage()).isNotNull(); + assertThat(response.getMetadata().getUsage().getTotalTokens()).isGreaterThan(150).isLessThan(230); + ChatResponse response2 = this.chatModel .call(new Prompt("What is the payment status for transaction 696?", promptOptions)); @@ -214,6 +220,37 @@ public void functionCallTestInferredOpenApiSchemaStream() { } + @Test + public void functionCallTestUsageInferredOpenApiSchemaStream() { + + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Paris and in Tokyo? Return the temperature in Celsius."); + + List messages = new ArrayList<>(List.of(userMessage)); + + var promptOptions = VertexAiGeminiChatOptions.builder() + .withModel(VertexAiGeminiChatModel.ChatModel.GEMINI_1_5_FLASH) + .withFunctionCallbacks(List.of(FunctionCallback.builder() + .function("getCurrentWeather", new MockWeatherService()) + .schemaType(SchemaType.OPEN_API_SCHEMA) + .description("Get the current weather in a given location") + .inputType(MockWeatherService.Request.class) + .build())) + .build(); + + Flux responseFlux = this.chatModel.stream(new Prompt(messages, promptOptions)); + + ChatResponse chatResponse = responseFlux.blockLast(); + + logger.info("Response: {}", chatResponse); + + assertThat(chatResponse).isNotNull(); + assertThat(chatResponse.getMetadata()).isNotNull(); + assertThat(chatResponse.getMetadata().getUsage()).isNotNull(); + assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isGreaterThan(150).isLessThan(230); + + } + public record PaymentInfoRequest(String id) { }