From 7f3297746a64fbb8f6a18722f078cd55ceb8cb4b Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Thu, 13 Feb 2025 00:46:52 +0100 Subject: [PATCH] feat(vertex-ai): Refactor Gemini for new tool calling API - Migrate from function calling to tool calling API - Add support for Gemini 2.0 models (flash, flash-lite) - Implement JSON schema to OpenAPI schema conversion - Add builder pattern for improved configuration - Deprecate legacy function calling constructors and methods - Update default model to GEMINI_2_0_FLASH - Add comprehensive test coverage for tool calling - Upgrade victools dependency to 4.37.0 - Update the Vertex Tool calling docs Part of the #2207 epic Signed-off-by: Christian Tzolov --- models/spring-ai-vertex-ai-gemini/pom.xml | 14 +- .../gemini/VertexAiGeminiChatModel.java | 410 +++++++++++++----- .../gemini/VertexAiGeminiChatOptions.java | 223 +++++----- .../gemini/schema/JsonSchemaConverter.java | 168 +++++++ .../schema/VertextToolCallingManager.java | 66 +++ .../gemini/CreateGeminiRequestTests.java | 146 ++++--- .../VertexAiChatModelObservationIT.java | 18 +- .../gemini/VertexAiGeminiChatModelIT.java | 26 +- .../gemini/VertexAiGeminiRetryTests.java | 2 +- .../schema/JsonSchemaConverterTests.java | 210 +++++++++ .../MockWeatherService.java | 2 +- ...texAiGeminiChatModelFunctionCallingIT.java | 46 +- .../VertexAiGeminiChatModelToolCallingIT.java | 212 +++++++++ ...GeminiPaymentTransactionDeprecatedIT.java} | 7 +- .../VertexAiGeminiPaymentTransactionIT.java | 247 +++++++++++ ...texAiGeminiPaymentTransactionMethodIT.java | 248 +++++++++++ pom.xml | 2 +- .../util/json/schema/JsonSchemaGenerator.java | 2 +- .../src/main/antora/modules/ROOT/nav.adoc | 1 - .../ROOT/pages/api/chat/bedrock-converse.adoc | 2 +- .../vertexai-gemini-chat-functions.adoc | 207 --------- .../pages/api/chat/vertexai-gemini-chat.adoc | 46 +- .../modules/ROOT/pages/api/functions.adoc | 19 +- .../VertexAiGeminiAutoConfiguration.java | 33 +- .../gemini/VertexAiGeminiChatProperties.java | 2 +- .../tool/FunctionCallWithFunctionBeanIT.java | 4 +- .../FunctionCallWithFunctionWrapperIT.java | 14 +- .../FunctionCallWithPromptFunctionIT.java | 14 +- 28 files changed, 1850 insertions(+), 541 deletions(-) create mode 100644 models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/schema/JsonSchemaConverter.java create mode 100644 models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/schema/VertextToolCallingManager.java create mode 100644 models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/schema/JsonSchemaConverterTests.java rename models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/{function => tool}/MockWeatherService.java (97%) rename models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/{function => tool}/VertexAiGeminiChatModelFunctionCallingIT.java (81%) create mode 100644 models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiChatModelToolCallingIT.java rename models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/{function/VertexAiGeminiPaymentTransactionIT.java => tool/VertexAiGeminiPaymentTransactionDeprecatedIT.java} (97%) create mode 100644 models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiPaymentTransactionIT.java create mode 100644 models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiPaymentTransactionMethodIT.java delete mode 100644 spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/vertexai-gemini-chat-functions.adoc diff --git a/models/spring-ai-vertex-ai-gemini/pom.xml b/models/spring-ai-vertex-ai-gemini/pom.xml index 07db09ee045..1f3324553ed 100644 --- a/models/spring-ai-vertex-ai-gemini/pom.xml +++ b/models/spring-ai-vertex-ai-gemini/pom.xml @@ -16,7 +16,8 @@ --> + xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" + xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd"> 4.0.0 org.springframework.ai @@ -53,6 +54,17 @@ + + com.github.victools + jsonschema-generator + ${victools.version} + + + com.github.victools + jsonschema-module-jackson + ${victools.version} + + com.google.cloud google-cloud-vertexai diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java index ac83c57176d..81880d7d731 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,10 +18,8 @@ import java.util.ArrayList; import java.util.Collection; -import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Set; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; @@ -47,7 +45,10 @@ import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; @@ -75,9 +76,15 @@ import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackResolver; import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.model.tool.LegacyToolCallingManager; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.model.tool.ToolExecutionResult; import org.springframework.ai.retry.RetryUtils; +import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.ai.vertexai.gemini.common.VertexAiGeminiConstants; import org.springframework.ai.vertexai.gemini.common.VertexAiGeminiSafetySetting; +import org.springframework.ai.vertexai.gemini.schema.VertextToolCallingManager; import org.springframework.beans.factory.DisposableBean; import org.springframework.lang.NonNull; import org.springframework.retry.support.RetryTemplate; @@ -102,6 +109,10 @@ public class VertexAiGeminiChatModel extends AbstractToolCallSupport implements private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention(); + private static final ToolCallingManager DEFAULT_TOOL_CALLING_MANAGER = ToolCallingManager.builder().build(); + + private final Logger logger = LoggerFactory.getLogger(getClass()); + private final VertexAI vertexAI; private final VertexAiGeminiChatOptions defaultOptions; @@ -118,29 +129,54 @@ public class VertexAiGeminiChatModel extends AbstractToolCallSupport implements */ private final ObservationRegistry observationRegistry; + /** + * Tool calling manager used to call tools. + */ + private final ToolCallingManager toolCallingManager; + /** * Conventions to use for generating observations. */ private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; + /** + * @deprecated Use {@link VertexAiGeminiChatModel.Builder}. + */ + @Deprecated public VertexAiGeminiChatModel(VertexAI vertexAI) { this(vertexAI, VertexAiGeminiChatOptions.builder().model(ChatModel.GEMINI_1_5_PRO).temperature(0.8).build()); } + /** + * @deprecated Use {@link VertexAiGeminiChatModel.Builder}. + */ + @Deprecated public VertexAiGeminiChatModel(VertexAI vertexAI, VertexAiGeminiChatOptions options) { this(vertexAI, options, null); } + /** + * @deprecated Use {@link VertexAiGeminiChatModel.Builder}. + */ + @Deprecated public VertexAiGeminiChatModel(VertexAI vertexAI, VertexAiGeminiChatOptions options, FunctionCallbackResolver functionCallbackResolver) { this(vertexAI, options, functionCallbackResolver, List.of()); } + /** + * @deprecated Use {@link VertexAiGeminiChatModel.Builder}. + */ + @Deprecated public VertexAiGeminiChatModel(VertexAI vertexAI, VertexAiGeminiChatOptions options, FunctionCallbackResolver functionCallbackResolver, List toolFunctionCallbacks) { this(vertexAI, options, functionCallbackResolver, toolFunctionCallbacks, RetryUtils.DEFAULT_RETRY_TEMPLATE); } + /** + * @deprecated Use {@link VertexAiGeminiChatModel.Builder}. + */ + @Deprecated public VertexAiGeminiChatModel(VertexAI vertexAI, VertexAiGeminiChatOptions options, FunctionCallbackResolver functionCallbackResolver, List toolFunctionCallbacks, RetryTemplate retryTemplate) { @@ -148,22 +184,49 @@ public VertexAiGeminiChatModel(VertexAI vertexAI, VertexAiGeminiChatOptions opti ObservationRegistry.NOOP); } + /** + * @deprecated Use {@link VertexAiGeminiChatModel.Builder}. + */ + @Deprecated public VertexAiGeminiChatModel(VertexAI vertexAI, VertexAiGeminiChatOptions options, FunctionCallbackResolver functionCallbackResolver, List toolFunctionCallbacks, RetryTemplate retryTemplate, ObservationRegistry observationRegistry) { - super(functionCallbackResolver, options, toolFunctionCallbacks); + this(vertexAI, options, + LegacyToolCallingManager.builder() + .functionCallbackResolver(functionCallbackResolver) + .functionCallbacks(toolFunctionCallbacks) + .build(), + retryTemplate, observationRegistry); + logger.warn("This constructor is deprecated and will be removed in the next milestone. " + + "Please use the new constructor accepting ToolCallingManager instead."); + + } + + public VertexAiGeminiChatModel(VertexAI vertexAI, VertexAiGeminiChatOptions defaultOptions, + ToolCallingManager toolCallingManager, RetryTemplate retryTemplate, + ObservationRegistry observationRegistry) { + + super(null, VertexAiGeminiChatOptions.builder().build(), List.of()); Assert.notNull(vertexAI, "VertexAI must not be null"); - Assert.notNull(options, "VertexAiGeminiChatOptions must not be null"); - Assert.notNull(options.getModel(), "VertexAiGeminiChatOptions.modelName must not be null"); + Assert.notNull(defaultOptions, "VertexAiGeminiChatOptions must not be null"); + Assert.notNull(defaultOptions.getModel(), "VertexAiGeminiChatOptions.modelName must not be null"); Assert.notNull(retryTemplate, "RetryTemplate must not be null"); + Assert.notNull(toolCallingManager, "ToolCallingManager must not be null"); this.vertexAI = vertexAI; - this.defaultOptions = options; - this.generationConfig = toGenerationConfig(options); + this.defaultOptions = defaultOptions; + this.generationConfig = toGenerationConfig(defaultOptions); this.retryTemplate = retryTemplate; this.observationRegistry = observationRegistry; + + if (toolCallingManager instanceof VertextToolCallingManager) { + this.toolCallingManager = toolCallingManager; + } + else { + this.toolCallingManager = new VertextToolCallingManager(toolCallingManager); + } } private static GeminiMessageType toGeminiMessageType(@NonNull MessageType type) { @@ -287,13 +350,16 @@ private static Schema jsonToSchema(String json) { // https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini @Override public ChatResponse call(Prompt prompt) { + var requestPrompt = this.buildRequestPrompt(prompt); + return this.internalCall(requestPrompt); + } - VertexAiGeminiChatOptions vertexAiGeminiChatOptions = vertexAiGeminiChatOptions(prompt); + private ChatResponse internalCall(Prompt prompt) { ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(prompt) .provider(VertexAiGeminiConstants.PROVIDER_NAME) - .requestOptions(vertexAiGeminiChatOptions) + .requestOptions(prompt.getOptions()) .build(); ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION @@ -301,7 +367,7 @@ public ChatResponse call(Prompt prompt) { this.observationRegistry) .observe(() -> this.retryTemplate.execute(context -> { - var geminiRequest = createGeminiRequest(prompt, vertexAiGeminiChatOptions); + var geminiRequest = createGeminiRequest(prompt); GenerateContentResponse generateContentResponse = this.getContentResponse(geminiRequest); @@ -318,26 +384,94 @@ public ChatResponse call(Prompt prompt) { return chatResponse; })); - if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(response, Set.of(FinishReason.STOP.name()))) { - var toolCallConversation = handleToolCalls(prompt, response); - // Recursively call the call method with the tool call message - // conversation that contains the call responses. - return this.call(new Prompt(toolCallConversation, prompt.getOptions())); + if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) && response != null + && response.hasToolCalls()) { + var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); + if (toolExecutionResult.returnDirect()) { + // Return tool execution result directly to the client. + return ChatResponse.builder() + .from(response) + .generations(ToolExecutionResult.buildGenerations(toolExecutionResult)) + .build(); + } + else { + // Send the tool execution result back to the model. + return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions())); + } } return response; } + Prompt buildRequestPrompt(Prompt prompt) { + // Process runtime options + VertexAiGeminiChatOptions runtimeOptions = null; + if (prompt.getOptions() != null) { + if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) { + runtimeOptions = ModelOptionsUtils.copyToTarget(toolCallingChatOptions, ToolCallingChatOptions.class, + VertexAiGeminiChatOptions.class); + } + else if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) { + runtimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions, FunctionCallingOptions.class, + VertexAiGeminiChatOptions.class); + } + else { + runtimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class, + VertexAiGeminiChatOptions.class); + } + } + + // Define request options by merging runtime options and default options + VertexAiGeminiChatOptions requestOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, + VertexAiGeminiChatOptions.class); + + // Merge @JsonIgnore-annotated options explicitly since they are ignored by + // Jackson, used by ModelOptionsUtils. + if (runtimeOptions != null) { + requestOptions.setInternalToolExecutionEnabled( + ModelOptionsUtils.mergeOption(runtimeOptions.isInternalToolExecutionEnabled(), + this.defaultOptions.isInternalToolExecutionEnabled())); + requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(), + this.defaultOptions.getToolNames())); + requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(), + this.defaultOptions.getToolCallbacks())); + requestOptions.setToolContext(ToolCallingChatOptions.mergeToolContext(runtimeOptions.getToolContext(), + this.defaultOptions.getToolContext())); + + requestOptions.setGoogleSearchRetrieval(ModelOptionsUtils.mergeOption( + runtimeOptions.getGoogleSearchRetrieval(), this.defaultOptions.getGoogleSearchRetrieval())); + requestOptions.setSafetySettings(ModelOptionsUtils.mergeOption(runtimeOptions.getSafetySettings(), + this.defaultOptions.getSafetySettings())); + } + else { + requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.isInternalToolExecutionEnabled()); + requestOptions.setToolNames(this.defaultOptions.getToolNames()); + requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks()); + requestOptions.setToolContext(this.defaultOptions.getToolContext()); + + requestOptions.setGoogleSearchRetrieval(this.defaultOptions.getGoogleSearchRetrieval()); + requestOptions.setSafetySettings(this.defaultOptions.getSafetySettings()); + } + + ToolCallingChatOptions.validateToolCallbacks(requestOptions.getToolCallbacks()); + + return new Prompt(prompt.getInstructions(), requestOptions); + } + @Override public Flux stream(Prompt prompt) { + var requestPrompt = this.buildRequestPrompt(prompt); + return this.internalStream(requestPrompt); + } + + public Flux internalStream(Prompt prompt) { return Flux.deferContextual(contextView -> { - VertexAiGeminiChatOptions vertexAiGeminiChatOptions = vertexAiGeminiChatOptions(prompt); ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(prompt) .provider(VertexAiGeminiConstants.PROVIDER_NAME) - .requestOptions(vertexAiGeminiChatOptions) + .requestOptions(prompt.getOptions()) .build(); Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( @@ -345,41 +479,56 @@ public Flux stream(Prompt prompt) { this.observationRegistry); observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start(); - var request = createGeminiRequest(prompt, vertexAiGeminiChatOptions); + + var request = createGeminiRequest(prompt); try { ResponseStream responseStream = request.model .generateContentStream(request.contents); - return Flux.fromStream(responseStream.stream()).switchMap(response -> { - - List generations = response.getCandidatesList() - .stream() - .map(this::responseCandidateToGeneration) - .flatMap(List::stream) - .toList(); - - ChatResponse chatResponse = new ChatResponse(generations, toChatResponseMetadata(response)); - - if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(chatResponse, - Set.of(FinishReason.STOP.name(), FinishReason.FINISH_REASON_UNSPECIFIED.name()))) { - var toolCallConversation = handleToolCalls(prompt, chatResponse); - // Recursively call the stream method with the tool call message - // conversation that contains the call responses. - return this.stream(new Prompt(toolCallConversation, prompt.getOptions())); + Flux chatResponse1 = Flux.fromStream(responseStream.stream()) + .switchMap(response2 -> Mono.just(response2).map(response -> { + + List generations = response.getCandidatesList() + .stream() + .map(this::responseCandidateToGeneration) + .flatMap(List::stream) + .toList(); + + return new ChatResponse(generations, toChatResponseMetadata(response)); + + })); + + // @formatter:off + Flux chatResponseFlux = chatResponse1.flatMap(response -> { + if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) && response.hasToolCalls()) { + var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); + if (toolExecutionResult.returnDirect()) { + // Return tool execution result directly to the client. + return Flux.just(ChatResponse.builder().from(response) + .generations(ToolExecutionResult.buildGenerations(toolExecutionResult)) + .build()); + } else { + // Send the tool execution result back to the model. + return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions())); + } + } + else { + return Flux.just(response); } + }) + .doOnError(observation::error) + .doFinally(s -> observation.stop()) + .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); + // @formatter:on; - Flux chatResponseFlux = Flux.just(chatResponse) - .doOnError(observation::error) - .doFinally(s -> observation.stop()) - .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); + return new MessageAggregator().aggregate(chatResponseFlux, observationContext::setResponse); - return new MessageAggregator().aggregate(chatResponseFlux, observationContext::setResponse); - }); } catch (Exception e) { throw new RuntimeException("Failed to generate content", e); } + }); } @@ -448,54 +597,40 @@ private VertexAiGeminiChatOptions vertexAiGeminiChatOptions(Prompt prompt) { VertexAiGeminiChatOptions.class); return updatedRuntimeOptions; - } - /** - * Tests access to the {@link #createGeminiRequest(Prompt, VertexAiGeminiChatOptions)} - * method. - */ - GeminiRequest createGeminiRequest(Prompt prompt, VertexAiGeminiChatOptions updatedRuntimeOptions) { - - Set functionsForThisRequest = new HashSet<>(); + GeminiRequest createGeminiRequest(Prompt prompt) { - GenerationConfig generationConfig = this.generationConfig; + VertexAiGeminiChatOptions requestOptions = (VertexAiGeminiChatOptions) prompt.getOptions(); - var generativeModelBuilder = new GenerativeModel.Builder().setModelName(this.defaultOptions.getModel()) - .setVertexAi(this.vertexAI) - .setSafetySettings(toGeminiSafetySettings(this.defaultOptions.getSafetySettings())); + var generativeModelBuilder = new GenerativeModel.Builder().setVertexAi(this.vertexAI) + .setSafetySettings(toGeminiSafetySettings(requestOptions.getSafetySettings())); - if (prompt.getOptions() != null) { - if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) { - updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions, - FunctionCallingOptions.class, VertexAiGeminiChatOptions.class); - } - else { - updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class, - VertexAiGeminiChatOptions.class); - } - functionsForThisRequest.addAll(runtimeFunctionCallbackConfigurations(updatedRuntimeOptions)); + if (requestOptions.getModel() != null) { + generativeModelBuilder.setModelName(requestOptions.getModel()); } - - if (!CollectionUtils.isEmpty(this.defaultOptions.getFunctions())) { - functionsForThisRequest.addAll(this.defaultOptions.getFunctions()); + else { + generativeModelBuilder.setModelName(this.defaultOptions.getModel()); } - if (updatedRuntimeOptions != null) { - - if (StringUtils.hasText(updatedRuntimeOptions.getModel()) - && !updatedRuntimeOptions.getModel().equals(this.defaultOptions.getModel())) { - // Override model name - generativeModelBuilder.setModelName(updatedRuntimeOptions.getModel()); - } + GenerationConfig generationConfig = this.generationConfig; - generationConfig = toGenerationConfig(updatedRuntimeOptions); + if (requestOptions != null) { + generationConfig = toGenerationConfig(requestOptions); } // Add the enabled functions definitions to the request's tools parameter. List tools = new ArrayList<>(); - if (!CollectionUtils.isEmpty(functionsForThisRequest)) { - tools.addAll(this.getFunctionTools(functionsForThisRequest)); + List toolDefinitions = this.toolCallingManager.resolveToolDefinitions(requestOptions); + if (!CollectionUtils.isEmpty(toolDefinitions)) { + final List functionDeclarations = toolDefinitions.stream() + .map(toolDefinition -> FunctionDeclaration.newBuilder() + .setName(toolDefinition.name()) + .setDescription(toolDefinition.description()) + .setParameters(jsonToSchema(toolDefinition.inputSchema())) + .build()) + .toList(); + tools.add(Tool.newBuilder().addAllFunctionDeclarations(functionDeclarations).build()); } if (prompt.getOptions() instanceof VertexAiGeminiChatOptions options && options.getGoogleSearchRetrieval()) { @@ -505,13 +640,13 @@ GeminiRequest createGeminiRequest(Prompt prompt, VertexAiGeminiChatOptions updat .build(); tools.add(googleSearchRetrievalTool); } + if (!CollectionUtils.isEmpty(tools)) { generativeModelBuilder.setTools(tools); } - if (prompt.getOptions() instanceof VertexAiGeminiChatOptions options - && !CollectionUtils.isEmpty(options.getSafetySettings())) { - generativeModelBuilder.setSafetySettings(toGeminiSafetySettings(options.getSafetySettings())); + if (!CollectionUtils.isEmpty(requestOptions.getSafetySettings())) { + generativeModelBuilder.setSafetySettings(toGeminiSafetySettings(requestOptions.getSafetySettings())); } generativeModelBuilder.setGenerationConfig(generationConfig); @@ -582,22 +717,6 @@ private List toGeminiSafetySettings(List getFunctionTools(Set functionNames) { - - final var tool = Tool.newBuilder(); - - final List functionDeclarations = this.resolveFunctionCallbacks(functionNames) - .stream() - .map(functionCallback -> FunctionDeclaration.newBuilder() - .setName(functionCallback.getName()) - .setDescription(functionCallback.getDescription()) - .setParameters(jsonToSchema(functionCallback.getInputTypeSchema())) - .build()) - .toList(); - tool.addAllFunctionDeclarations(functionDeclarations); - return List.of(tool.build()); - } - /** * Generates the content response based on the provided Gemini request. Package * protected for testing purposes. @@ -662,9 +781,15 @@ public enum ChatModel implements ChatModelDescription { GEMINI_PRO("gemini-pro"), - GEMINI_1_5_PRO("gemini-1.5-pro-001"), + GEMINI_1_5_PRO("gemini-1.5-pro-002"), + + GEMINI_1_5_FLASH("gemini-1.5-flash-002"), + + GEMINI_1_5_FLASH_8B("gemini-1.5-flash-8b-001"), + + GEMINI_2_0_FLASH("gemini-2.0-flash"), - GEMINI_1_5_FLASH("gemini-1.5-flash-001"); + GEMINI_2_0_FLASH_LIGHT("gemini-2.0-flash-lite-preview-02-05"); public final String value; @@ -688,4 +813,95 @@ public record GeminiRequest(List contents, GenerativeModel model) { } + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private VertexAI vertexAI; + + private VertexAiGeminiChatOptions defaultOptions = VertexAiGeminiChatOptions.builder() + .temperature(0.7) + .topP(1.0) + .model(VertexAiGeminiChatModel.ChatModel.GEMINI_2_0_FLASH) + .build(); + + private ToolCallingManager toolCallingManager; + + private FunctionCallbackResolver functionCallbackResolver; + + private List toolFunctionCallbacks; + + private RetryTemplate retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; + + private ObservationRegistry observationRegistry = ObservationRegistry.NOOP; + + private Builder() { + } + + public Builder vertexAI(VertexAI vertexAI) { + this.vertexAI = vertexAI; + return this; + } + + public Builder defaultOptions(VertexAiGeminiChatOptions defaultOptions) { + this.defaultOptions = defaultOptions; + return this; + } + + public Builder toolCallingManager(ToolCallingManager toolCallingManager) { + this.toolCallingManager = toolCallingManager; + return this; + } + + @Deprecated + public Builder functionCallbackResolver(FunctionCallbackResolver functionCallbackResolver) { + this.functionCallbackResolver = functionCallbackResolver; + return this; + } + + @Deprecated + public Builder toolFunctionCallbacks(List toolFunctionCallbacks) { + this.toolFunctionCallbacks = toolFunctionCallbacks; + return this; + } + + public Builder retryTemplate(RetryTemplate retryTemplate) { + this.retryTemplate = retryTemplate; + return this; + } + + public Builder observationRegistry(ObservationRegistry observationRegistry) { + this.observationRegistry = observationRegistry; + return this; + } + + public VertexAiGeminiChatModel build() { + if (toolCallingManager != null) { + Assert.isNull(functionCallbackResolver, + "functionCallbackResolver cannot be set when toolCallingManager is set"); + Assert.isNull(toolFunctionCallbacks, + "toolFunctionCallbacks cannot be set when toolCallingManager is set"); + + return new VertexAiGeminiChatModel(vertexAI, defaultOptions, toolCallingManager, retryTemplate, + observationRegistry); + } + + if (functionCallbackResolver != null) { + Assert.isNull(toolCallingManager, + "toolCallingManager cannot be set when functionCallbackResolver is set"); + List toolCallbacks = this.toolFunctionCallbacks != null ? this.toolFunctionCallbacks + : List.of(); + + return new VertexAiGeminiChatModel(vertexAI, defaultOptions, functionCallbackResolver, toolCallbacks, + retryTemplate, observationRegistry); + } + + return new VertexAiGeminiChatModel(vertexAI, defaultOptions, DEFAULT_TOOL_CALLING_MANAGER, retryTemplate, + observationRegistry); + } + + } + } diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java index 631a4df2717..9a39d9195f1 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,11 +29,12 @@ import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; -import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel.ChatModel; import org.springframework.ai.vertexai.gemini.common.VertexAiGeminiSafetySetting; +import org.springframework.lang.Nullable; import org.springframework.util.Assert; /** @@ -46,7 +47,7 @@ * @since 1.0.0 */ @JsonInclude(Include.NON_NULL) -public class VertexAiGeminiChatOptions implements FunctionCallingOptions { +public class VertexAiGeminiChatOptions implements ToolCallingChatOptions { // https://cloud.google.com/vertex-ai/docs/reference/rest/v1/GenerationConfig @@ -95,40 +96,36 @@ public class VertexAiGeminiChatOptions implements FunctionCallingOptions { private @JsonProperty("responseMimeType") String responseMimeType; /** - * Tool Function Callbacks to register with the ChatModel. - * For Prompt Options the functionCallbacks are automatically enabled for the duration of the prompt execution. - * For Default Options the functionCallbacks are registered but disabled by default. Use the enableFunctions to set the functions - * from the registry to be used by the ChatModel chat completion requests. + * Collection of {@link ToolCallback}s to be used for tool calling in the chat + * completion requests. */ @JsonIgnore - private List functionCallbacks = new ArrayList<>(); + private List toolCallbacks = new ArrayList<>(); /** - * List of functions, identified by their names, to configure for function calling in - * the chat completion requests. - * Functions with those names must exist in the functionCallbacks registry. - * The {@link #functionCallbacks} from the PromptOptions are automatically enabled for the duration of the prompt execution. - * - * Note that function enabled with the default options are enabled for all chat completion requests. This could impact the token count and the billing. - * If the functions is set in a prompt options, then the enabled functions are only active for the duration of this prompt execution. + * Collection of tool names to be resolved at runtime and used for tool calling in the + * chat completion requests. */ @JsonIgnore - private Set functions = new HashSet<>(); + private Set toolNames = new HashSet<>(); /** - * Use Google search Grounding feature + * Whether to enable the tool execution lifecycle internally in ChatModel. */ @JsonIgnore - private boolean googleSearchRetrieval = false; + private Boolean internalToolExecutionEnabled; @JsonIgnore - private List safetySettings = new ArrayList<>(); + private Map toolContext = new HashMap<>(); + /** + * Use Google search Grounding feature + */ @JsonIgnore - private Boolean proxyToolCalls; + private Boolean googleSearchRetrieval = false; @JsonIgnore - private Map toolContext; + private List safetySettings = new ArrayList<>(); public static Builder builder() { return new Builder(); @@ -145,13 +142,13 @@ public static VertexAiGeminiChatOptions fromOptions(VertexAiGeminiChatOptions fr options.setCandidateCount(fromOptions.getCandidateCount()); options.setMaxOutputTokens(fromOptions.getMaxOutputTokens()); options.setModel(fromOptions.getModel()); - options.setFunctionCallbacks(fromOptions.getFunctionCallbacks()); + options.setToolCallbacks(fromOptions.getToolCallbacks()); options.setResponseMimeType(fromOptions.getResponseMimeType()); - options.setFunctions(fromOptions.getFunctions()); + options.setToolNames(fromOptions.getToolNames()); options.setResponseMimeType(fromOptions.getResponseMimeType()); options.setGoogleSearchRetrieval(fromOptions.getGoogleSearchRetrieval()); options.setSafetySettings(fromOptions.getSafetySettings()); - options.setProxyToolCalls(fromOptions.getProxyToolCalls()); + options.setInternalToolExecutionEnabled(fromOptions.isInternalToolExecutionEnabled()); options.setToolContext(fromOptions.getToolContext()); return options; } @@ -236,20 +233,67 @@ public void setResponseMimeType(String mimeType) { this.responseMimeType = mimeType; } + @Override + @JsonIgnore + @Deprecated public List getFunctionCallbacks() { - return this.functionCallbacks; + return this.getToolCallbacks(); } + @Override + @JsonIgnore + @Deprecated public void setFunctionCallbacks(List functionCallbacks) { - this.functionCallbacks = functionCallbacks; + this.setToolCallbacks(functionCallbacks); } + @Override + public List getToolCallbacks() { + return this.toolCallbacks; + } + + @Override + public void setToolCallbacks(List toolCallbacks) { + Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); + Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements"); + this.toolCallbacks = toolCallbacks; + } + + @Override + @JsonIgnore + @Deprecated public Set getFunctions() { - return this.functions; + return this.getToolNames(); } + @JsonIgnore + @Deprecated public void setFunctions(Set functions) { - this.functions = functions; + this.setToolNames(functions); + } + + @Override + public Set getToolNames() { + return this.toolNames; + } + + @Override + public void setToolNames(Set toolNames) { + Assert.notNull(toolNames, "toolNames cannot be null"); + Assert.noNullElements(toolNames, "toolNames cannot contain null elements"); + toolNames.forEach(tool -> Assert.hasText(tool, "toolNames cannot contain empty elements")); + this.toolNames = toolNames; + } + + @Override + @Nullable + public Boolean isInternalToolExecutionEnabled() { + return internalToolExecutionEnabled; + } + + @Override + public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled) { + this.internalToolExecutionEnabled = internalToolExecutionEnabled; } @Override @@ -264,11 +308,11 @@ public Double getPresencePenalty() { return null; } - public boolean getGoogleSearchRetrieval() { + public Boolean getGoogleSearchRetrieval() { return this.googleSearchRetrieval; } - public void setGoogleSearchRetrieval(boolean googleSearchRetrieval) { + public void setGoogleSearchRetrieval(Boolean googleSearchRetrieval) { this.googleSearchRetrieval = googleSearchRetrieval; } @@ -281,13 +325,17 @@ public void setSafetySettings(List safetySettings) this.safetySettings = safetySettings; } + @Deprecated @Override + @JsonIgnore public Boolean getProxyToolCalls() { - return this.proxyToolCalls; + return this.internalToolExecutionEnabled != null ? !this.internalToolExecutionEnabled : null; } + @Deprecated + @JsonIgnore public void setProxyToolCalls(Boolean proxyToolCalls) { - this.proxyToolCalls = proxyToolCalls; + this.internalToolExecutionEnabled = proxyToolCalls != null ? !proxyToolCalls : null; } @Override @@ -314,18 +362,18 @@ public boolean equals(Object o) { && Objects.equals(this.topK, that.topK) && Objects.equals(this.candidateCount, that.candidateCount) && Objects.equals(this.maxOutputTokens, that.maxOutputTokens) && Objects.equals(this.model, that.model) && Objects.equals(this.responseMimeType, that.responseMimeType) - && Objects.equals(this.functionCallbacks, that.functionCallbacks) - && Objects.equals(this.functions, that.functions) + && Objects.equals(this.toolCallbacks, that.toolCallbacks) + && Objects.equals(this.toolNames, that.toolNames) && Objects.equals(this.safetySettings, that.safetySettings) - && Objects.equals(this.proxyToolCalls, that.proxyToolCalls) + && Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled) && Objects.equals(this.toolContext, that.toolContext); } @Override public int hashCode() { return Objects.hash(this.stopSequences, this.temperature, this.topP, this.topK, this.candidateCount, - this.maxOutputTokens, this.model, this.responseMimeType, this.functionCallbacks, this.functions, - this.googleSearchRetrieval, this.safetySettings, this.proxyToolCalls, this.toolContext); + this.maxOutputTokens, this.model, this.responseMimeType, this.toolCallbacks, this.toolNames, + this.googleSearchRetrieval, this.safetySettings, this.internalToolExecutionEnabled, this.toolContext); } @Override @@ -333,9 +381,9 @@ public String toString() { return "VertexAiGeminiChatOptions{" + "stopSequences=" + this.stopSequences + ", temperature=" + this.temperature + ", topP=" + this.topP + ", topK=" + this.topK + ", candidateCount=" + this.candidateCount + ", maxOutputTokens=" + this.maxOutputTokens + ", model='" + this.model + '\'' - + ", responseMimeType='" + this.responseMimeType + '\'' + ", functionCallbacks=" - + this.functionCallbacks + ", functions=" + this.functions + ", googleSearchRetrieval=" - + this.googleSearchRetrieval + ", safetySettings=" + this.safetySettings + '}'; + + ", responseMimeType='" + this.responseMimeType + '\'' + ", toolCallbacks=" + this.toolCallbacks + + ", toolNames=" + this.toolNames + ", googleSearchRetrieval=" + this.googleSearchRetrieval + + ", safetySettings=" + this.safetySettings + '}'; } @Override @@ -343,67 +391,6 @@ public VertexAiGeminiChatOptions copy() { return fromOptions(this); } - public FunctionCallingOptions merge(ChatOptions options) { - VertexAiGeminiChatOptions.Builder builder = VertexAiGeminiChatOptions.builder(); - - // Merge chat-specific options - builder.model(options.getModel() != null ? options.getModel() : this.getModel()) - .maxOutputTokens(options.getMaxTokens() != null ? options.getMaxTokens() : this.getMaxOutputTokens()) - .stopSequences(options.getStopSequences() != null ? options.getStopSequences() : this.getStopSequences()) - .temperature(options.getTemperature() != null ? options.getTemperature() : this.getTemperature()) - .topP(options.getTopP() != null ? options.getTopP() : this.getTopP()) - .topK(options.getTopK() != null ? options.getTopK() : this.getTopK()); - - // Try to get function-specific properties if options is a FunctionCallingOptions - if (options instanceof FunctionCallingOptions functionOptions) { - builder.proxyToolCalls(functionOptions.getProxyToolCalls() != null ? functionOptions.getProxyToolCalls() - : this.proxyToolCalls); - - Set functions = new HashSet<>(); - if (this.functions != null) { - functions.addAll(this.functions); - } - if (functionOptions.getFunctions() != null) { - functions.addAll(functionOptions.getFunctions()); - } - builder.functions(functions); - - List functionCallbacks = new ArrayList<>(); - if (this.functionCallbacks != null) { - functionCallbacks.addAll(this.functionCallbacks); - } - if (functionOptions.getFunctionCallbacks() != null) { - functionCallbacks.addAll(functionOptions.getFunctionCallbacks()); - } - builder.functionCallbacks(functionCallbacks); - - Map context = new HashMap<>(); - if (this.toolContext != null) { - context.putAll(this.toolContext); - } - if (functionOptions.getToolContext() != null) { - context.putAll(functionOptions.getToolContext()); - } - builder.toolContext(context); - } - else { - // If not a FunctionCallingOptions, preserve current function-specific - // properties - builder.proxyToolCalls(this.proxyToolCalls); - builder.functions(this.functions != null ? new HashSet<>(this.functions) : null); - builder.functionCallbacks(this.functionCallbacks != null ? new ArrayList<>(this.functionCallbacks) : null); - builder.toolContext(this.toolContext != null ? new HashMap<>(this.toolContext) : null); - } - - // Preserve Vertex AI Gemini-specific properties - builder.candidateCount(this.candidateCount) - .responseMimeType(this.responseMimeType) - .googleSearchRetrieval(this.googleSearchRetrieval) - .safetySettings(this.safetySettings != null ? new ArrayList<>(this.safetySettings) : null); - - return builder.build(); - } - public enum TransportType { GRPC, REST @@ -460,20 +447,35 @@ public Builder responseMimeType(String mimeType) { return this; } + @Deprecated public Builder functionCallbacks(List functionCallbacks) { - this.options.functionCallbacks = functionCallbacks; + return toolCallbacks(functionCallbacks); + } + + public Builder toolCallbacks(List toolCallbacks) { + this.options.toolCallbacks = toolCallbacks; return this; } + @Deprecated public Builder functions(Set functionNames) { - Assert.notNull(functionNames, "Function names must not be null"); - this.options.functions = functionNames; + return this.toolNames(functionNames); + } + + public Builder toolNames(Set toolNames) { + Assert.notNull(toolNames, "Function names must not be null"); + this.options.toolNames = toolNames; return this; } + @Deprecated public Builder function(String functionName) { - Assert.hasText(functionName, "Function name must not be empty"); - this.options.functions.add(functionName); + return this.toolName(functionName); + } + + public Builder toolName(String toolName) { + Assert.hasText(toolName, "Function name must not be empty"); + this.options.toolNames.add(toolName); return this; } @@ -488,8 +490,13 @@ public Builder safetySettings(List safetySettings) return this; } + @Deprecated public Builder proxyToolCalls(boolean proxyToolCalls) { - this.options.proxyToolCalls = proxyToolCalls; + return this.internalToolExecutionEnabled(proxyToolCalls); + } + + public Builder internalToolExecutionEnabled(boolean internalToolExecutionEnabled) { + this.options.internalToolExecutionEnabled = internalToolExecutionEnabled; return this; } diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/schema/JsonSchemaConverter.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/schema/JsonSchemaConverter.java new file mode 100644 index 00000000000..97b4bdaae66 --- /dev/null +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/schema/JsonSchemaConverter.java @@ -0,0 +1,168 @@ +/* +* Copyright 2025 - 2025 the original author or authors. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* https://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +package org.springframework.ai.vertexai.gemini.schema; + +/** + * @author Christian Tzolov + * @since 1.0.0 + */ + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; + +import org.springframework.util.Assert; + +/** + * Utility class for converting JSON Schema to OpenAPI schema format. + */ +public final class JsonSchemaConverter { + + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + + private JsonSchemaConverter() { + // Prevent instantiation + } + + public static ObjectNode fromJson(String jsonString) { + try { + return (ObjectNode) OBJECT_MAPPER.readTree(jsonString); + } + catch (Exception e) { + throw new RuntimeException("Failed to parse JSON: " + jsonString, e); + } + } + + /** + * Converts a JSON Schema ObjectNode to OpenAPI schema format. + * @param jsonSchemaNode The input JSON Schema as ObjectNode + * @return ObjectNode containing the OpenAPI schema + * @throws IllegalArgumentException if jsonSchemaNode is null + */ + public static ObjectNode convertToOpenApiSchema(ObjectNode jsonSchemaNode) { + Assert.notNull(jsonSchemaNode, "JSON Schema node must not be null"); + + try { + // Convert to OpenAPI schema using our custom conversion logic + ObjectNode openApiSchema = convertSchema(jsonSchemaNode, OBJECT_MAPPER.getNodeFactory()); + + // Add OpenAPI-specific metadata + if (!openApiSchema.has("openapi")) { + openApiSchema.put("openapi", "3.0.0"); + } + + return openApiSchema; + } + catch (Exception e) { + throw new IllegalStateException("Failed to convert JSON Schema to OpenAPI format: " + e.getMessage(), e); + } + } + + /** + * Copies common properties from source to target node. + * @param source The source ObjectNode containing JSON Schema properties + * @param target The target ObjectNode to copy properties to + */ + private static void copyCommonProperties(ObjectNode source, ObjectNode target) { + Assert.notNull(source, "Source node must not be null"); + Assert.notNull(target, "Target node must not be null"); + String[] commonProperties = { + // Core schema properties + "type", "format", "description", "default", "maximum", "minimum", "maxLength", "minLength", "pattern", + "enum", "multipleOf", "uniqueItems", + // OpenAPI specific properties + "example", "deprecated", "readOnly", "writeOnly", "nullable", "discriminator", "xml", "externalDocs" }; + + for (String prop : commonProperties) { + if (source.has(prop)) { + target.set(prop, source.get(prop)); + } + } + } + + /** + * Handles JSON Schema specific attributes and converts them to OpenAPI format. + * @param source The source ObjectNode containing JSON Schema + * @param target The target ObjectNode to store OpenAPI schema + */ + private static void handleJsonSchemaSpecifics(ObjectNode source, ObjectNode target) { + Assert.notNull(source, "Source node must not be null"); + Assert.notNull(target, "Target node must not be null"); + if (source.has("properties")) { + ObjectNode properties = target.putObject("properties"); + source.get("properties").fields().forEachRemaining(entry -> { + if (entry.getValue() instanceof ObjectNode) { + properties.set(entry.getKey(), + convertSchema((ObjectNode) entry.getValue(), OBJECT_MAPPER.getNodeFactory())); + } + }); + } + + // Handle required array + if (source.has("required")) { + target.set("required", source.get("required")); + } + + // Convert JSON Schema specific attributes to OpenAPI equivalents + if (source.has("additionalProperties")) { + JsonNode additionalProps = source.get("additionalProperties"); + if (additionalProps.isBoolean()) { + target.put("additionalProperties", additionalProps.asBoolean()); + } + else if (additionalProps.isObject()) { + target.set("additionalProperties", + convertSchema((ObjectNode) additionalProps, OBJECT_MAPPER.getNodeFactory())); + } + } + + // Handle arrays + if (source.has("items")) { + JsonNode items = source.get("items"); + if (items.isObject()) { + target.set("items", convertSchema((ObjectNode) items, OBJECT_MAPPER.getNodeFactory())); + } + } + + // Handle allOf, anyOf, oneOf + String[] combiners = { "allOf", "anyOf", "oneOf" }; + for (String combiner : combiners) { + if (source.has(combiner)) { + JsonNode combinerNode = source.get(combiner); + if (combinerNode.isArray()) { + target.putArray(combiner).addAll((com.fasterxml.jackson.databind.node.ArrayNode) combinerNode); + } + } + } + } + + /** + * Recursively converts a JSON Schema node to OpenAPI format. + * @param source The source ObjectNode containing JSON Schema + * @param factory The JsonNodeFactory to create new nodes + * @return The converted OpenAPI schema as ObjectNode + */ + private static ObjectNode convertSchema(ObjectNode source, + com.fasterxml.jackson.databind.node.JsonNodeFactory factory) { + Assert.notNull(source, "Source node must not be null"); + Assert.notNull(factory, "JsonNodeFactory must not be null"); + + ObjectNode converted = factory.objectNode(); + copyCommonProperties(source, converted); + handleJsonSchemaSpecifics(source, converted); + return converted; + } + +} diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/schema/VertextToolCallingManager.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/schema/VertextToolCallingManager.java new file mode 100644 index 00000000000..d7159344254 --- /dev/null +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/schema/VertextToolCallingManager.java @@ -0,0 +1,66 @@ +/* +* Copyright 2025 - 2025 the original author or authors. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* https://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +package org.springframework.ai.vertexai.gemini.schema; + +import java.util.List; + +import com.fasterxml.jackson.databind.node.ObjectNode; + +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.model.tool.ToolExecutionResult; +import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.util.json.schema.JsonSchemaGenerator; + +/** + * @author Christian Tzolov + * @since 1.0.0 + */ + +public class VertextToolCallingManager implements ToolCallingManager { + + private final ToolCallingManager delegateToolCallingManager; + + public VertextToolCallingManager(ToolCallingManager delegateToolCallingManager) { + this.delegateToolCallingManager = delegateToolCallingManager; + } + + @Override + public List resolveToolDefinitions(ToolCallingChatOptions chatOptions) { + + List toolDefinitions = delegateToolCallingManager.resolveToolDefinitions(chatOptions); + + return toolDefinitions.stream().map(td -> { + ObjectNode jsonSchema = JsonSchemaConverter.fromJson(td.inputSchema()); + ObjectNode openApiSchema = JsonSchemaConverter.convertToOpenApiSchema(jsonSchema); + JsonSchemaGenerator.convertTypeValuesToUpperCase(openApiSchema); + + return ToolDefinition.builder() + .name(td.name()) + .description(td.description()) + .inputSchema(openApiSchema.toPrettyString()) + .build(); + }).toList(); + } + + @Override + public ToolExecutionResult executeToolCalls(Prompt prompt, ChatResponse chatResponse) { + return this.delegateToolCallingManager.executeToolCalls(prompt, chatResponse); + } + +} diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/CreateGeminiRequestTests.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/CreateGeminiRequestTests.java index d92bf2b76e7..1b7ac94a4a0 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/CreateGeminiRequestTests.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/CreateGeminiRequestTests.java @@ -33,8 +33,12 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.Media; import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel.GeminiRequest; -import org.springframework.ai.vertexai.gemini.function.MockWeatherService; +import org.springframework.ai.vertexai.gemini.tool.MockWeatherService; import org.springframework.util.MimeTypeUtils; import static org.assertj.core.api.Assertions.assertThat; @@ -51,10 +55,13 @@ public class CreateGeminiRequestTests { @Test public void createRequestWithChatOptions() { - var client = new VertexAiGeminiChatModel(this.vertexAI, - VertexAiGeminiChatOptions.builder().model("DEFAULT_MODEL").temperature(66.6).build()); + var client = VertexAiGeminiChatModel.builder() + .vertexAI(this.vertexAI) + .defaultOptions(VertexAiGeminiChatOptions.builder().model("DEFAULT_MODEL").temperature(66.6).build()) + .build(); - GeminiRequest request = client.createGeminiRequest(new Prompt("Test message content"), null); + GeminiRequest request = client.createGeminiRequest(client + .buildRequestPrompt(new Prompt("Test message content", VertexAiGeminiChatOptions.builder().build()))); assertThat(request.contents()).hasSize(1); @@ -62,8 +69,8 @@ public void createRequestWithChatOptions() { assertThat(request.model().getModelName()).isEqualTo("DEFAULT_MODEL"); assertThat(request.model().getGenerationConfig().getTemperature()).isEqualTo(66.6f); - request = client.createGeminiRequest(new Prompt("Test message content", - VertexAiGeminiChatOptions.builder().model("PROMPT_MODEL").temperature(99.9).build()), null); + request = client.createGeminiRequest(client.buildRequestPrompt(new Prompt("Test message content", + VertexAiGeminiChatOptions.builder().model("PROMPT_MODEL").temperature(99.9).build()))); assertThat(request.contents()).hasSize(1); @@ -80,10 +87,13 @@ public void createRequestWithSystemMessage() throws MalformedURLException { var userMessage = new UserMessage("User Message Text", List.of(Media.builder().mimeType(MimeTypeUtils.IMAGE_PNG).data(new URL("http://example.com")).build())); - var client = new VertexAiGeminiChatModel(this.vertexAI, - VertexAiGeminiChatOptions.builder().model("DEFAULT_MODEL").temperature(66.6).build()); + var client = VertexAiGeminiChatModel.builder() + .vertexAI(this.vertexAI) + .defaultOptions(VertexAiGeminiChatOptions.builder().model("DEFAULT_MODEL").temperature(66.6).build()) + .build(); - GeminiRequest request = client.createGeminiRequest(new Prompt(List.of(systemMessage, userMessage)), null); + GeminiRequest request = client + .createGeminiRequest(client.buildRequestPrompt(new Prompt(List.of(systemMessage, userMessage)))); assertThat(request.model().getModelName()).isEqualTo("DEFAULT_MODEL"); assertThat(request.model().getGenerationConfig().getTemperature()).isEqualTo(66.6f); @@ -109,22 +119,30 @@ public void promptOptionsTools() { final String TOOL_FUNCTION_NAME = "CurrentWeather"; - var client = new VertexAiGeminiChatModel(this.vertexAI, - VertexAiGeminiChatOptions.builder().model("DEFAULT_MODEL").build()); + var toolCallingManager = ToolCallingManager.builder().build(); - var request = client.createGeminiRequest(new Prompt("Test message content", + var client = VertexAiGeminiChatModel.builder() + .vertexAI(this.vertexAI) + .defaultOptions(VertexAiGeminiChatOptions.builder().model("DEFAULT_MODEL").build()) + .toolCallingManager(toolCallingManager) + .build(); + + var requestPrompt = client.buildRequestPrompt(new Prompt("Test message content", VertexAiGeminiChatOptions.builder() .model("PROMPT_MODEL") - .functionCallbacks(List.of(FunctionCallback.builder() - .function(TOOL_FUNCTION_NAME, new MockWeatherService()) + .toolCallbacks(List.of(FunctionToolCallback.builder(TOOL_FUNCTION_NAME, new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build())) - .build()), - null); + .build())); + + var request = client.createGeminiRequest(requestPrompt); + + List toolDefinitions = toolCallingManager + .resolveToolDefinitions((ToolCallingChatOptions) requestPrompt.getOptions()); - assertThat(client.getFunctionCallbackRegister()).hasSize(1); - assertThat(client.getFunctionCallbackRegister()).containsKeys(TOOL_FUNCTION_NAME); + assertThat(toolDefinitions).hasSize(1); + assertThat(toolDefinitions.get(0).name()).isSameAs(TOOL_FUNCTION_NAME); assertThat(request.contents()).hasSize(1); assertThat(request.model().getSystemInstruction()).isNotPresent(); @@ -140,33 +158,44 @@ public void defaultOptionsTools() { final String TOOL_FUNCTION_NAME = "CurrentWeather"; - var client = new VertexAiGeminiChatModel(this.vertexAI, - VertexAiGeminiChatOptions.builder() - .model("DEFAULT_MODEL") - .functionCallbacks(List.of(FunctionCallback.builder() - .function(TOOL_FUNCTION_NAME, new MockWeatherService()) - .description("Get the weather in location") - .inputType(MockWeatherService.Request.class) - .build())) - .build()); + var toolCallingManager = ToolCallingManager.builder().build(); + + var client = VertexAiGeminiChatModel.builder() + .vertexAI(this.vertexAI) + .toolCallingManager(toolCallingManager) + .defaultOptions(VertexAiGeminiChatOptions.builder() + .model("DEFAULT_MODEL") + .functionCallbacks(List.of(FunctionCallback.builder() + .function(TOOL_FUNCTION_NAME, new MockWeatherService()) + .description("Get the weather in location") + .inputType(MockWeatherService.Request.class) + .build())) + .build()) + .build(); + + var requestPrompt = client.buildRequestPrompt(new Prompt("Test message content")); - var request = client.createGeminiRequest(new Prompt("Test message content"), null); + var request = client.createGeminiRequest(requestPrompt); - assertThat(client.getFunctionCallbackRegister()).hasSize(1); - assertThat(client.getFunctionCallbackRegister()).containsKeys(TOOL_FUNCTION_NAME); - assertThat(client.getFunctionCallbackRegister().get(TOOL_FUNCTION_NAME).getDescription()) - .isEqualTo("Get the weather in location"); + List toolDefinitions = toolCallingManager + .resolveToolDefinitions((ToolCallingChatOptions) requestPrompt.getOptions()); + + assertThat(toolDefinitions).hasSize(1); + assertThat(toolDefinitions.get(0).name()).isSameAs(TOOL_FUNCTION_NAME); + assertThat(toolDefinitions.get(0).description()).isEqualTo("Get the weather in location"); assertThat(request.contents()).hasSize(1); assertThat(request.model().getSystemInstruction()).isNotPresent(); assertThat(request.model().getModelName()).isEqualTo("DEFAULT_MODEL"); - assertThat(request.model().getTools()).as("Default Options callback functions are not automatically enabled!") - .isNullOrEmpty(); + assertThat(request.model().getTools()).hasSize(1); // Explicitly enable the function - request = client.createGeminiRequest(new Prompt("Test message content", - VertexAiGeminiChatOptions.builder().function(TOOL_FUNCTION_NAME).build()), null); + + requestPrompt = client.buildRequestPrompt(new Prompt("Test message content", + VertexAiGeminiChatOptions.builder().toolName(TOOL_FUNCTION_NAME).build())); + + request = client.createGeminiRequest(requestPrompt); assertThat(request.model().getTools()).hasSize(1); assertThat(request.model().getTools().get(0).getFunctionDeclarations(0).getName()) @@ -174,43 +203,48 @@ public void defaultOptionsTools() { .isEqualTo(TOOL_FUNCTION_NAME); // Override the default options function with one from the prompt - request = client.createGeminiRequest(new Prompt("Test message content", + requestPrompt = client.buildRequestPrompt(new Prompt("Test message content", VertexAiGeminiChatOptions.builder() .functionCallbacks(List.of(FunctionCallback.builder() .function(TOOL_FUNCTION_NAME, new MockWeatherService()) .description("Overridden function description") .inputType(MockWeatherService.Request.class) .build())) - .build()), - null); + .build())); + request = client.createGeminiRequest(requestPrompt); assertThat(request.model().getTools()).hasSize(1); assertThat(request.model().getTools().get(0).getFunctionDeclarations(0).getName()) .as("Explicitly enabled function") .isEqualTo(TOOL_FUNCTION_NAME); - assertThat(client.getFunctionCallbackRegister()).hasSize(1); - assertThat(client.getFunctionCallbackRegister()).containsKeys(TOOL_FUNCTION_NAME); - assertThat(client.getFunctionCallbackRegister().get(TOOL_FUNCTION_NAME).getDescription()) - .isEqualTo("Overridden function description"); + toolDefinitions = toolCallingManager + .resolveToolDefinitions((ToolCallingChatOptions) requestPrompt.getOptions()); + + assertThat(toolDefinitions).hasSize(1); + assertThat(toolDefinitions.get(0).name()).isSameAs(TOOL_FUNCTION_NAME); + assertThat(toolDefinitions.get(0).description()).isEqualTo("Overridden function description"); } @Test public void createRequestWithGenerationConfigOptions() { - var client = new VertexAiGeminiChatModel(this.vertexAI, - VertexAiGeminiChatOptions.builder() - .model("DEFAULT_MODEL") - .temperature(66.6) - .maxOutputTokens(100) - .topK(10) - .topP(5.0) - .stopSequences(List.of("stop1", "stop2")) - .candidateCount(1) - .responseMimeType("application/json") - .build()); - - GeminiRequest request = client.createGeminiRequest(new Prompt("Test message content"), null); + var client = VertexAiGeminiChatModel.builder() + .vertexAI(this.vertexAI) + .defaultOptions(VertexAiGeminiChatOptions.builder() + .model("DEFAULT_MODEL") + .temperature(66.6) + .maxOutputTokens(100) + .topK(10) + .topP(5.0) + .stopSequences(List.of("stop1", "stop2")) + .candidateCount(1) + .responseMimeType("application/json") + .build()) + .build(); + + GeminiRequest request = client + .createGeminiRequest(client.buildRequestPrompt(new Prompt("Test message content"))); assertThat(request.contents()).hasSize(1); diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiChatModelObservationIT.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiChatModelObservationIT.java index 2cb70793ade..6b97b9e70e9 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiChatModelObservationIT.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiChatModelObservationIT.java @@ -39,7 +39,6 @@ import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; -import org.springframework.retry.support.RetryTemplate; import static org.assertj.core.api.Assertions.assertThat; @@ -66,7 +65,7 @@ void beforeEach() { void observationForChatOperation() { var options = VertexAiGeminiChatOptions.builder() - .model(VertexAiGeminiChatModel.ChatModel.GEMINI_1_5_PRO.getValue()) + .model(VertexAiGeminiChatModel.ChatModel.GEMINI_2_0_FLASH.getValue()) .temperature(0.7) .stopSequences(List.of("this-is-the-end")) .maxOutputTokens(2048) @@ -88,7 +87,7 @@ void observationForChatOperation() { void observationForStreamingOperation() { var options = VertexAiGeminiChatOptions.builder() - .model(VertexAiGeminiChatModel.ChatModel.GEMINI_1_5_PRO.getValue()) + .model(VertexAiGeminiChatModel.ChatModel.GEMINI_2_0_FLASH.getValue()) .temperature(0.7) .stopSequences(List.of("this-is-the-end")) .maxOutputTokens(2048) @@ -128,7 +127,7 @@ private void validate(ChatResponseMetadata responseMetadata) { AiProvider.VERTEX_AI.value()) .hasLowCardinalityKeyValue( ChatModelObservationDocumentation.LowCardinalityKeyNames.REQUEST_MODEL.asString(), - VertexAiGeminiChatModel.ChatModel.GEMINI_1_5_PRO.getValue()) + VertexAiGeminiChatModel.ChatModel.GEMINI_2_0_FLASH.getValue()) .hasHighCardinalityKeyValue( ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_MAX_TOKENS.asString(), "2048") .hasHighCardinalityKeyValue( @@ -177,9 +176,14 @@ public VertexAI vertexAiApi() { @Bean public VertexAiGeminiChatModel vertexAiEmbedding(VertexAI vertexAi, TestObservationRegistry observationRegistry) { - return new VertexAiGeminiChatModel(vertexAi, - VertexAiGeminiChatOptions.builder().model(VertexAiGeminiChatModel.ChatModel.GEMINI_1_5_PRO).build(), - null, List.of(), RetryTemplate.defaultInstance(), observationRegistry); + + return VertexAiGeminiChatModel.builder() + .vertexAI(vertexAi) + .observationRegistry(observationRegistry) + .defaultOptions(VertexAiGeminiChatOptions.builder() + .model(VertexAiGeminiChatModel.ChatModel.GEMINI_2_0_FLASH) + .build()) + .build(); } } diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModelIT.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModelIT.java index f17a9aefd8d..fe307ef15d9 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModelIT.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModelIT.java @@ -26,6 +26,7 @@ import com.google.cloud.vertexai.Transport; import com.google.cloud.vertexai.VertexAI; import org.jetbrains.annotations.NotNull; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; @@ -41,6 +42,7 @@ import org.springframework.ai.converter.ListOutputConverter; import org.springframework.ai.converter.MapOutputConverter; import org.springframework.ai.model.Media; +import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel.ChatModel; import org.springframework.ai.vertexai.gemini.common.VertexAiGeminiSafetySetting; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; @@ -88,19 +90,24 @@ void testMessageHistory() { @Test void googleSearchTool() { - Prompt prompt = createPrompt(VertexAiGeminiChatOptions.builder().googleSearchRetrieval(true).build()); + Prompt prompt = createPrompt(VertexAiGeminiChatOptions.builder() + .model(ChatModel.GEMINI_1_5_PRO) // Only the pro model supports the google + // search tool + .googleSearchRetrieval(true) + .build()); ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResult().getOutput().getText()).containsAnyOf("Blackbeard", "Bartholomew"); } @Test + @Disabled void testSafetySettings() { List safetySettings = List.of(new VertexAiGeminiSafetySetting.Builder() .withCategory(VertexAiGeminiSafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT) .withThreshold(VertexAiGeminiSafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE) .build()); - Prompt prompt = new Prompt("What are common digital attack vectors?", - VertexAiGeminiChatOptions.builder().safetySettings(safetySettings).build()); + Prompt prompt = new Prompt("How to make cocktail Molotov bomb at home?", + VertexAiGeminiChatOptions.builder().model(ChatModel.GEMINI_PRO).safetySettings(safetySettings).build()); ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResult().getMetadata().getFinishReason()).isEqualTo("SAFETY"); } @@ -235,7 +242,8 @@ void multiModalityTest() throws IOException { // Response should contain something like: // I see a bunch of bananas in a golden basket. The bananas are ripe and yellow. - // There are also some red apples in the basket. The basket is sitting on a table. + // There are also some red apples in the basket. The basket is sitting on a + // table. // The background is a blurred light blue color.' assertThat(response.getResult().getOutput().getText()).satisfies(content -> { long count = Stream.of("bananas", "apple", "basket").filter(content::contains).count(); @@ -293,10 +301,12 @@ public VertexAI vertexAiApi() { @Bean public VertexAiGeminiChatModel vertexAiEmbedding(VertexAI vertexAi) { - return new VertexAiGeminiChatModel(vertexAi, - VertexAiGeminiChatOptions.builder() - .model(VertexAiGeminiChatModel.ChatModel.GEMINI_1_5_PRO) - .build()); + return VertexAiGeminiChatModel.builder() + .vertexAI(vertexAi) + .defaultOptions(VertexAiGeminiChatOptions.builder() + .model(VertexAiGeminiChatModel.ChatModel.GEMINI_2_0_FLASH) + .build()) + .build(); } } diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiRetryTests.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiRetryTests.java index f42ae4f7bb0..577e4b2086a 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiRetryTests.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiRetryTests.java @@ -75,7 +75,7 @@ public void setUp() { VertexAiGeminiChatOptions.builder() .temperature(0.7) .topP(1.0) - .model(VertexAiGeminiChatModel.ChatModel.GEMINI_PRO.getValue()) + .model(VertexAiGeminiChatModel.ChatModel.GEMINI_2_0_FLASH.getValue()) .build(), null, Collections.emptyList(), this.retryTemplate); diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/schema/JsonSchemaConverterTests.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/schema/JsonSchemaConverterTests.java new file mode 100644 index 00000000000..e60e1a683c6 --- /dev/null +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/schema/JsonSchemaConverterTests.java @@ -0,0 +1,210 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vertexai.gemini.schema; + +import com.fasterxml.jackson.databind.node.ObjectNode; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Nested; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Tests for {@link JsonSchemaConverter}. + * + * @author Christian Tzolov + */ +class JsonSchemaConverterTests { + + @Test + void fromJsonShouldParseValidJson() { + String json = "{\"type\":\"object\",\"properties\":{\"name\":{\"type\":\"string\"}}}"; + ObjectNode result = JsonSchemaConverter.fromJson(json); + + assertThat(result.get("type").asText()).isEqualTo("object"); + assertThat(result.get("properties").get("name").get("type").asText()).isEqualTo("string"); + } + + @Test + void fromJsonShouldThrowOnInvalidJson() { + String invalidJson = "{invalid:json}"; + assertThatThrownBy(() -> JsonSchemaConverter.fromJson(invalidJson)).isInstanceOf(RuntimeException.class) + .hasMessageContaining("Failed to parse JSON"); + } + + @Test + void convertToOpenApiSchemaShouldThrowOnNullInput() { + assertThatThrownBy(() -> JsonSchemaConverter.convertToOpenApiSchema(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("JSON Schema node must not be null"); + } + + @Nested + class SchemaConversionTests { + + @Test + void shouldConvertBasicSchema() { + String json = """ + { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "The name property" + } + }, + "required": ["name"] + } + """; + + ObjectNode result = JsonSchemaConverter.convertToOpenApiSchema(JsonSchemaConverter.fromJson(json)); + + assertThat(result.get("openapi").asText()).isEqualTo("3.0.0"); + assertThat(result.get("type").asText()).isEqualTo("object"); + assertThat(result.get("properties").get("name").get("type").asText()).isEqualTo("string"); + assertThat(result.get("properties").get("name").get("description").asText()).isEqualTo("The name property"); + assertThat(result.get("required").get(0).asText()).isEqualTo("name"); + } + + @Test + void shouldHandleArrayTypes() { + String json = """ + { + "type": "object", + "properties": { + "tags": { + "type": "array", + "items": { + "type": "string" + } + } + } + } + """; + + ObjectNode result = JsonSchemaConverter.convertToOpenApiSchema(JsonSchemaConverter.fromJson(json)); + + assertThat(result.get("properties").get("tags").get("type").asText()).isEqualTo("array"); + assertThat(result.get("properties").get("tags").get("items").get("type").asText()).isEqualTo("string"); + } + + @Test + void shouldHandleAdditionalProperties() { + String json = """ + { + "type": "object", + "additionalProperties": { + "type": "string" + } + } + """; + + ObjectNode result = JsonSchemaConverter.convertToOpenApiSchema(JsonSchemaConverter.fromJson(json)); + + assertThat(result.get("additionalProperties").get("type").asText()).isEqualTo("string"); + } + + @Test + void shouldHandleCombiningSchemas() { + String json = """ + { + "type": "object", + "allOf": [ + {"type": "object", "properties": {"name": {"type": "string"}}}, + {"type": "object", "properties": {"age": {"type": "integer"}}} + ] + } + """; + + ObjectNode result = JsonSchemaConverter.convertToOpenApiSchema(JsonSchemaConverter.fromJson(json)); + + assertThat(result.get("allOf")).isNotNull(); + assertThat(result.get("allOf").isArray()).isTrue(); + assertThat(result.get("allOf").size()).isEqualTo(2); + } + + @Test + void shouldCopyCommonProperties() { + String json = """ + { + "type": "string", + "format": "email", + "description": "Email address", + "minLength": 5, + "maxLength": 100, + "pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\\\.[a-zA-Z]{2,}$", + "example": "user@example.com", + "deprecated": false + } + """; + + ObjectNode result = JsonSchemaConverter.convertToOpenApiSchema(JsonSchemaConverter.fromJson(json)); + + assertThat(result.get("type").asText()).isEqualTo("string"); + assertThat(result.get("format").asText()).isEqualTo("email"); + assertThat(result.get("description").asText()).isEqualTo("Email address"); + assertThat(result.get("minLength").asInt()).isEqualTo(5); + assertThat(result.get("maxLength").asInt()).isEqualTo(100); + assertThat(result.get("pattern").asText()).isEqualTo("^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$"); + assertThat(result.get("example").asText()).isEqualTo("user@example.com"); + assertThat(result.get("deprecated").asBoolean()).isFalse(); + } + + @Test + void shouldHandleNestedObjects() { + String json = """ + { + "type": "object", + "properties": { + "user": { + "type": "object", + "properties": { + "address": { + "type": "object", + "properties": { + "street": {"type": "string"}, + "city": {"type": "string"} + } + } + } + } + } + } + """; + + ObjectNode result = JsonSchemaConverter.convertToOpenApiSchema(JsonSchemaConverter.fromJson(json)); + + assertThat(result.get("properties") + .get("user") + .get("properties") + .get("address") + .get("properties") + .get("street") + .get("type") + .asText()).isEqualTo("string"); + assertThat(result.get("properties") + .get("user") + .get("properties") + .get("address") + .get("properties") + .get("city") + .get("type") + .asText()).isEqualTo("string"); + } + + } + +} diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/MockWeatherService.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/MockWeatherService.java similarity index 97% rename from models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/MockWeatherService.java rename to models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/MockWeatherService.java index d79b317cda1..22e98ff8c2d 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/MockWeatherService.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/MockWeatherService.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.vertexai.gemini.function; +package org.springframework.ai.vertexai.gemini.tool; import java.util.function.Function; diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiChatModelFunctionCallingIT.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiChatModelFunctionCallingIT.java similarity index 81% rename from models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiChatModelFunctionCallingIT.java rename to models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiChatModelFunctionCallingIT.java index 82beeb07f20..f0604e48eac 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiChatModelFunctionCallingIT.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiChatModelFunctionCallingIT.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.vertexai.gemini.function; +package org.springframework.ai.vertexai.gemini.tool; import java.util.ArrayList; import java.util.List; @@ -49,6 +49,7 @@ @SpringBootTest @EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_PROJECT_ID", matches = ".*") @EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_LOCATION", matches = ".*") +@Deprecated public class VertexAiGeminiChatModelFunctionCallingIT { private static final Logger logger = LoggerFactory.getLogger(VertexAiGeminiChatModelFunctionCallingIT.class); @@ -106,7 +107,7 @@ public void functionCallTestInferredOpenApiSchema() { List messages = new ArrayList<>(List.of(userMessage)); var promptOptions = VertexAiGeminiChatOptions.builder() - .model(VertexAiGeminiChatModel.ChatModel.GEMINI_1_5_FLASH) + .model(VertexAiGeminiChatModel.ChatModel.GEMINI_2_0_FLASH) .functionCallbacks(List.of( FunctionToolCallback.builder("get_current_weather", new MockWeatherService()) .inputSchema(JsonSchemaGenerator.generateForType(MockWeatherService.Request.class, @@ -138,6 +139,43 @@ public void functionCallTestInferredOpenApiSchema() { } + @Test + public void functionCallTestInferredOpenApiSchema2() { + + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Paris and in Tokyo? Return the temperature in Celsius."); + + List messages = new ArrayList<>(List.of(userMessage)); + + var promptOptions = VertexAiGeminiChatOptions.builder() + .model(VertexAiGeminiChatModel.ChatModel.GEMINI_2_0_FLASH) + .functionCallbacks(List.of( + FunctionToolCallback.builder("get_current_weather", new MockWeatherService()) + .description("Get the current weather in a given location.") + .inputType(MockWeatherService.Request.class) + .build(), + FunctionToolCallback.builder("get_payment_status", new PaymentStatus()) + .description( + "Retrieves the payment status for transaction. For example what is the payment status for transaction 700?") + .inputType(PaymentInfoRequest.class) + .build())) + .build(); + + ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); + + logger.info("Response: {}", response); + + assertThat(response.getResult().getOutput().getText()).contains("30", "10", "15"); + + ChatResponse response2 = this.chatModel + .call(new Prompt("What is the payment status for transaction 696?", promptOptions)); + + logger.info("Response: {}", response2); + + assertThat(response2.getResult().getOutput().getText()).containsIgnoringCase("transaction 696 is PAYED"); + + } + @Test public void functionCallTestInferredOpenApiSchemaStream() { @@ -147,7 +185,7 @@ public void functionCallTestInferredOpenApiSchemaStream() { List messages = new ArrayList<>(List.of(userMessage)); var promptOptions = VertexAiGeminiChatOptions.builder() - .model(VertexAiGeminiChatModel.ChatModel.GEMINI_1_5_FLASH) + .model(VertexAiGeminiChatModel.ChatModel.GEMINI_2_0_FLASH) .functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .inputSchema(JsonSchemaGenerator.generateForType(MockWeatherService.Request.class, JsonSchemaGenerator.SchemaOption.UPPER_CASE_TYPE_VALUES)) @@ -207,7 +245,7 @@ public VertexAI vertexAiApi() { public VertexAiGeminiChatModel vertexAiEmbedding(VertexAI vertexAi) { return new VertexAiGeminiChatModel(vertexAi, VertexAiGeminiChatOptions.builder() - .model(VertexAiGeminiChatModel.ChatModel.GEMINI_1_5_PRO) + .model(VertexAiGeminiChatModel.ChatModel.GEMINI_2_0_FLASH) .temperature(0.9) .build()); } diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiChatModelToolCallingIT.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiChatModelToolCallingIT.java new file mode 100644 index 00000000000..8aa65819946 --- /dev/null +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiChatModelToolCallingIT.java @@ -0,0 +1,212 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vertexai.gemini.tool; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.Function; +import java.util.stream.Collectors; + +import com.google.cloud.vertexai.Transport; +import com.google.cloud.vertexai.VertexAI; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.tool.function.FunctionToolCallback; +import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel; +import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatOptions; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; + +import static org.assertj.core.api.Assertions.assertThat; + +@SpringBootTest +@EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_PROJECT_ID", matches = ".*") +@EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_LOCATION", matches = ".*") +public class VertexAiGeminiChatModelToolCallingIT { + + private static final Logger logger = LoggerFactory.getLogger(VertexAiGeminiChatModelToolCallingIT.class); + + @Autowired + private VertexAiGeminiChatModel chatModel; + + @Test + public void functionCallExplicitOpenApiSchema() { + + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Paris and in Tokyo? Return the temperature in Celsius."); + + List messages = new ArrayList<>(List.of(userMessage)); + + String openApiSchema = """ + { + "type": "OBJECT", + "properties": { + "location": { + "type": "STRING", + "description": "The city and state e.g. San Francisco, CA" + }, + "unit" : { + "type" : "STRING", + "enum" : [ "C", "F" ], + "description" : "Temperature unit" + } + }, + "required": ["location", "unit"] + } + """; + + var promptOptions = VertexAiGeminiChatOptions.builder() + .toolCallbacks(List.of(FunctionToolCallback.builder("get_current_weather", new MockWeatherService()) + .description("Get the current weather in a given location") + .inputSchema(openApiSchema) + .inputType(MockWeatherService.Request.class) + .build())) + .build(); + + ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); + + logger.info("Response: {}", response); + + assertThat(response.getResult().getOutput().getText()).contains("30", "10", "15"); + } + + @Test + public void functionCallTestInferredOpenApiSchema() { + + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Paris and in Tokyo? Return the temperature in Celsius."); + + List messages = new ArrayList<>(List.of(userMessage)); + + var promptOptions = VertexAiGeminiChatOptions.builder() + .model(VertexAiGeminiChatModel.ChatModel.GEMINI_2_0_FLASH) + .toolCallbacks(List.of( + FunctionToolCallback.builder("get_current_weather", new MockWeatherService()) + .description("Get the current weather in a given location.") + .inputType(MockWeatherService.Request.class) + .build(), + FunctionToolCallback.builder("get_payment_status", new PaymentStatus()) + .description( + "Retrieves the payment status for transaction. For example what is the payment status for transaction 700?") + .inputType(PaymentInfoRequest.class) + .build())) + .build(); + + ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); + + logger.info("Response: {}", response); + + assertThat(response.getResult().getOutput().getText()).contains("30", "10", "15"); + + ChatResponse response2 = this.chatModel + .call(new Prompt("What is the payment status for transaction 696?", promptOptions)); + + logger.info("Response: {}", response2); + + assertThat(response2.getResult().getOutput().getText()).containsIgnoringCase("transaction 696 is PAYED"); + + } + + @Test + public void functionCallTestInferredOpenApiSchemaStream() { + + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Paris and in Tokyo? Return the temperature in Celsius."); + + List messages = new ArrayList<>(List.of(userMessage)); + + var promptOptions = VertexAiGeminiChatOptions.builder() + .model(VertexAiGeminiChatModel.ChatModel.GEMINI_2_0_FLASH) + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) + .description("Get the current weather in a given location") + .inputType(MockWeatherService.Request.class) + .build())) + .build(); + + Flux response = this.chatModel.stream(new Prompt(messages, promptOptions)); + + String responseString = response.collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getText) + .collect(Collectors.joining()); + + logger.info("Response: {}", responseString); + + assertThat(responseString).contains("30", "10", "15"); + + } + + public record PaymentInfoRequest(String id) { + + } + + public record TransactionStatus(String status) { + + } + + public static class PaymentStatus implements Function { + + @Override + public TransactionStatus apply(PaymentInfoRequest paymentInfoRequest) { + return new TransactionStatus("Transaction " + paymentInfoRequest.id() + " is PAYED"); + } + + } + + @SpringBootConfiguration + public static class TestConfiguration { + + @Bean + public VertexAI vertexAiApi() { + String projectId = System.getenv("VERTEX_AI_GEMINI_PROJECT_ID"); + String location = System.getenv("VERTEX_AI_GEMINI_LOCATION"); + return new VertexAI.Builder().setLocation(location) + .setProjectId(projectId) + .setTransport(Transport.REST) + .build(); + } + + @Bean + public VertexAiGeminiChatModel vertexAiEmbedding(VertexAI vertexAi) { + return VertexAiGeminiChatModel.builder() + .vertexAI(vertexAi) + .defaultOptions(VertexAiGeminiChatOptions.builder() + .model(VertexAiGeminiChatModel.ChatModel.GEMINI_2_0_FLASH) + .temperature(0.9) + .build()) + .build(); + } + + } + +} diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiPaymentTransactionIT.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiPaymentTransactionDeprecatedIT.java similarity index 97% rename from models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiPaymentTransactionIT.java rename to models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiPaymentTransactionDeprecatedIT.java index 888f4536662..96ff923bc29 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiPaymentTransactionIT.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiPaymentTransactionDeprecatedIT.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.vertexai.gemini.function; +package org.springframework.ai.vertexai.gemini.tool; import java.util.List; import java.util.Map; @@ -55,9 +55,10 @@ @SpringBootTest @EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_PROJECT_ID", matches = ".*") @EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_LOCATION", matches = ".*") -public class VertexAiGeminiPaymentTransactionIT { +@Deprecated +public class VertexAiGeminiPaymentTransactionDeprecatedIT { - private static final Logger logger = LoggerFactory.getLogger(VertexAiGeminiPaymentTransactionIT.class); + private static final Logger logger = LoggerFactory.getLogger(VertexAiGeminiPaymentTransactionDeprecatedIT.class); private static final Map DATASET = Map.of(new Transaction("001"), new Status("pending"), new Transaction("002"), new Status("approved"), new Transaction("003"), new Status("rejected")); diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiPaymentTransactionIT.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiPaymentTransactionIT.java new file mode 100644 index 00000000000..ee0574d7645 --- /dev/null +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiPaymentTransactionIT.java @@ -0,0 +1,247 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vertexai.gemini.tool; + +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; + +import com.google.cloud.vertexai.Transport; +import com.google.cloud.vertexai.VertexAI; +import io.micrometer.observation.ObservationRegistry; +import org.junit.jupiter.api.RepeatedTest; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; +import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor; +import org.springframework.ai.tool.resolution.DelegatingToolCallbackResolver; +import org.springframework.ai.tool.resolution.SpringBeanToolCallbackResolver; +import org.springframework.ai.tool.resolution.StaticToolCallbackResolver; +import org.springframework.ai.tool.resolution.ToolCallbackResolver; +import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel; +import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatOptions; +import org.springframework.beans.factory.ObjectProvider; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Description; +import org.springframework.context.support.GenericApplicationContext; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Christian Tzolov + */ +@SpringBootTest +@EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_PROJECT_ID", matches = ".*") +@EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_LOCATION", matches = ".*") +public class VertexAiGeminiPaymentTransactionIT { + + private static final Logger logger = LoggerFactory.getLogger(VertexAiGeminiPaymentTransactionIT.class); + + private static final Map DATASET = Map.of(new Transaction("001"), new Status("pending"), + new Transaction("002"), new Status("approved"), new Transaction("003"), new Status("rejected")); + + @Autowired + ChatClient chatClient; + + @Test + public void paymentStatuses() { + // @formatter:off + String content = this.chatClient.prompt() + .advisors(new LoggingAdvisor()) + .tools("paymentStatus") + .user(""" + What is the status of my payment transactions 001, 002 and 003? + If requred invoke the function per transaction. + """).call().content(); + // @formatter:on + logger.info("" + content); + + assertThat(content).contains("001", "002", "003"); + assertThat(content).contains("pending", "approved", "rejected"); + } + + @RepeatedTest(5) + public void streamingPaymentStatuses() { + + Flux streamContent = this.chatClient.prompt() + .advisors(new LoggingAdvisor()) + .tools("paymentStatus") + .user(""" + What is the status of my payment transactions 001, 002 and 003? + If requred invoke the function per transaction. + """) + .stream() + .content(); + + String content = streamContent.collectList().block().stream().collect(Collectors.joining()); + + logger.info(content); + + assertThat(content).contains("001", "002", "003"); + assertThat(content).contains("pending", "approved", "rejected"); + + // Quota rate + try { + Thread.sleep(1000); + } + catch (InterruptedException e) { + } + } + + record TransactionStatusResponse(String id, String status) { + + } + + private static class LoggingAdvisor implements CallAroundAdvisor { + + private final Logger logger = LoggerFactory.getLogger(LoggingAdvisor.class); + + @Override + public String getName() { + return this.getClass().getSimpleName(); + } + + @Override + public int getOrder() { + return 0; + } + + @Override + public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { + var response = chain.nextAroundCall(before(advisedRequest)); + observeAfter(response); + return response; + } + + private AdvisedRequest before(AdvisedRequest request) { + logger.info("System text: \n" + request.systemText()); + logger.info("System params: " + request.systemParams()); + logger.info("User text: \n" + request.userText()); + logger.info("User params:" + request.userParams()); + logger.info("Function names: " + request.functionNames()); + + logger.info("Options: " + request.chatOptions().toString()); + + return request; + } + + private void observeAfter(AdvisedResponse advisedResponse) { + logger.info("Response: " + advisedResponse.response()); + } + + } + + record Transaction(String id) { + } + + record Status(String name) { + } + + record Transactions(List transactions) { + } + + record Statuses(List statuses) { + } + + @SpringBootConfiguration + public static class TestConfiguration { + + @Bean + @Description("Get the status of a single payment transaction") + public Function paymentStatus() { + return transaction -> { + logger.info("Single Transaction: " + transaction); + return DATASET.get(transaction); + }; + } + + @Bean + @Description("Get the list statuses of a list of payment transactions") + public Function paymentStatuses() { + return transactions -> { + logger.info("Transactions: " + transactions); + return new Statuses(transactions.transactions().stream().map(t -> DATASET.get(t)).toList()); + }; + } + + @Bean + public ChatClient chatClient(VertexAiGeminiChatModel chatModel) { + return ChatClient.builder(chatModel).build(); + } + + @Bean + public VertexAI vertexAiApi() { + + String projectId = System.getenv("VERTEX_AI_GEMINI_PROJECT_ID"); + String location = System.getenv("VERTEX_AI_GEMINI_LOCATION"); + + return new VertexAI.Builder().setLocation(location) + .setProjectId(projectId) + .setTransport(Transport.REST) + // .setTransport(Transport.GRPC) + .build(); + } + + @Bean + public VertexAiGeminiChatModel vertexAiChatModel(VertexAI vertexAi, ToolCallingManager toolCallingManager) { + + return VertexAiGeminiChatModel.builder() + .vertexAI(vertexAi) + .toolCallingManager(toolCallingManager) + .defaultOptions(VertexAiGeminiChatOptions.builder() + .model(VertexAiGeminiChatModel.ChatModel.GEMINI_2_0_FLASH) + .temperature(0.1) + .build()) + .build(); + } + + @Bean + ToolCallingManager toolCallingManager(GenericApplicationContext applicationContext, + List toolCallbacks, ObjectProvider observationRegistry) { + + var staticToolCallbackResolver = new StaticToolCallbackResolver(toolCallbacks); + var springBeanToolCallbackResolver = SpringBeanToolCallbackResolver.builder() + .applicationContext(applicationContext) + .build(); + + ToolCallbackResolver toolCallbackResolver = new DelegatingToolCallbackResolver( + List.of(staticToolCallbackResolver, springBeanToolCallbackResolver)); + + return ToolCallingManager.builder() + .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) + .toolCallbackResolver(toolCallbackResolver) + .toolExecutionExceptionProcessor(new DefaultToolExecutionExceptionProcessor(false)) + .build(); + } + + } + +} diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiPaymentTransactionMethodIT.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiPaymentTransactionMethodIT.java new file mode 100644 index 00000000000..85c9c5218f8 --- /dev/null +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiPaymentTransactionMethodIT.java @@ -0,0 +1,248 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vertexai.gemini.tool; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import com.google.cloud.vertexai.Transport; +import com.google.cloud.vertexai.VertexAI; +import io.micrometer.observation.ObservationRegistry; +import org.junit.jupiter.api.RepeatedTest; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; +import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.ToolCallbacks; +import org.springframework.ai.tool.annotation.Tool; +import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor; +import org.springframework.ai.tool.resolution.DelegatingToolCallbackResolver; +import org.springframework.ai.tool.resolution.SpringBeanToolCallbackResolver; +import org.springframework.ai.tool.resolution.StaticToolCallbackResolver; +import org.springframework.ai.tool.resolution.ToolCallbackResolver; +import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel; +import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatOptions; +import org.springframework.beans.factory.ObjectProvider; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; +import org.springframework.context.support.GenericApplicationContext; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Christian Tzolov + */ +@SpringBootTest +@EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_PROJECT_ID", matches = ".*") +@EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_LOCATION", matches = ".*") +public class VertexAiGeminiPaymentTransactionMethodIT { + + private static final Logger logger = LoggerFactory.getLogger(VertexAiGeminiPaymentTransactionMethodIT.class); + + private static final Map DATASET = Map.of(new Transaction("001"), new Status("pending"), + new Transaction("002"), new Status("approved"), new Transaction("003"), new Status("rejected")); + + @Autowired + ChatClient chatClient; + + @Test + public void paymentStatuses() { + + String content = this.chatClient.prompt().advisors(new LoggingAdvisor()).tools("paymentStatus").user(""" + What is the status of my payment transactions 001, 002 and 003? + If requred invoke the function per transaction. + """).call().content(); + logger.info("" + content); + + assertThat(content).contains("001", "002", "003"); + assertThat(content).contains("pending", "approved", "rejected"); + } + + @RepeatedTest(5) + public void streamingPaymentStatuses() { + + Flux streamContent = this.chatClient.prompt() + .advisors(new LoggingAdvisor()) + .tools("paymentStatus") + .user(""" + What is the status of my payment transactions 001, 002 and 003? + If requred invoke the function per transaction. + """) + .stream() + .content(); + + String content = streamContent.collectList().block().stream().collect(Collectors.joining()); + + logger.info(content); + + assertThat(content).contains("001", "002", "003"); + assertThat(content).contains("pending", "approved", "rejected"); + + // Quota rate + try { + Thread.sleep(1000); + } + catch (InterruptedException e) { + } + } + + record TransactionStatusResponse(String id, String status) { + + } + + private static class LoggingAdvisor implements CallAroundAdvisor { + + private final Logger logger = LoggerFactory.getLogger(LoggingAdvisor.class); + + @Override + public String getName() { + return this.getClass().getSimpleName(); + } + + @Override + public int getOrder() { + return 0; + } + + @Override + public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { + var response = chain.nextAroundCall(before(advisedRequest)); + observeAfter(response); + return response; + } + + private AdvisedRequest before(AdvisedRequest request) { + logger.info("System text: \n" + request.systemText()); + logger.info("System params: " + request.systemParams()); + logger.info("User text: \n" + request.userText()); + logger.info("User params:" + request.userParams()); + logger.info("Function names: " + request.functionNames()); + + logger.info("Options: " + request.chatOptions().toString()); + + return request; + } + + private void observeAfter(AdvisedResponse advisedResponse) { + logger.info("Response: " + advisedResponse.response()); + } + + } + + record Transaction(String id) { + } + + record Status(String name) { + } + + public static class PaymentService { + + @Tool(description = "Get the status of a single payment transaction") + public Status paymentStatus(Transaction transaction) { + logger.info("Single Transaction: " + transaction); + return DATASET.get(transaction); + } + + @Tool(description = "Get the list statuses of a list of payment transactions") + public List statusespaymentStatuses(List transactions) { + logger.info("Transactions: " + transactions); + return transactions.stream().map(t -> DATASET.get(t)).toList(); + } + + } + + @SpringBootConfiguration + public static class TestConfiguration { + + @Bean + public List paymentServiceTools() { + var tools = List.of(ToolCallbacks.from(new PaymentService())); + return tools; + } + + @Bean + public ChatClient chatClient(VertexAiGeminiChatModel chatModel) { + return ChatClient.builder(chatModel).build(); + } + + @Bean + public VertexAI vertexAiApi() { + + String projectId = System.getenv("VERTEX_AI_GEMINI_PROJECT_ID"); + String location = System.getenv("VERTEX_AI_GEMINI_LOCATION"); + + return new VertexAI.Builder().setLocation(location) + .setProjectId(projectId) + .setTransport(Transport.REST) + // .setTransport(Transport.GRPC) + .build(); + } + + @Bean + public VertexAiGeminiChatModel vertexAiChatModel(VertexAI vertexAi, ToolCallingManager toolCallingManager) { + + return VertexAiGeminiChatModel.builder() + .vertexAI(vertexAi) + .toolCallingManager(toolCallingManager) + .defaultOptions(VertexAiGeminiChatOptions.builder() + .model(VertexAiGeminiChatModel.ChatModel.GEMINI_2_0_FLASH) + .temperature(0.1) + .build()) + .build(); + } + + @Bean + ToolCallingManager toolCallingManager(GenericApplicationContext applicationContext, + List toolCallbacks, List functionCallbacks, + ObjectProvider observationRegistry) { + + List allFunctionCallbacks = new ArrayList(functionCallbacks); + allFunctionCallbacks.addAll(toolCallbacks.stream().map(tc -> (FunctionCallback) tc).toList()); + + var staticToolCallbackResolver = new StaticToolCallbackResolver(allFunctionCallbacks); + + var springBeanToolCallbackResolver = SpringBeanToolCallbackResolver.builder() + .applicationContext(applicationContext) + .build(); + + ToolCallbackResolver toolCallbackResolver = new DelegatingToolCallbackResolver( + List.of(staticToolCallbackResolver, springBeanToolCallbackResolver)); + + return ToolCallingManager.builder() + .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) + .toolCallbackResolver(toolCallbackResolver) + .toolExecutionExceptionProcessor(new DefaultToolExecutionExceptionProcessor(false)) + .build(); + } + + } + +} diff --git a/pom.xml b/pom.xml index f47ff2f5786..58b192f46cc 100644 --- a/pom.xml +++ b/pom.xml @@ -190,7 +190,7 @@ 4.3.4 1.0.0-beta.13 1.1.0 - 4.31.1 + 4.37.0 1.9.25 diff --git a/spring-ai-core/src/main/java/org/springframework/ai/util/json/schema/JsonSchemaGenerator.java b/spring-ai-core/src/main/java/org/springframework/ai/util/json/schema/JsonSchemaGenerator.java index 3d2599874db..504fdd92f60 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/util/json/schema/JsonSchemaGenerator.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/util/json/schema/JsonSchemaGenerator.java @@ -244,7 +244,7 @@ private static String getMethodParameterDescription(Method method, int index) { } // Based on the method in ModelOptionsUtils. - private static void convertTypeValuesToUpperCase(ObjectNode node) { + public static void convertTypeValuesToUpperCase(ObjectNode node) { if (node.isObject()) { node.fields().forEachRemaining(entry -> { JsonNode value = entry.getValue(); diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc index f94c7ea0843..434895a81a8 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc @@ -14,7 +14,6 @@ *** xref:api/chat/deepseek-chat.adoc[DeepSeek AI] *** xref:api/chat/google-vertexai.adoc[Google VertexAI] **** xref:api/chat/vertexai-gemini-chat.adoc[VertexAI Gemini] -***** xref:api/chat/functions/vertexai-gemini-chat-functions.adoc[Gemini Function Calling] *** xref:api/chat/groq-chat.adoc[Groq] *** xref:api/chat/huggingface.adoc[Hugging Face] *** xref:api/chat/mistralai-chat.adoc[Mistral AI] diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock-converse.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock-converse.adoc index cd3c7f76be9..48b7048ae2c 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock-converse.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock-converse.adoc @@ -115,7 +115,7 @@ String response = ChatClient.create(this.chatModel) .content(); ---- -== Tool/Function Calling +== Tool Calling The Bedrock Converse API supports tool calling capabilities, allowing models to use tools during conversations. Here's an example of how to define and use @Tool based tools: diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/vertexai-gemini-chat-functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/vertexai-gemini-chat-functions.adoc deleted file mode 100644 index de5d31de3ed..00000000000 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/vertexai-gemini-chat-functions.adoc +++ /dev/null @@ -1,207 +0,0 @@ -= Gemini Function Calling - -WARNING: -Apparently the Gemini Pro can not handle anymore the function name correctly. -The parallel function calling is gone as well. - -Function calling lets developers create a description of a function in their code, then pass that description to a language model in a request. The response from the model includes the name of a function that matches the description and the arguments to call it with. - -You can register custom Java functions with the `VertexAiGeminiChatModel` and have the Gemini Pro model intelligently choose to output a JSON object containing arguments to call one or many of the registered functions. -This allows you to connect the LLM capabilities with external tools and APIs. -The VertexAI Gemini Pro model is trained to detect when a function should be called and to respond with JSON that adheres to the function signature. - -The VertexAI Gemini API does not call the function directly; instead, the model generates JSON that you can use to call the function in your code and return the result back to the model to complete the conversation. - -Spring AI provides flexible and user-friendly ways to register and call custom functions. -In general, the custom functions need to provide a function `name`, `description`, and the function call `signature` (as Open API schema) to let the model know what arguments the function expects. The `description` helps the model to understand when to call the function. - -As a developer, you need to implement a function that takes the function call arguments sent from the AI model, and responds with the result back to the model. -Your function can in turn invoke other 3rd party services to provide the results. - -Spring AI makes this as easy as defining a `@Bean` definition that returns a `java.util.Function` and supplying the bean name as an option when invoking the `ChatModel`. - -Under the hood, Spring wraps your POJO (the function) with the appropriate adapter code that enables interaction with the AI Model, saving you from writing tedious boilerplate code. -The basis of the underlying infrastructure is the link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallback.java[FunctionCallback.java] interface and the companion Builder utility class to simplify the implementation and registration of Java callback functions. - -// Additionally, the Auto-Configuration provides a way to auto-register any Function beans definition as function calling candidates in the `ChatModel`. - -== How it works - -Suppose we want the AI model to respond with information that it does not have, for example the current temperature at a given location. - -We can provide the AI model with metadata about our own functions that it can use to retrieve that information as it processes your prompt. - -For example, if during the processing of a prompt, the AI Model determines that it needs additional information about the temperature in a given location, it will start a server side generated request/response interaction. The AI Model invokes a client side function. -The AI Model provides method invocation details as JSON and it is the responsibility of the client to execute that function and return the response. - -Spring AI greatly simplifies the code you need to write to support function invocation. -It brokers the function invocation conversation for you. -You can simply provide your function definition as a `@Bean` and then provide the bean name of the function in your prompt options. -You can also reference multiple function bean names in your prompt. - -== Quick Start - -Let's create a chatbot that answer questions by calling our own function. -To support the response of the chatbot, we will register our own function that takes a location and returns the current weather in that location. - -When the response to the prompt to the model needs to answer a question such as `"What’s the weather like in Boston?"` the AI model will invoke the client providing the location value as an argument to be passed to the function. This RPC-like data is passed as JSON. - -Our function can some SaaS based weather service API and returns the weather response back to the model to complete the conversation. In this example we will use a simple implementation named `MockWeatherService` that hard codes the temperature for various locations. - -The following `MockWeatherService.java` represents the weather service API: - -[source,java] ----- -public class MockWeatherService implements Function { - - public enum Unit { C, F } - public record Request(String location, Unit unit) {} - public record Response(double temp, Unit unit) {} - - public Response apply(Request request) { - return new Response(30.0, Unit.C); - } -} ----- - -=== Registering Functions as Beans - -With the link:../vertexai-gemini-chat.html#_auto_configuration[VertexAiGeminiChatModel Auto-Configuration] you have multiple ways to register custom functions as beans in the Spring context. - -We start with describing the most POJO friendly options. - -==== Plain Java Functions - -In this approach you define `@Beans` in your application context as you would any other Spring managed object. - -Internally, Spring AI `ChatModel` will create an instance of a `FunctionCallback` instance that adds the logic for it being invoked via the AI model. -The name of the `@Bean` is passed as a `ChatOption`. - - -[source,java] ----- -@Configuration -static class Config { - - @Bean - @Description("Get the weather in location") // function description - public Function weatherFunction1() { - return new MockWeatherService(); - } - ... -} ----- - -The `@Description` annotation is optional and provides a function description (2) that helps the model understand when to call the function. It is an important property to set to help the AI model determine what client side function to invoke. - -Another option to provide the description of the function is to use the `@JsonClassDescription` annotation on the `MockWeatherService.Request` to provide the function description: - -[source,java] ----- - -@Configuration -static class Config { - - @Bean - public Function currentWeather3() { // (1) bean name as function name. - return new MockWeatherService(); - } - ... -} - -@JsonClassDescription("Get the weather in location") // (2) function description -public record Request(String location, Unit unit) {} ----- - -It is a best practice to annotate the request object with information such that the generated JSON schema of that function is as descriptive as possible to help the AI model pick the correct function to invoke. - -The link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/gemini/tool/FunctionCallWithFunctionBeanIT.java[FunctionCallWithFunctionBeanIT.java] demonstrates this approach. - -==== FunctionCallback Wrapper - -Another way to register a function is to create a `FunctionCallback` instance like this: - -[source,java] ----- -@Configuration -static class Config { - - @Bean - public FunctionCallback weatherFunctionInfo() { - - return FunctionCallback.builder() - .function("CurrentWeather", new MockWeatherService()) // (1) function name and instance - .description("Get the current weather in a given location") // (2) function description - .schemaType(SchemaType.OPEN_API_SCHEMA) // (3) schema type. Compulsory for Gemini function calling. - .inputType(MockWeatherService.Request.class) // (4) input type - .build(); - } - ... -} ----- - -It wraps the 3rd party `MockWeatherService` function and registers it as a `CurrentWeather` function with the `VertexAiGeminiChatModel`. -It also provides a description (2), the Schema type to Open API type (3) and input type (4) used to generate the Open API schema for the function call. - -NOTE: The default response converter does a JSON serialization of the Response object. - -NOTE: The `FunctionCallback` internally resolves the function call signature based on the `MockWeatherService.Request` class and internally generates an Open API schema for the function call. - -=== Specifying functions in Chat Options - -To let the model know and call your `CurrentWeather` function you need to enable it in your prompt requests: - -[source,java] ----- -VertexAiGeminiChatModel chatModel = ... - -UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); - -ChatResponse response = this.chatModel.call(new Prompt(List.of(this.userMessage), - VertexAiGeminiChatOptions.builder().withFunction("CurrentWeather").build())); // (1) Enable the function - -logger.info("Response: {}", response); ----- - -// NOTE: You can can have multiple functions registered in your `ChatModel` but only those enabled in the prompt request will be considered for the function calling. - -Above user question will trigger 3 calls to `CurrentWeather` function (one for each city) and the final response will be something like this: - ----- -Here is the current weather for the requested cities: -- San Francisco, CA: 30.0°C -- Tokyo, Japan: 10.0°C -- Paris, France: 15.0°C ----- - -The link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/gemini/tool/FunctionCallWithFunctionWrapperIT.java[FunctionCallWithFunctionWrapperIT.java] test demo this approach. - - -=== Register/Call Functions with Prompt Options - -In addition to the auto-configuration you can register callback functions, dynamically, with your Prompt requests: - -[source,java] ----- -VertexAiGeminiChatModel chatModel = ... - -UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris? Use Multi-turn function calling."); - -var promptOptions = VertexAiGeminiChatOptions.builder() - .withFunctionCallbacks(List.of(FunctionCallback.builder() - .function("CurrentWeather", new MockWeatherService()) - .schemaType(SchemaType.OPEN_API_SCHEMA) // IMPORTANT!! - .description("Get the weather in location") - .inputType(MockWeatherService.Request.class) - .build())) - .build(); - -ChatResponse response = this.chatModel.call(new Prompt(List.of(this.userMessage), this.promptOptions)); ----- - -NOTE: The in-prompt registered functions are enabled by default for the duration of this request. - -This approach allows to dynamically chose different functions to be called based on the user input. - -The https://github.com/spring-projects/spring-ai/blob/main/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/gemini/tool/FunctionCallWithPromptFunctionIT.java[FunctionCallWithPromptFunctionIT.java] integration test provides a complete example of how to register a function with the `VertexAiGeminiChatModel` and use it in a prompt request. - diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/vertexai-gemini-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/vertexai-gemini-chat.adoc index f88867f4064..dc9ce83b7ed 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/vertexai-gemini-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/vertexai-gemini-chat.adoc @@ -113,11 +113,49 @@ ChatResponse response = chatModel.call( TIP: In addition to the model specific `VertexAiGeminiChatOptions` you can use a portable https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java[ChatOptions] instance, created with the https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptionsBuilder.java[ChatOptionsBuilder#builder()]. -== Function Calling +== Tool Calling + +The Vertex AI Gemini supports tool calling capabilities, allowing models to use tools during conversations. +Here's an example of how to define and use @Tool based tools: + +[source,java] +---- + +public class WeatherService { + + @Tool(description = "Get the weather in location") + public String weatherByLocation(@ToolParam(description= "City or state name") String location) { + ... + } +} + +String response = ChatClient.create(this.chatModel) + .prompt("What's the weather like in Boston?") + .tools(new WeatherService()) + .call() + .content(); +---- + +You can use the java.util.function beans as tools as well: + +[source,java] +---- +@Bean +@Description("Get the weather in location. Return temperature in 36°F or 36°C format.") +public Function weatherFunction() { + return new MockWeatherService(); +} + +String response = ChatClient.create(this.chatModel) + .prompt("What's the weather like in Boston?") + .tools("weatherFunction") + .inputType(Request.class) + .call() + .content(); +---- + +Find more in xref:api/tools.adoc[Tools] documentation. -You can register custom Java functions with the VertexAiGeminiChatModel and have the Gemini Pro model intelligently choose to output a JSON object containing arguments to call one or many of the registered functions. -This is a powerful technique to connect the LLM capabilities with external tools and APIs. -Read more about xref:api/chat/functions/vertexai-gemini-chat-functions.adoc[Vertex AI Gemini Function Calling]. == Multimodal diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/functions.adoc index e64d66d86e9..00e450f3267 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/functions.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/functions.adoc @@ -5,16 +5,17 @@ WARNING: This page describes the previous version of the Function Calling API, w The integration of function support in AI models, permits the model to request the execution of client-side functions, thereby accessing necessary information or performing tasks dynamically as required. -Spring AI currently supports function invocation for the following AI Models: - -* Anthropic Claude: Refer to the xref:api/chat/functions/anthropic-chat-functions.adoc[Anthropic Claude function invocation docs]. -* Azure OpenAI: Refer to the xref:api/chat/functions/azure-open-ai-chat-functions.adoc[Azure OpenAI function invocation docs]. -* Google VertexAI Gemini: Refer to the xref:api/chat/functions/vertexai-gemini-chat-functions.adoc[Vertex AI Gemini function invocation docs]. -* Groq: Refer to the xref:api/chat/groq-chat.adoc#_function_calling[Groq function invocation docs]. -* Mistral AI: Refer to the xref:api/chat/functions/mistralai-chat-functions.adoc[Mistral AI function invocation docs]. +Spring AI currently supports tool/function calling for the following AI Models: + +* Anthropic Claude: Refer to the xref:api/chat/functions/anthropic-chat-functions.adoc[Anthropic Claude tool/function calling]. +* Azure OpenAI: Refer to the xref:api/chat/functions/azure-open-ai-chat-functions.adoc[Azure OpenAI tool/function calling]. +* Amazon Bedrock Converse: Refer to the xref:api/chat/bedrock-converse.adoc#_tool_calling[Amazon Bedrock Converse tool/function calling]. +* Google VertexAI Gemini: Refer to the xref:api/chat/vertexai-gemini-chat.adoc#_tool_calling[Vertex AI Gemini tool/function calling]. +* Groq: Refer to the xref:api/chat/groq-chat.adoc#_function_calling[Groq tool/function calling]. +* Mistral AI: Refer to the xref:api/chat/functions/mistralai-chat-functions.adoc[Mistral AI tool/function calling]. // * MiniMax : Refer to the xref:api/chat/functions/minimax-chat-functions.adoc[MiniMax function invocation docs]. -* Ollama: Refer to the xref:api/chat/functions/ollama-chat-functions.adoc[Ollama function invocation docs] (streaming not supported yet). -* OpenAI: Refer to the xref:api/chat/functions/openai-chat-functions.adoc[OpenAI function invocation docs]. +* Ollama: Refer to the xref:api/chat/functions/ollama-chat-functions.adoc[Ollama tool/function calling] +* OpenAI: Refer to the xref:api/chat/functions/openai-chat-functions.adoc[OpenAI tool/function calling]. // * ZhiPu AI : Refer to the xref:api/chat/functions/zhipuai-chat-functions.adoc[ZhiPu AI function invocation docs]. image::function-calling-basic-flow.jpg[Function calling, width=700, align="center"] diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiAutoConfiguration.java index 2f1125a31c1..69b4321d8d9 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiAutoConfiguration.java @@ -23,12 +23,14 @@ import com.google.cloud.vertexai.VertexAI; import io.micrometer.observation.ObservationRegistry; +import org.springframework.ai.autoconfigure.chat.model.ToolCallingAutoConfiguration; import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.model.function.DefaultFunctionCallbackResolver; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallback.SchemaType; import org.springframework.ai.model.function.FunctionCallbackResolver; +import org.springframework.ai.model.tool.ToolCallingManager; import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel; import org.springframework.beans.factory.ObjectProvider; import org.springframework.boot.autoconfigure.AutoConfiguration; @@ -52,10 +54,10 @@ * @author Mark Pollack * @since 1.0.0 */ -@AutoConfiguration(after = { SpringAiRetryAutoConfiguration.class }) +@AutoConfiguration(after = { SpringAiRetryAutoConfiguration.class, ToolCallingAutoConfiguration.class }) @ConditionalOnClass({ VertexAI.class, VertexAiGeminiChatModel.class }) @EnableConfigurationProperties({ VertexAiGeminiChatProperties.class, VertexAiGeminiConnectionProperties.class }) -@ImportAutoConfiguration(classes = { SpringAiRetryAutoConfiguration.class }) +@ImportAutoConfiguration(classes = { SpringAiRetryAutoConfiguration.class, ToolCallingAutoConfiguration.class }) public class VertexAiGeminiAutoConfiguration { @Bean @@ -91,16 +93,20 @@ public VertexAI vertexAi(VertexAiGeminiConnectionProperties connectionProperties @ConditionalOnProperty(prefix = VertexAiGeminiChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", matchIfMissing = true) public VertexAiGeminiChatModel vertexAiGeminiChat(VertexAI vertexAi, VertexAiGeminiChatProperties chatProperties, - List toolFunctionCallbacks, ApplicationContext context, RetryTemplate retryTemplate, + ToolCallingManager toolCallingManager, ApplicationContext context, RetryTemplate retryTemplate, ObjectProvider observationRegistry, ObjectProvider observationConvention) { - FunctionCallbackResolver functionCallbackResolver = springAiFunctionManager(context); + VertexAiGeminiChatModel chatModel = VertexAiGeminiChatModel.builder() + .vertexAI(vertexAi) + .defaultOptions(chatProperties.getOptions()) + .toolCallingManager(toolCallingManager) + .retryTemplate(retryTemplate) + .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) + .build(); - VertexAiGeminiChatModel chatModel = new VertexAiGeminiChatModel(vertexAi, chatProperties.getOptions(), - functionCallbackResolver, toolFunctionCallbacks, retryTemplate, - observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)); observationConvention.ifAvailable(chatModel::setObservationConvention); + return chatModel; } @@ -108,11 +114,12 @@ public VertexAiGeminiChatModel vertexAiGeminiChat(VertexAI vertexAi, VertexAiGem * Because of the OPEN_API_SCHEMA type, the FunctionCallbackResolver instance must * different from the other JSON schema types. */ - private FunctionCallbackResolver springAiFunctionManager(ApplicationContext context) { - DefaultFunctionCallbackResolver manager = new DefaultFunctionCallbackResolver(); - manager.setSchemaType(SchemaType.OPEN_API_SCHEMA); - manager.setApplicationContext(context); - return manager; - } + // private FunctionCallbackResolver springAiFunctionManager(ApplicationContext + // context) { + // DefaultFunctionCallbackResolver manager = new DefaultFunctionCallbackResolver(); + // manager.setSchemaType(SchemaType.OPEN_API_SCHEMA); + // manager.setApplicationContext(context); + // return manager; + // } } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiChatProperties.java index f024067bc1f..9ce224a7a8c 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/gemini/VertexAiGeminiChatProperties.java @@ -31,7 +31,7 @@ public class VertexAiGeminiChatProperties { public static final String CONFIG_PREFIX = "spring.ai.vertex.ai.gemini.chat"; - public static final String DEFAULT_MODEL = VertexAiGeminiChatModel.ChatModel.GEMINI_1_5_PRO.getValue(); + public static final String DEFAULT_MODEL = VertexAiGeminiChatModel.ChatModel.GEMINI_2_0_FLASH.getValue(); /** * Vertex AI Gemini API generative options. diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithFunctionBeanIT.java index 92769b23488..eb31e3b8840 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithFunctionBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithFunctionBeanIT.java @@ -57,7 +57,7 @@ void functionCallTest() { this.contextRunner.withPropertyValues("spring.ai.vertex.ai.gemini.chat.options.model=" // + VertexAiGeminiChatModel.ChatModel.GEMINI_PRO_1_5_PRO.getValue()) - + VertexAiGeminiChatModel.ChatModel.GEMINI_1_5_FLASH.getValue()) + + VertexAiGeminiChatModel.ChatModel.GEMINI_2_0_FLASH.getValue()) .run(context -> { VertexAiGeminiChatModel chatModel = context.getBean(VertexAiGeminiChatModel.class); @@ -97,7 +97,7 @@ void functionCallWithPortableFunctionCallingOptions() { this.contextRunner.withPropertyValues("spring.ai.vertex.ai.gemini.chat.options.model=" // + VertexAiGeminiChatModel.ChatModel.GEMINI_PRO_1_5_PRO.getValue()) - + VertexAiGeminiChatModel.ChatModel.GEMINI_1_5_FLASH.getValue()) + + VertexAiGeminiChatModel.ChatModel.GEMINI_2_0_FLASH.getValue()) .run(context -> { VertexAiGeminiChatModel chatModel = context.getBean(VertexAiGeminiChatModel.class); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithFunctionWrapperIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithFunctionWrapperIT.java index 4747e6af870..bfcdb7c6433 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithFunctionWrapperIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithFunctionWrapperIT.java @@ -27,8 +27,8 @@ import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.model.function.FunctionCallback.SchemaType; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel; import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatOptions; import org.springframework.boot.autoconfigure.AutoConfigurations; @@ -54,7 +54,7 @@ public class FunctionCallWithFunctionWrapperIT { void functionCallTest() { this.contextRunner .withPropertyValues("spring.ai.vertex.ai.gemini.chat.options.model=" - + VertexAiGeminiChatModel.ChatModel.GEMINI_1_5_FLASH.getValue()) + + VertexAiGeminiChatModel.ChatModel.GEMINI_2_0_FLASH.getValue()) .run(context -> { VertexAiGeminiChatModel chatModel = context.getBean(VertexAiGeminiChatModel.class); @@ -65,7 +65,7 @@ void functionCallTest() { """); ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), - VertexAiGeminiChatOptions.builder().function("WeatherInfo").build())); + VertexAiGeminiChatOptions.builder().toolName("WeatherInfo").build())); logger.info("Response: {}", response); @@ -77,12 +77,10 @@ void functionCallTest() { static class Config { @Bean - public FunctionCallback weatherFunctionInfo() { + public ToolCallback weatherFunctionInfo() { - return FunctionCallback.builder() - .function("WeatherInfo", new MockWeatherService()) + return FunctionToolCallback.builder("WeatherInfo", new MockWeatherService()) .description("Get the current weather in a given location") - .schemaType(SchemaType.OPEN_API_SCHEMA) .inputType(MockWeatherService.Request.class) .build(); } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithPromptFunctionIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithPromptFunctionIT.java index 99e524f5219..2227f80c623 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithPromptFunctionIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithPromptFunctionIT.java @@ -29,6 +29,7 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallback.SchemaType; +import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel; import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatOptions; import org.springframework.boot.autoconfigure.AutoConfigurations; @@ -51,7 +52,7 @@ public class FunctionCallWithPromptFunctionIT { void functionCallTest() { this.contextRunner .withPropertyValues("spring.ai.vertex.ai.gemini.chat.options.model=" - + VertexAiGeminiChatModel.ChatModel.GEMINI_1_5_FLASH.getValue()) + + VertexAiGeminiChatModel.ChatModel.GEMINI_2_0_FLASH_LIGHT.getValue()) .run(context -> { VertexAiGeminiChatModel chatModel = context.getBean(VertexAiGeminiChatModel.class); @@ -68,12 +69,11 @@ void functionCallTest() { """); var promptOptions = VertexAiGeminiChatOptions.builder() - .functionCallbacks(List.of(FunctionCallback.builder() - .function("CurrentWeatherService", new MockWeatherService()) - .schemaType(SchemaType.OPEN_API_SCHEMA) // IMPORTANT!! - .description("Get the weather in location") - .inputType(MockWeatherService.Request.class) - .build())) + .toolCallbacks( + List.of(FunctionToolCallback.builder("CurrentWeatherService", new MockWeatherService()) + .description("Get the weather in location") + .inputType(MockWeatherService.Request.class) + .build())) .build(); ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions));