diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java index 30f030c3737..b6bd69c8c01 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java @@ -36,6 +36,7 @@ import com.google.cloud.vertexai.api.GenerationConfig; import com.google.cloud.vertexai.api.GoogleSearchRetrieval; import com.google.cloud.vertexai.api.Part; +import com.google.cloud.vertexai.api.SafetySetting; import com.google.cloud.vertexai.api.Schema; import com.google.cloud.vertexai.api.Tool; import com.google.cloud.vertexai.generativeai.GenerativeModel; @@ -46,6 +47,7 @@ import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; +import org.springframework.ai.vertexai.gemini.common.VertexAiGeminiSafetySetting; import reactor.core.publisher.Flux; import org.springframework.ai.chat.messages.AssistantMessage; @@ -453,7 +455,8 @@ GeminiRequest createGeminiRequest(Prompt prompt, VertexAiGeminiChatOptions updat GenerationConfig generationConfig = this.generationConfig; var generativeModelBuilder = new GenerativeModel.Builder().setModelName(this.defaultOptions.getModel()) - .setVertexAi(this.vertexAI); + .setVertexAi(this.vertexAI) + .setSafetySettings(toGeminiSafetySettings(this.defaultOptions.getSafetySettings())); if (prompt.getOptions() != null) { if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) { @@ -499,6 +502,11 @@ GeminiRequest createGeminiRequest(Prompt prompt, VertexAiGeminiChatOptions updat generativeModelBuilder.setTools(tools); } + if (prompt.getOptions() instanceof VertexAiGeminiChatOptions options + && !CollectionUtils.isEmpty(options.getSafetySettings())) { + generativeModelBuilder.setSafetySettings(toGeminiSafetySettings(options.getSafetySettings())); + } + generativeModelBuilder.setGenerationConfig(generationConfig); GenerativeModel generativeModel = generativeModelBuilder.build(); @@ -557,6 +565,16 @@ private List toGeminiContent(List instrucitons) { return contents; } + private List toGeminiSafetySettings(List safetySettings) { + return safetySettings.stream() + .map(safetySetting -> SafetySetting.newBuilder() + .setCategoryValue(safetySetting.getCategory().getValue()) + .setThresholdValue(safetySetting.getThreshold().getValue()) + .setMethodValue(safetySetting.getMethod().getValue()) + .build()) + .toList(); + } + private List getFunctionTools(Set functionNames) { final var tool = Tool.newBuilder(); diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java index bdd81c3e45d..de5be2151eb 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java @@ -31,6 +31,7 @@ import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallingOptions; import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel.ChatModel; +import org.springframework.ai.vertexai.gemini.common.VertexAiGeminiSafetySetting; import org.springframework.util.Assert; /** @@ -117,6 +118,9 @@ public class VertexAiGeminiChatOptions implements FunctionCallingOptions { @JsonIgnore private boolean googleSearchRetrieval = false; + @JsonIgnore + private List safetySettings = new ArrayList<>(); + @JsonIgnore private Boolean proxyToolCalls; @@ -143,6 +147,7 @@ public static VertexAiGeminiChatOptions fromOptions(VertexAiGeminiChatOptions fr options.setFunctions(fromOptions.getFunctions()); options.setResponseMimeType(fromOptions.getResponseMimeType()); options.setGoogleSearchRetrieval(fromOptions.getGoogleSearchRetrieval()); + options.setSafetySettings(fromOptions.getSafetySettings()); options.setProxyToolCalls(fromOptions.getProxyToolCalls()); options.setToolContext(fromOptions.getToolContext()); return options; @@ -269,6 +274,15 @@ public void setGoogleSearchRetrieval(boolean googleSearchRetrieval) { this.googleSearchRetrieval = googleSearchRetrieval; } + public List getSafetySettings() { + return safetySettings; + } + + public void setSafetySettings(List safetySettings) { + Assert.notNull(safetySettings, "safetySettings must not be null"); + this.safetySettings = safetySettings; + } + @Override public Boolean getProxyToolCalls() { return this.proxyToolCalls; @@ -304,6 +318,7 @@ public boolean equals(Object o) { && Objects.equals(this.responseMimeType, that.responseMimeType) && Objects.equals(this.functionCallbacks, that.functionCallbacks) && Objects.equals(this.functions, that.functions) + && Objects.equals(this.safetySettings, that.safetySettings) && Objects.equals(this.proxyToolCalls, that.proxyToolCalls) && Objects.equals(this.toolContext, that.toolContext); } @@ -312,7 +327,7 @@ public boolean equals(Object o) { public int hashCode() { return Objects.hash(this.stopSequences, this.temperature, this.topP, this.topK, this.candidateCount, this.maxOutputTokens, this.model, this.responseMimeType, this.functionCallbacks, this.functions, - this.googleSearchRetrieval, this.proxyToolCalls, this.toolContext); + this.googleSearchRetrieval, this.safetySettings, this.proxyToolCalls, this.toolContext); } @Override @@ -322,7 +337,7 @@ public String toString() { + this.candidateCount + ", maxOutputTokens=" + this.maxOutputTokens + ", model='" + this.model + '\'' + ", responseMimeType='" + this.responseMimeType + '\'' + ", functionCallbacks=" + this.functionCallbacks + ", functions=" + this.functions + ", googleSearchRetrieval=" - + this.googleSearchRetrieval + '}'; + + this.googleSearchRetrieval + ", safetySettings=" + this.safetySettings + '}'; } @Override @@ -408,6 +423,12 @@ public Builder withGoogleSearchRetrieval(boolean googleSearch) { return this; } + public Builder withSafetySettings(List safetySettings) { + Assert.notNull(safetySettings, "safetySettings must not be null"); + this.options.safetySettings = safetySettings; + return this; + } + public Builder withProxyToolCalls(boolean proxyToolCalls) { this.options.proxyToolCalls = proxyToolCalls; return this; diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/common/VertexAiGeminiSafetySetting.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/common/VertexAiGeminiSafetySetting.java new file mode 100644 index 00000000000..9e120cea2ad --- /dev/null +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/common/VertexAiGeminiSafetySetting.java @@ -0,0 +1,167 @@ +package org.springframework.ai.vertexai.gemini.common; + +public class VertexAiGeminiSafetySetting { + + /** + * Enum representing different threshold levels for blocking harmful content. + */ + public enum HarmBlockThreshold { + + HARM_BLOCK_THRESHOLD_UNSPECIFIED(0), BLOCK_LOW_AND_ABOVE(1), BLOCK_MEDIUM_AND_ABOVE(2), BLOCK_ONLY_HIGH(3), + BLOCK_NONE(4), OFF(5); + + private final int value; + + HarmBlockThreshold(int value) { + this.value = value; + } + + public int getValue() { + return value; + } + + } + + /** + * Enum representing methods for evaluating harmful content. + */ + public enum HarmBlockMethod { + + HARM_BLOCK_METHOD_UNSPECIFIED(0), SEVERITY(1), PROBABILITY(2); + + private final int value; + + HarmBlockMethod(int value) { + this.value = value; + } + + public int getValue() { + return value; + } + + } + + /** + * Enum representing different categories of harmful content. + */ + public enum HarmCategory { + + HARM_CATEGORY_UNSPECIFIED(0), HARM_CATEGORY_HATE_SPEECH(1), HARM_CATEGORY_DANGEROUS_CONTENT(2), + HARM_CATEGORY_HARASSMENT(3), HARM_CATEGORY_SEXUALLY_EXPLICIT(4); + + private final int value; + + HarmCategory(int value) { + this.value = value; + } + + public int getValue() { + return value; + } + + } + + private HarmCategory category; + + private HarmBlockThreshold threshold; + + private HarmBlockMethod method; + + // Default constructor + public VertexAiGeminiSafetySetting() { + this.category = HarmCategory.HARM_CATEGORY_UNSPECIFIED; + this.threshold = HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED; + this.method = HarmBlockMethod.HARM_BLOCK_METHOD_UNSPECIFIED; + } + + // Constructor with all fields + public VertexAiGeminiSafetySetting(HarmCategory category, HarmBlockThreshold threshold, HarmBlockMethod method) { + this.category = category; + this.threshold = threshold; + this.method = method; + } + + // Getters and setters + public HarmCategory getCategory() { + return category; + } + + public void setCategory(HarmCategory category) { + this.category = category; + } + + public HarmBlockThreshold getThreshold() { + return threshold; + } + + public void setThreshold(HarmBlockThreshold threshold) { + this.threshold = threshold; + } + + public HarmBlockMethod getMethod() { + return method; + } + + public void setMethod(HarmBlockMethod method) { + this.method = method; + } + + @Override + public String toString() { + return "SafetySetting{" + "category=" + category + ", threshold=" + threshold + ", method=" + method + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + + VertexAiGeminiSafetySetting that = (VertexAiGeminiSafetySetting) o; + + if (category != that.category) + return false; + if (threshold != that.threshold) + return false; + return method == that.method; + } + + @Override + public int hashCode() { + int result = category != null ? category.hashCode() : 0; + result = 31 * result + (threshold != null ? threshold.hashCode() : 0); + result = 31 * result + (method != null ? method.hashCode() : 0); + return result; + } + + public static class Builder { + + private HarmCategory category = HarmCategory.HARM_CATEGORY_UNSPECIFIED; + + private HarmBlockThreshold threshold = HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED; + + private HarmBlockMethod method = HarmBlockMethod.HARM_BLOCK_METHOD_UNSPECIFIED; + + public Builder withCategory(HarmCategory category) { + this.category = category; + return this; + } + + public Builder withThreshold(HarmBlockThreshold threshold) { + this.threshold = threshold; + return this; + } + + public Builder withMethod(HarmBlockMethod method) { + this.method = method; + return this; + } + + public VertexAiGeminiSafetySetting build() { + return new VertexAiGeminiSafetySetting(category, threshold, method); + } + + } + +} diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModelIT.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModelIT.java index 4d6a064448c..13fb4e1e3d6 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModelIT.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModelIT.java @@ -41,6 +41,7 @@ import org.springframework.ai.converter.ListOutputConverter; import org.springframework.ai.converter.MapOutputConverter; import org.springframework.ai.model.Media; +import org.springframework.ai.vertexai.gemini.common.VertexAiGeminiSafetySetting; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.SpringBootConfiguration; @@ -91,6 +92,18 @@ void googleSearchTool() { assertThat(response.getResult().getOutput().getContent()).containsAnyOf("Blackbeard", "Bartholomew"); } + @Test + void testSafetySettings() { + List safetySettings = List.of(new VertexAiGeminiSafetySetting.Builder() + .withCategory(VertexAiGeminiSafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT) + .withThreshold(VertexAiGeminiSafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE) + .build()); + Prompt prompt = new Prompt("What are common digital attack vectors?", + VertexAiGeminiChatOptions.builder().withSafetySettings(safetySettings).build()); + ChatResponse response = this.chatModel.call(prompt); + assertThat(response.getResult().getMetadata().getFinishReason()).isEqualTo("SAFETY"); + } + @NotNull private Prompt createPrompt(VertexAiGeminiChatOptions chatOptions) { String request = "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did."; diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/vertexai-gemini-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/vertexai-gemini-chat.adoc index 58ac4280124..dbc1ece15db 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/vertexai-gemini-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/vertexai-gemini-chat.adoc @@ -84,6 +84,7 @@ The prefix `spring.ai.vertex.ai.gemini.chat` is the property prefix that lets yo | spring.ai.vertex.ai.gemini.chat.options.presencePenalty | | - | spring.ai.vertex.ai.gemini.chat.options.functions | List of functions, identified by their names, to enable for function calling in a single prompt requests. Functions with those names must exist in the functionCallbacks registry. | - | spring.ai.vertex.ai.gemini.chat.options.proxy-tool-calls | If true, the Spring AI will not handle the function calls internally, but will proxy them to the client. Then is the client's responsibility to handle the function calls, dispatch them to the appropriate function, and return the results. If false (the default), the Spring AI will handle the function calls internally. Applicable only for chat models with function calling support | false +| spring.ai.vertex.ai.gemini.chat.options.safetySettings | List of safety settings to control safety filters, as defined by https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/configure-safety-filters[Vertex AI Safety Filters]. Each safety setting can have a method, threshold, and category. | - |====