Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import com.google.cloud.vertexai.api.GenerationConfig;
import com.google.cloud.vertexai.api.GoogleSearchRetrieval;
import com.google.cloud.vertexai.api.Part;
import com.google.cloud.vertexai.api.SafetySetting;
import com.google.cloud.vertexai.api.Schema;
import com.google.cloud.vertexai.api.Tool;
import com.google.cloud.vertexai.generativeai.GenerativeModel;
Expand All @@ -46,6 +47,7 @@
import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationRegistry;
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
import org.springframework.ai.vertexai.gemini.common.VertexAiGeminiSafetySetting;
import reactor.core.publisher.Flux;

import org.springframework.ai.chat.messages.AssistantMessage;
Expand Down Expand Up @@ -453,7 +455,8 @@ GeminiRequest createGeminiRequest(Prompt prompt, VertexAiGeminiChatOptions updat
GenerationConfig generationConfig = this.generationConfig;

var generativeModelBuilder = new GenerativeModel.Builder().setModelName(this.defaultOptions.getModel())
.setVertexAi(this.vertexAI);
.setVertexAi(this.vertexAI)
.setSafetySettings(toGeminiSafetySettings(this.defaultOptions.getSafetySettings()));

if (prompt.getOptions() != null) {
if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) {
Expand Down Expand Up @@ -499,6 +502,11 @@ GeminiRequest createGeminiRequest(Prompt prompt, VertexAiGeminiChatOptions updat
generativeModelBuilder.setTools(tools);
}

if (prompt.getOptions() instanceof VertexAiGeminiChatOptions options
&& !CollectionUtils.isEmpty(options.getSafetySettings())) {
generativeModelBuilder.setSafetySettings(toGeminiSafetySettings(options.getSafetySettings()));
}

generativeModelBuilder.setGenerationConfig(generationConfig);

GenerativeModel generativeModel = generativeModelBuilder.build();
Expand Down Expand Up @@ -557,6 +565,16 @@ private List<Content> toGeminiContent(List<Message> instrucitons) {
return contents;
}

private List<SafetySetting> toGeminiSafetySettings(List<VertexAiGeminiSafetySetting> safetySettings) {
return safetySettings.stream()
.map(safetySetting -> SafetySetting.newBuilder()
.setCategoryValue(safetySetting.getCategory().getValue())
.setThresholdValue(safetySetting.getThreshold().getValue())
.setMethodValue(safetySetting.getMethod().getValue())
.build())
.toList();
}

private List<Tool> getFunctionTools(Set<String> functionNames) {

final var tool = Tool.newBuilder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel.ChatModel;
import org.springframework.ai.vertexai.gemini.common.VertexAiGeminiSafetySetting;
import org.springframework.util.Assert;

/**
Expand Down Expand Up @@ -117,6 +118,9 @@ public class VertexAiGeminiChatOptions implements FunctionCallingOptions {
@JsonIgnore
private boolean googleSearchRetrieval = false;

@JsonIgnore
private List<VertexAiGeminiSafetySetting> safetySettings = new ArrayList<>();

@JsonIgnore
private Boolean proxyToolCalls;

Expand All @@ -143,6 +147,7 @@ public static VertexAiGeminiChatOptions fromOptions(VertexAiGeminiChatOptions fr
options.setFunctions(fromOptions.getFunctions());
options.setResponseMimeType(fromOptions.getResponseMimeType());
options.setGoogleSearchRetrieval(fromOptions.getGoogleSearchRetrieval());
options.setSafetySettings(fromOptions.getSafetySettings());
options.setProxyToolCalls(fromOptions.getProxyToolCalls());
options.setToolContext(fromOptions.getToolContext());
return options;
Expand Down Expand Up @@ -269,6 +274,15 @@ public void setGoogleSearchRetrieval(boolean googleSearchRetrieval) {
this.googleSearchRetrieval = googleSearchRetrieval;
}

public List<VertexAiGeminiSafetySetting> getSafetySettings() {
return safetySettings;
}

public void setSafetySettings(List<VertexAiGeminiSafetySetting> safetySettings) {
Assert.notNull(safetySettings, "safetySettings must not be null");
this.safetySettings = safetySettings;
}

@Override
public Boolean getProxyToolCalls() {
return this.proxyToolCalls;
Expand Down Expand Up @@ -304,6 +318,7 @@ public boolean equals(Object o) {
&& Objects.equals(this.responseMimeType, that.responseMimeType)
&& Objects.equals(this.functionCallbacks, that.functionCallbacks)
&& Objects.equals(this.functions, that.functions)
&& Objects.equals(this.safetySettings, that.safetySettings)
&& Objects.equals(this.proxyToolCalls, that.proxyToolCalls)
&& Objects.equals(this.toolContext, that.toolContext);
}
Expand All @@ -312,7 +327,7 @@ public boolean equals(Object o) {
public int hashCode() {
return Objects.hash(this.stopSequences, this.temperature, this.topP, this.topK, this.candidateCount,
this.maxOutputTokens, this.model, this.responseMimeType, this.functionCallbacks, this.functions,
this.googleSearchRetrieval, this.proxyToolCalls, this.toolContext);
this.googleSearchRetrieval, this.safetySettings, this.proxyToolCalls, this.toolContext);
}

@Override
Expand All @@ -322,7 +337,7 @@ public String toString() {
+ this.candidateCount + ", maxOutputTokens=" + this.maxOutputTokens + ", model='" + this.model + '\''
+ ", responseMimeType='" + this.responseMimeType + '\'' + ", functionCallbacks="
+ this.functionCallbacks + ", functions=" + this.functions + ", googleSearchRetrieval="
+ this.googleSearchRetrieval + '}';
+ this.googleSearchRetrieval + ", safetySettings=" + this.safetySettings + '}';
}

@Override
Expand Down Expand Up @@ -408,6 +423,12 @@ public Builder withGoogleSearchRetrieval(boolean googleSearch) {
return this;
}

public Builder withSafetySettings(List<VertexAiGeminiSafetySetting> safetySettings) {
Assert.notNull(safetySettings, "safetySettings must not be null");
this.options.safetySettings = safetySettings;
return this;
}

public Builder withProxyToolCalls(boolean proxyToolCalls) {
this.options.proxyToolCalls = proxyToolCalls;
return this;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
package org.springframework.ai.vertexai.gemini.common;

public class VertexAiGeminiSafetySetting {

/**
* Enum representing different threshold levels for blocking harmful content.
*/
public enum HarmBlockThreshold {

HARM_BLOCK_THRESHOLD_UNSPECIFIED(0), BLOCK_LOW_AND_ABOVE(1), BLOCK_MEDIUM_AND_ABOVE(2), BLOCK_ONLY_HIGH(3),
BLOCK_NONE(4), OFF(5);

private final int value;

HarmBlockThreshold(int value) {
this.value = value;
}

public int getValue() {
return value;
}

}

/**
* Enum representing methods for evaluating harmful content.
*/
public enum HarmBlockMethod {

HARM_BLOCK_METHOD_UNSPECIFIED(0), SEVERITY(1), PROBABILITY(2);

private final int value;

HarmBlockMethod(int value) {
this.value = value;
}

public int getValue() {
return value;
}

}

/**
* Enum representing different categories of harmful content.
*/
public enum HarmCategory {

HARM_CATEGORY_UNSPECIFIED(0), HARM_CATEGORY_HATE_SPEECH(1), HARM_CATEGORY_DANGEROUS_CONTENT(2),
HARM_CATEGORY_HARASSMENT(3), HARM_CATEGORY_SEXUALLY_EXPLICIT(4);

private final int value;

HarmCategory(int value) {
this.value = value;
}

public int getValue() {
return value;
}

}

private HarmCategory category;

private HarmBlockThreshold threshold;

private HarmBlockMethod method;

// Default constructor
public VertexAiGeminiSafetySetting() {
this.category = HarmCategory.HARM_CATEGORY_UNSPECIFIED;
this.threshold = HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED;
this.method = HarmBlockMethod.HARM_BLOCK_METHOD_UNSPECIFIED;
}

// Constructor with all fields
public VertexAiGeminiSafetySetting(HarmCategory category, HarmBlockThreshold threshold, HarmBlockMethod method) {
this.category = category;
this.threshold = threshold;
this.method = method;
}

// Getters and setters
public HarmCategory getCategory() {
return category;
}

public void setCategory(HarmCategory category) {
this.category = category;
}

public HarmBlockThreshold getThreshold() {
return threshold;
}

public void setThreshold(HarmBlockThreshold threshold) {
this.threshold = threshold;
}

public HarmBlockMethod getMethod() {
return method;
}

public void setMethod(HarmBlockMethod method) {
this.method = method;
}

@Override
public String toString() {
return "SafetySetting{" + "category=" + category + ", threshold=" + threshold + ", method=" + method + '}';
}

@Override
public boolean equals(Object o) {
if (this == o)
return true;
if (o == null || getClass() != o.getClass())
return false;

VertexAiGeminiSafetySetting that = (VertexAiGeminiSafetySetting) o;

if (category != that.category)
return false;
if (threshold != that.threshold)
return false;
return method == that.method;
}

@Override
public int hashCode() {
int result = category != null ? category.hashCode() : 0;
result = 31 * result + (threshold != null ? threshold.hashCode() : 0);
result = 31 * result + (method != null ? method.hashCode() : 0);
return result;
}

public static class Builder {

private HarmCategory category = HarmCategory.HARM_CATEGORY_UNSPECIFIED;

private HarmBlockThreshold threshold = HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED;

private HarmBlockMethod method = HarmBlockMethod.HARM_BLOCK_METHOD_UNSPECIFIED;

public Builder withCategory(HarmCategory category) {
this.category = category;
return this;
}

public Builder withThreshold(HarmBlockThreshold threshold) {
this.threshold = threshold;
return this;
}

public Builder withMethod(HarmBlockMethod method) {
this.method = method;
return this;
}

public VertexAiGeminiSafetySetting build() {
return new VertexAiGeminiSafetySetting(category, threshold, method);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.springframework.ai.converter.ListOutputConverter;
import org.springframework.ai.converter.MapOutputConverter;
import org.springframework.ai.model.Media;
import org.springframework.ai.vertexai.gemini.common.VertexAiGeminiSafetySetting;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.SpringBootConfiguration;
Expand Down Expand Up @@ -91,6 +92,18 @@ void googleSearchTool() {
assertThat(response.getResult().getOutput().getContent()).containsAnyOf("Blackbeard", "Bartholomew");
}

@Test
void testSafetySettings() {
List<VertexAiGeminiSafetySetting> safetySettings = List.of(new VertexAiGeminiSafetySetting.Builder()
.withCategory(VertexAiGeminiSafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT)
.withThreshold(VertexAiGeminiSafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE)
.build());
Prompt prompt = new Prompt("What are common digital attack vectors?",
VertexAiGeminiChatOptions.builder().withSafetySettings(safetySettings).build());
ChatResponse response = this.chatModel.call(prompt);
assertThat(response.getResult().getMetadata().getFinishReason()).isEqualTo("SAFETY");
}

@NotNull
private Prompt createPrompt(VertexAiGeminiChatOptions chatOptions) {
String request = "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did.";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ The prefix `spring.ai.vertex.ai.gemini.chat` is the property prefix that lets yo
| spring.ai.vertex.ai.gemini.chat.options.presencePenalty | | -
| spring.ai.vertex.ai.gemini.chat.options.functions | List of functions, identified by their names, to enable for function calling in a single prompt requests. Functions with those names must exist in the functionCallbacks registry. | -
| spring.ai.vertex.ai.gemini.chat.options.proxy-tool-calls | If true, the Spring AI will not handle the function calls internally, but will proxy them to the client. Then is the client's responsibility to handle the function calls, dispatch them to the appropriate function, and return the results. If false (the default), the Spring AI will handle the function calls internally. Applicable only for chat models with function calling support | false
| spring.ai.vertex.ai.gemini.chat.options.safetySettings | List of safety settings to control safety filters, as defined by https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/configure-safety-filters[Vertex AI Safety Filters]. Each safety setting can have a method, threshold, and category. | -

|====

Expand Down