Skip to content

Commit 12773eb

Browse files
committed
feat(gemini) Support Safety Settings for VertexAiGeminiChatModel
feat(gemini) not null checks on safetySettings feat(gemini) check safety settings not empty instead of not null feat(gemini) Add VertexAiGeminiSafetySetting wrapper class feat(gemini) Update documentation with new property
1 parent c057148 commit 12773eb

File tree

5 files changed

+220
-1
lines changed

5 files changed

+220
-1
lines changed

models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import com.google.cloud.vertexai.api.GenerationConfig;
3737
import com.google.cloud.vertexai.api.GoogleSearchRetrieval;
3838
import com.google.cloud.vertexai.api.Part;
39+
import com.google.cloud.vertexai.api.SafetySetting;
3940
import com.google.cloud.vertexai.api.Schema;
4041
import com.google.cloud.vertexai.api.Tool;
4142
import com.google.cloud.vertexai.generativeai.GenerativeModel;
@@ -46,6 +47,7 @@
4647
import io.micrometer.observation.Observation;
4748
import io.micrometer.observation.ObservationRegistry;
4849
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
50+
import org.springframework.ai.vertexai.gemini.common.VertexAiGeminiSafetySetting;
4951
import reactor.core.publisher.Flux;
5052

5153
import org.springframework.ai.chat.messages.AssistantMessage;
@@ -453,7 +455,8 @@ GeminiRequest createGeminiRequest(Prompt prompt, VertexAiGeminiChatOptions updat
453455
GenerationConfig generationConfig = this.generationConfig;
454456

455457
var generativeModelBuilder = new GenerativeModel.Builder().setModelName(this.defaultOptions.getModel())
456-
.setVertexAi(this.vertexAI);
458+
.setVertexAi(this.vertexAI)
459+
.setSafetySettings(toGeminiSafetySettings(this.defaultOptions.getSafetySettings()));
457460

458461
if (prompt.getOptions() != null) {
459462
if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) {
@@ -499,6 +502,11 @@ GeminiRequest createGeminiRequest(Prompt prompt, VertexAiGeminiChatOptions updat
499502
generativeModelBuilder.setTools(tools);
500503
}
501504

505+
if (prompt.getOptions() instanceof VertexAiGeminiChatOptions options
506+
&& !CollectionUtils.isEmpty(options.getSafetySettings())) {
507+
generativeModelBuilder.setSafetySettings(toGeminiSafetySettings(options.getSafetySettings()));
508+
}
509+
502510
generativeModelBuilder.setGenerationConfig(generationConfig);
503511

504512
GenerativeModel generativeModel = generativeModelBuilder.build();
@@ -557,6 +565,16 @@ private List<Content> toGeminiContent(List<Message> instrucitons) {
557565
return contents;
558566
}
559567

568+
private List<SafetySetting> toGeminiSafetySettings(List<VertexAiGeminiSafetySetting> safetySettings) {
569+
return safetySettings.stream()
570+
.map(safetySetting -> SafetySetting.newBuilder()
571+
.setCategoryValue(safetySetting.getCategory().getValue())
572+
.setThresholdValue(safetySetting.getThreshold().getValue())
573+
.setMethodValue(safetySetting.getMethod().getValue())
574+
.build())
575+
.toList();
576+
}
577+
560578
private List<Tool> getFunctionTools(Set<String> functionNames) {
561579

562580
final var tool = Tool.newBuilder();

models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.springframework.ai.model.function.FunctionCallback;
3232
import org.springframework.ai.model.function.FunctionCallingOptions;
3333
import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel.ChatModel;
34+
import org.springframework.ai.vertexai.gemini.common.VertexAiGeminiSafetySetting;
3435
import org.springframework.util.Assert;
3536

3637
/**
@@ -117,6 +118,9 @@ public class VertexAiGeminiChatOptions implements FunctionCallingOptions {
117118
@JsonIgnore
118119
private boolean googleSearchRetrieval = false;
119120

121+
@JsonIgnore
122+
private List<VertexAiGeminiSafetySetting> safetySettings = new ArrayList<>();
123+
120124
@JsonIgnore
121125
private Boolean proxyToolCalls;
122126

@@ -143,6 +147,7 @@ public static VertexAiGeminiChatOptions fromOptions(VertexAiGeminiChatOptions fr
143147
options.setFunctions(fromOptions.getFunctions());
144148
options.setResponseMimeType(fromOptions.getResponseMimeType());
145149
options.setGoogleSearchRetrieval(fromOptions.getGoogleSearchRetrieval());
150+
options.setSafetySettings(fromOptions.getSafetySettings());
146151
options.setProxyToolCalls(fromOptions.getProxyToolCalls());
147152
options.setToolContext(fromOptions.getToolContext());
148153
return options;
@@ -269,6 +274,15 @@ public void setGoogleSearchRetrieval(boolean googleSearchRetrieval) {
269274
this.googleSearchRetrieval = googleSearchRetrieval;
270275
}
271276

277+
public List<VertexAiGeminiSafetySetting> getSafetySettings() {
278+
return safetySettings;
279+
}
280+
281+
public void setSafetySettings(List<VertexAiGeminiSafetySetting> safetySettings) {
282+
Assert.notNull(safetySettings, "safetySettings must not be null");
283+
this.safetySettings = safetySettings;
284+
}
285+
272286
@Override
273287
public Boolean getProxyToolCalls() {
274288
return this.proxyToolCalls;
@@ -408,6 +422,12 @@ public Builder withGoogleSearchRetrieval(boolean googleSearch) {
408422
return this;
409423
}
410424

425+
public Builder withSafetySettings(List<VertexAiGeminiSafetySetting> safetySettings) {
426+
Assert.notNull(safetySettings, "safetySettings must not be null");
427+
this.options.safetySettings = safetySettings;
428+
return this;
429+
}
430+
411431
public Builder withProxyToolCalls(boolean proxyToolCalls) {
412432
this.options.proxyToolCalls = proxyToolCalls;
413433
return this;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
package org.springframework.ai.vertexai.gemini.common;
2+
3+
public class VertexAiGeminiSafetySetting {
4+
5+
/**
6+
* Enum representing different threshold levels for blocking harmful content.
7+
*/
8+
public enum HarmBlockThreshold {
9+
10+
HARM_BLOCK_THRESHOLD_UNSPECIFIED(0), BLOCK_LOW_AND_ABOVE(1), BLOCK_MEDIUM_AND_ABOVE(2), BLOCK_ONLY_HIGH(3),
11+
BLOCK_NONE(4), OFF(5);
12+
13+
private final int value;
14+
15+
HarmBlockThreshold(int value) {
16+
this.value = value;
17+
}
18+
19+
public int getValue() {
20+
return value;
21+
}
22+
23+
}
24+
25+
/**
26+
* Enum representing methods for evaluating harmful content.
27+
*/
28+
public enum HarmBlockMethod {
29+
30+
HARM_BLOCK_METHOD_UNSPECIFIED(0), SEVERITY(1), PROBABILITY(2);
31+
32+
private final int value;
33+
34+
HarmBlockMethod(int value) {
35+
this.value = value;
36+
}
37+
38+
public int getValue() {
39+
return value;
40+
}
41+
42+
}
43+
44+
/**
45+
* Enum representing different categories of harmful content.
46+
*/
47+
public enum HarmCategory {
48+
49+
HARM_CATEGORY_UNSPECIFIED(0), HARM_CATEGORY_HATE_SPEECH(1), HARM_CATEGORY_DANGEROUS_CONTENT(2),
50+
HARM_CATEGORY_HARASSMENT(3), HARM_CATEGORY_SEXUALLY_EXPLICIT(4);
51+
52+
private final int value;
53+
54+
HarmCategory(int value) {
55+
this.value = value;
56+
}
57+
58+
public int getValue() {
59+
return value;
60+
}
61+
62+
}
63+
64+
private HarmCategory category;
65+
66+
private HarmBlockThreshold threshold;
67+
68+
private HarmBlockMethod method;
69+
70+
// Default constructor
71+
public VertexAiGeminiSafetySetting() {
72+
this.category = HarmCategory.HARM_CATEGORY_UNSPECIFIED;
73+
this.threshold = HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED;
74+
this.method = HarmBlockMethod.HARM_BLOCK_METHOD_UNSPECIFIED;
75+
}
76+
77+
// Constructor with all fields
78+
public VertexAiGeminiSafetySetting(HarmCategory category, HarmBlockThreshold threshold, HarmBlockMethod method) {
79+
this.category = category;
80+
this.threshold = threshold;
81+
this.method = method;
82+
}
83+
84+
// Getters and setters
85+
public HarmCategory getCategory() {
86+
return category;
87+
}
88+
89+
public void setCategory(HarmCategory category) {
90+
this.category = category;
91+
}
92+
93+
public HarmBlockThreshold getThreshold() {
94+
return threshold;
95+
}
96+
97+
public void setThreshold(HarmBlockThreshold threshold) {
98+
this.threshold = threshold;
99+
}
100+
101+
public HarmBlockMethod getMethod() {
102+
return method;
103+
}
104+
105+
public void setMethod(HarmBlockMethod method) {
106+
this.method = method;
107+
}
108+
109+
@Override
110+
public String toString() {
111+
return "SafetySetting{" + "category=" + category + ", threshold=" + threshold + ", method=" + method + '}';
112+
}
113+
114+
@Override
115+
public boolean equals(Object o) {
116+
if (this == o)
117+
return true;
118+
if (o == null || getClass() != o.getClass())
119+
return false;
120+
121+
VertexAiGeminiSafetySetting that = (VertexAiGeminiSafetySetting) o;
122+
123+
if (category != that.category)
124+
return false;
125+
if (threshold != that.threshold)
126+
return false;
127+
return method == that.method;
128+
}
129+
130+
@Override
131+
public int hashCode() {
132+
int result = category != null ? category.hashCode() : 0;
133+
result = 31 * result + (threshold != null ? threshold.hashCode() : 0);
134+
result = 31 * result + (method != null ? method.hashCode() : 0);
135+
return result;
136+
}
137+
138+
public static class Builder {
139+
140+
private HarmCategory category = HarmCategory.HARM_CATEGORY_UNSPECIFIED;
141+
142+
private HarmBlockThreshold threshold = HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED;
143+
144+
private HarmBlockMethod method = HarmBlockMethod.HARM_BLOCK_METHOD_UNSPECIFIED;
145+
146+
public Builder withCategory(HarmCategory category) {
147+
this.category = category;
148+
return this;
149+
}
150+
151+
public Builder withThreshold(HarmBlockThreshold threshold) {
152+
this.threshold = threshold;
153+
return this;
154+
}
155+
156+
public Builder withMethod(HarmBlockMethod method) {
157+
this.method = method;
158+
return this;
159+
}
160+
161+
public VertexAiGeminiSafetySetting build() {
162+
return new VertexAiGeminiSafetySetting(category, threshold, method);
163+
}
164+
165+
}
166+
167+
}

models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModelIT.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
import org.springframework.ai.converter.ListOutputConverter;
4242
import org.springframework.ai.converter.MapOutputConverter;
4343
import org.springframework.ai.model.Media;
44+
import org.springframework.ai.vertexai.gemini.common.VertexAiGeminiSafetySetting;
4445
import org.springframework.beans.factory.annotation.Autowired;
4546
import org.springframework.beans.factory.annotation.Value;
4647
import org.springframework.boot.SpringBootConfiguration;
@@ -91,6 +92,18 @@ void googleSearchTool() {
9192
assertThat(response.getResult().getOutput().getContent()).containsAnyOf("Blackbeard", "Bartholomew");
9293
}
9394

95+
@Test
96+
void testSafetySettings() {
97+
List<VertexAiGeminiSafetySetting> safetySettings = List.of(new VertexAiGeminiSafetySetting.Builder()
98+
.withCategory(VertexAiGeminiSafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT)
99+
.withThreshold(VertexAiGeminiSafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE)
100+
.build());
101+
Prompt prompt = new Prompt("What are common digital attack vectors?",
102+
VertexAiGeminiChatOptions.builder().withSafetySettings(safetySettings).build());
103+
ChatResponse response = this.chatModel.call(prompt);
104+
assertThat(response.getResult().getMetadata().getFinishReason()).isEqualTo("SAFETY");
105+
}
106+
94107
@NotNull
95108
private Prompt createPrompt(VertexAiGeminiChatOptions chatOptions) {
96109
String request = "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did.";

spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/vertexai-gemini-chat.adoc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ The prefix `spring.ai.vertex.ai.gemini.chat` is the property prefix that lets yo
8484
| spring.ai.vertex.ai.gemini.chat.options.presencePenalty | | -
8585
| spring.ai.vertex.ai.gemini.chat.options.functions | List of functions, identified by their names, to enable for function calling in a single prompt requests. Functions with those names must exist in the functionCallbacks registry. | -
8686
| spring.ai.vertex.ai.gemini.chat.options.proxy-tool-calls | If true, the Spring AI will not handle the function calls internally, but will proxy them to the client. Then is the client's responsibility to handle the function calls, dispatch them to the appropriate function, and return the results. If false (the default), the Spring AI will handle the function calls internally. Applicable only for chat models with function calling support | false
87+
| spring.ai.vertex.ai.gemini.chat.options.safetySettings | List of safety settings to control safety filters, as defined by https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/configure-safety-filters[Vertex AI Safety Filters]. Each safety setting can have a method, threshold, and category. | -
8788

8889
|====
8990

0 commit comments

Comments
 (0)