Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion models/spring-ai-vertex-ai-gemini/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
-->

<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.springframework.ai</groupId>
Expand Down Expand Up @@ -53,6 +54,17 @@

<dependencies>

<dependency>
<groupId>com.github.victools</groupId>
<artifactId>jsonschema-generator</artifactId>
<version>${victools.version}</version>
</dependency>
<dependency>
<groupId>com.github.victools</groupId>
<artifactId>jsonschema-module-jackson</artifactId>
<version>${victools.version}</version>
</dependency>

<dependency>
<groupId>com.google.cloud</groupId>
<artifactId>google-cloud-vertexai</artifactId>
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -29,11 +29,12 @@
import com.fasterxml.jackson.annotation.JsonInclude.Include;
import com.fasterxml.jackson.annotation.JsonProperty;

import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel.ChatModel;
import org.springframework.ai.vertexai.gemini.common.VertexAiGeminiSafetySetting;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;

/**
Expand All @@ -46,7 +47,7 @@
* @since 1.0.0
*/
@JsonInclude(Include.NON_NULL)
public class VertexAiGeminiChatOptions implements FunctionCallingOptions {
public class VertexAiGeminiChatOptions implements ToolCallingChatOptions {

// https://cloud.google.com/vertex-ai/docs/reference/rest/v1/GenerationConfig

Expand Down Expand Up @@ -95,40 +96,36 @@ public class VertexAiGeminiChatOptions implements FunctionCallingOptions {
private @JsonProperty("responseMimeType") String responseMimeType;

/**
* Tool Function Callbacks to register with the ChatModel.
* For Prompt Options the functionCallbacks are automatically enabled for the duration of the prompt execution.
* For Default Options the functionCallbacks are registered but disabled by default. Use the enableFunctions to set the functions
* from the registry to be used by the ChatModel chat completion requests.
* Collection of {@link ToolCallback}s to be used for tool calling in the chat
* completion requests.
*/
@JsonIgnore
private List<FunctionCallback> functionCallbacks = new ArrayList<>();
private List<FunctionCallback> toolCallbacks = new ArrayList<>();

/**
* List of functions, identified by their names, to configure for function calling in
* the chat completion requests.
* Functions with those names must exist in the functionCallbacks registry.
* The {@link #functionCallbacks} from the PromptOptions are automatically enabled for the duration of the prompt execution.
*
* Note that function enabled with the default options are enabled for all chat completion requests. This could impact the token count and the billing.
* If the functions is set in a prompt options, then the enabled functions are only active for the duration of this prompt execution.
* Collection of tool names to be resolved at runtime and used for tool calling in the
* chat completion requests.
*/
@JsonIgnore
private Set<String> functions = new HashSet<>();
private Set<String> toolNames = new HashSet<>();

/**
* Use Google search Grounding feature
* Whether to enable the tool execution lifecycle internally in ChatModel.
*/
@JsonIgnore
private boolean googleSearchRetrieval = false;
private Boolean internalToolExecutionEnabled;

@JsonIgnore
private List<VertexAiGeminiSafetySetting> safetySettings = new ArrayList<>();
private Map<String, Object> toolContext = new HashMap<>();

/**
* Use Google search Grounding feature
*/
@JsonIgnore
private Boolean proxyToolCalls;
private Boolean googleSearchRetrieval = false;

@JsonIgnore
private Map<String, Object> toolContext;
private List<VertexAiGeminiSafetySetting> safetySettings = new ArrayList<>();

public static Builder builder() {
return new Builder();
Expand All @@ -145,13 +142,13 @@ public static VertexAiGeminiChatOptions fromOptions(VertexAiGeminiChatOptions fr
options.setCandidateCount(fromOptions.getCandidateCount());
options.setMaxOutputTokens(fromOptions.getMaxOutputTokens());
options.setModel(fromOptions.getModel());
options.setFunctionCallbacks(fromOptions.getFunctionCallbacks());
options.setToolCallbacks(fromOptions.getToolCallbacks());
options.setResponseMimeType(fromOptions.getResponseMimeType());
options.setFunctions(fromOptions.getFunctions());
options.setToolNames(fromOptions.getToolNames());
options.setResponseMimeType(fromOptions.getResponseMimeType());
options.setGoogleSearchRetrieval(fromOptions.getGoogleSearchRetrieval());
options.setSafetySettings(fromOptions.getSafetySettings());
options.setProxyToolCalls(fromOptions.getProxyToolCalls());
options.setInternalToolExecutionEnabled(fromOptions.isInternalToolExecutionEnabled());
options.setToolContext(fromOptions.getToolContext());
return options;
}
Expand Down Expand Up @@ -236,20 +233,67 @@ public void setResponseMimeType(String mimeType) {
this.responseMimeType = mimeType;
}

@Override
@JsonIgnore
@Deprecated
public List<FunctionCallback> getFunctionCallbacks() {
return this.functionCallbacks;
return this.getToolCallbacks();
}

@Override
@JsonIgnore
@Deprecated
public void setFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
this.functionCallbacks = functionCallbacks;
this.setToolCallbacks(functionCallbacks);
}

@Override
public List<FunctionCallback> getToolCallbacks() {
return this.toolCallbacks;
}

@Override
public void setToolCallbacks(List<FunctionCallback> toolCallbacks) {
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements");
this.toolCallbacks = toolCallbacks;
}

@Override
@JsonIgnore
@Deprecated
public Set<String> getFunctions() {
return this.functions;
return this.getToolNames();
}

@JsonIgnore
@Deprecated
public void setFunctions(Set<String> functions) {
this.functions = functions;
this.setToolNames(functions);
}

@Override
public Set<String> getToolNames() {
return this.toolNames;
}

@Override
public void setToolNames(Set<String> toolNames) {
Assert.notNull(toolNames, "toolNames cannot be null");
Assert.noNullElements(toolNames, "toolNames cannot contain null elements");
toolNames.forEach(tool -> Assert.hasText(tool, "toolNames cannot contain empty elements"));
this.toolNames = toolNames;
}

@Override
@Nullable
public Boolean isInternalToolExecutionEnabled() {
return internalToolExecutionEnabled;
}

@Override
public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled) {
this.internalToolExecutionEnabled = internalToolExecutionEnabled;
}

@Override
Expand All @@ -264,11 +308,11 @@ public Double getPresencePenalty() {
return null;
}

public boolean getGoogleSearchRetrieval() {
public Boolean getGoogleSearchRetrieval() {
return this.googleSearchRetrieval;
}

public void setGoogleSearchRetrieval(boolean googleSearchRetrieval) {
public void setGoogleSearchRetrieval(Boolean googleSearchRetrieval) {
this.googleSearchRetrieval = googleSearchRetrieval;
}

Expand All @@ -281,13 +325,17 @@ public void setSafetySettings(List<VertexAiGeminiSafetySetting> safetySettings)
this.safetySettings = safetySettings;
}

@Deprecated
@Override
@JsonIgnore
public Boolean getProxyToolCalls() {
return this.proxyToolCalls;
return this.internalToolExecutionEnabled != null ? !this.internalToolExecutionEnabled : null;
}

@Deprecated
@JsonIgnore
public void setProxyToolCalls(Boolean proxyToolCalls) {
this.proxyToolCalls = proxyToolCalls;
this.internalToolExecutionEnabled = proxyToolCalls != null ? !proxyToolCalls : null;
}

@Override
Expand All @@ -314,96 +362,35 @@ public boolean equals(Object o) {
&& Objects.equals(this.topK, that.topK) && Objects.equals(this.candidateCount, that.candidateCount)
&& Objects.equals(this.maxOutputTokens, that.maxOutputTokens) && Objects.equals(this.model, that.model)
&& Objects.equals(this.responseMimeType, that.responseMimeType)
&& Objects.equals(this.functionCallbacks, that.functionCallbacks)
&& Objects.equals(this.functions, that.functions)
&& Objects.equals(this.toolCallbacks, that.toolCallbacks)
&& Objects.equals(this.toolNames, that.toolNames)
&& Objects.equals(this.safetySettings, that.safetySettings)
&& Objects.equals(this.proxyToolCalls, that.proxyToolCalls)
&& Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled)
&& Objects.equals(this.toolContext, that.toolContext);
}

@Override
public int hashCode() {
return Objects.hash(this.stopSequences, this.temperature, this.topP, this.topK, this.candidateCount,
this.maxOutputTokens, this.model, this.responseMimeType, this.functionCallbacks, this.functions,
this.googleSearchRetrieval, this.safetySettings, this.proxyToolCalls, this.toolContext);
this.maxOutputTokens, this.model, this.responseMimeType, this.toolCallbacks, this.toolNames,
this.googleSearchRetrieval, this.safetySettings, this.internalToolExecutionEnabled, this.toolContext);
}

@Override
public String toString() {
return "VertexAiGeminiChatOptions{" + "stopSequences=" + this.stopSequences + ", temperature="
+ this.temperature + ", topP=" + this.topP + ", topK=" + this.topK + ", candidateCount="
+ this.candidateCount + ", maxOutputTokens=" + this.maxOutputTokens + ", model='" + this.model + '\''
+ ", responseMimeType='" + this.responseMimeType + '\'' + ", functionCallbacks="
+ this.functionCallbacks + ", functions=" + this.functions + ", googleSearchRetrieval="
+ this.googleSearchRetrieval + ", safetySettings=" + this.safetySettings + '}';
+ ", responseMimeType='" + this.responseMimeType + '\'' + ", toolCallbacks=" + this.toolCallbacks
+ ", toolNames=" + this.toolNames + ", googleSearchRetrieval=" + this.googleSearchRetrieval
+ ", safetySettings=" + this.safetySettings + '}';
}

@Override
public VertexAiGeminiChatOptions copy() {
return fromOptions(this);
}

public FunctionCallingOptions merge(ChatOptions options) {
VertexAiGeminiChatOptions.Builder builder = VertexAiGeminiChatOptions.builder();

// Merge chat-specific options
builder.model(options.getModel() != null ? options.getModel() : this.getModel())
.maxOutputTokens(options.getMaxTokens() != null ? options.getMaxTokens() : this.getMaxOutputTokens())
.stopSequences(options.getStopSequences() != null ? options.getStopSequences() : this.getStopSequences())
.temperature(options.getTemperature() != null ? options.getTemperature() : this.getTemperature())
.topP(options.getTopP() != null ? options.getTopP() : this.getTopP())
.topK(options.getTopK() != null ? options.getTopK() : this.getTopK());

// Try to get function-specific properties if options is a FunctionCallingOptions
if (options instanceof FunctionCallingOptions functionOptions) {
builder.proxyToolCalls(functionOptions.getProxyToolCalls() != null ? functionOptions.getProxyToolCalls()
: this.proxyToolCalls);

Set<String> functions = new HashSet<>();
if (this.functions != null) {
functions.addAll(this.functions);
}
if (functionOptions.getFunctions() != null) {
functions.addAll(functionOptions.getFunctions());
}
builder.functions(functions);

List<FunctionCallback> functionCallbacks = new ArrayList<>();
if (this.functionCallbacks != null) {
functionCallbacks.addAll(this.functionCallbacks);
}
if (functionOptions.getFunctionCallbacks() != null) {
functionCallbacks.addAll(functionOptions.getFunctionCallbacks());
}
builder.functionCallbacks(functionCallbacks);

Map<String, Object> context = new HashMap<>();
if (this.toolContext != null) {
context.putAll(this.toolContext);
}
if (functionOptions.getToolContext() != null) {
context.putAll(functionOptions.getToolContext());
}
builder.toolContext(context);
}
else {
// If not a FunctionCallingOptions, preserve current function-specific
// properties
builder.proxyToolCalls(this.proxyToolCalls);
builder.functions(this.functions != null ? new HashSet<>(this.functions) : null);
builder.functionCallbacks(this.functionCallbacks != null ? new ArrayList<>(this.functionCallbacks) : null);
builder.toolContext(this.toolContext != null ? new HashMap<>(this.toolContext) : null);
}

// Preserve Vertex AI Gemini-specific properties
builder.candidateCount(this.candidateCount)
.responseMimeType(this.responseMimeType)
.googleSearchRetrieval(this.googleSearchRetrieval)
.safetySettings(this.safetySettings != null ? new ArrayList<>(this.safetySettings) : null);

return builder.build();
}

public enum TransportType {

GRPC, REST
Expand Down Expand Up @@ -460,20 +447,35 @@ public Builder responseMimeType(String mimeType) {
return this;
}

@Deprecated
public Builder functionCallbacks(List<FunctionCallback> functionCallbacks) {
this.options.functionCallbacks = functionCallbacks;
return toolCallbacks(functionCallbacks);
}

public Builder toolCallbacks(List<FunctionCallback> toolCallbacks) {
this.options.toolCallbacks = toolCallbacks;
return this;
}

@Deprecated
public Builder functions(Set<String> functionNames) {
Assert.notNull(functionNames, "Function names must not be null");
this.options.functions = functionNames;
return this.toolNames(functionNames);
}

public Builder toolNames(Set<String> toolNames) {
Assert.notNull(toolNames, "Function names must not be null");
this.options.toolNames = toolNames;
return this;
}

@Deprecated
public Builder function(String functionName) {
Assert.hasText(functionName, "Function name must not be empty");
this.options.functions.add(functionName);
return this.toolName(functionName);
}

public Builder toolName(String toolName) {
Assert.hasText(toolName, "Function name must not be empty");
this.options.toolNames.add(toolName);
return this;
}

Expand All @@ -488,8 +490,13 @@ public Builder safetySettings(List<VertexAiGeminiSafetySetting> safetySettings)
return this;
}

@Deprecated
public Builder proxyToolCalls(boolean proxyToolCalls) {
this.options.proxyToolCalls = proxyToolCalls;
return this.internalToolExecutionEnabled(proxyToolCalls);
}

public Builder internalToolExecutionEnabled(boolean internalToolExecutionEnabled) {
this.options.internalToolExecutionEnabled = internalToolExecutionEnabled;
return this;
}

Expand Down
Loading