Skip to content

Commit deffd1b

Browse files
committed
feat(gemini) Support Safety Settings for VertexAiGeminiChatModel
1 parent c057148 commit deffd1b

File tree

3 files changed

+42
-1
lines changed

3 files changed

+42
-1
lines changed

models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,8 @@ GeminiRequest createGeminiRequest(Prompt prompt, VertexAiGeminiChatOptions updat
453453
GenerationConfig generationConfig = this.generationConfig;
454454

455455
var generativeModelBuilder = new GenerativeModel.Builder().setModelName(this.defaultOptions.getModel())
456-
.setVertexAi(this.vertexAI);
456+
.setVertexAi(this.vertexAI)
457+
.setSafetySettings(this.defaultOptions.getSafetySettings());
457458

458459
if (prompt.getOptions() != null) {
459460
if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) {
@@ -499,6 +500,10 @@ GeminiRequest createGeminiRequest(Prompt prompt, VertexAiGeminiChatOptions updat
499500
generativeModelBuilder.setTools(tools);
500501
}
501502

503+
if (prompt.getOptions() instanceof VertexAiGeminiChatOptions options && options.getSafetySettings() != null) {
504+
generativeModelBuilder.setSafetySettings(options.getSafetySettings());
505+
}
506+
502507
generativeModelBuilder.setGenerationConfig(generationConfig);
503508

504509
GenerativeModel generativeModel = generativeModelBuilder.build();

models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import com.fasterxml.jackson.annotation.JsonInclude.Include;
2929
import com.fasterxml.jackson.annotation.JsonProperty;
3030

31+
import com.google.cloud.vertexai.api.SafetySetting;
3132
import org.springframework.ai.model.function.FunctionCallback;
3233
import org.springframework.ai.model.function.FunctionCallingOptions;
3334
import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel.ChatModel;
@@ -117,6 +118,9 @@ public class VertexAiGeminiChatOptions implements FunctionCallingOptions {
117118
@JsonIgnore
118119
private boolean googleSearchRetrieval = false;
119120

121+
@JsonIgnore
122+
private List<SafetySetting> safetySettings = new ArrayList<>();
123+
120124
@JsonIgnore
121125
private Boolean proxyToolCalls;
122126

@@ -143,6 +147,7 @@ public static VertexAiGeminiChatOptions fromOptions(VertexAiGeminiChatOptions fr
143147
options.setFunctions(fromOptions.getFunctions());
144148
options.setResponseMimeType(fromOptions.getResponseMimeType());
145149
options.setGoogleSearchRetrieval(fromOptions.getGoogleSearchRetrieval());
150+
options.setSafetySettings(fromOptions.getSafetySettings());
146151
options.setProxyToolCalls(fromOptions.getProxyToolCalls());
147152
options.setToolContext(fromOptions.getToolContext());
148153
return options;
@@ -269,6 +274,14 @@ public void setGoogleSearchRetrieval(boolean googleSearchRetrieval) {
269274
this.googleSearchRetrieval = googleSearchRetrieval;
270275
}
271276

277+
public List<SafetySetting> getSafetySettings() {
278+
return safetySettings;
279+
}
280+
281+
public void setSafetySettings(List<SafetySetting> safetySettings) {
282+
this.safetySettings = safetySettings;
283+
}
284+
272285
@Override
273286
public Boolean getProxyToolCalls() {
274287
return this.proxyToolCalls;
@@ -408,6 +421,11 @@ public Builder withGoogleSearchRetrieval(boolean googleSearch) {
408421
return this;
409422
}
410423

424+
public Builder withSafetySettings(List<SafetySetting> safetySettings) {
425+
this.options.safetySettings = safetySettings;
426+
return this;
427+
}
428+
411429
public Builder withProxyToolCalls(boolean proxyToolCalls) {
412430
this.options.proxyToolCalls = proxyToolCalls;
413431
return this;

models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModelIT.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626
import com.google.cloud.vertexai.Transport;
2727
import com.google.cloud.vertexai.VertexAI;
28+
import com.google.cloud.vertexai.api.HarmCategory;
29+
import com.google.cloud.vertexai.api.SafetySetting;
2830
import org.jetbrains.annotations.NotNull;
2931
import org.junit.jupiter.api.Test;
3032
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
@@ -91,6 +93,22 @@ void googleSearchTool() {
9193
assertThat(response.getResult().getOutput().getContent()).containsAnyOf("Blackbeard", "Bartholomew");
9294
}
9395

96+
@Test
97+
void testSafetySettings() {
98+
List<SafetySetting> safetySettings = List.of(
99+
SafetySetting.newBuilder()
100+
.setCategory(HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT)
101+
.setThreshold(SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE)
102+
.build()
103+
);
104+
Prompt prompt = new Prompt(
105+
"What are common digital attack vectors?",
106+
VertexAiGeminiChatOptions.builder().withSafetySettings(safetySettings).build()
107+
);
108+
ChatResponse response = this.chatModel.call(prompt);
109+
assertThat(response.getResult().getMetadata().getFinishReason()).isEqualTo("SAFETY");
110+
}
111+
94112
@NotNull
95113
private Prompt createPrompt(VertexAiGeminiChatOptions chatOptions) {
96114
String request = "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did.";

0 commit comments

Comments
 (0)