Skip to content

Commit e3ed72e

Browse files
committed
feat(gemini) Add VertexAiGeminiSafetySetting wrapper class
1 parent 5d6b589 commit e3ed72e

File tree

4 files changed

+190
-12
lines changed

4 files changed

+190
-12
lines changed

models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import com.google.cloud.vertexai.api.GenerationConfig;
3737
import com.google.cloud.vertexai.api.GoogleSearchRetrieval;
3838
import com.google.cloud.vertexai.api.Part;
39+
import com.google.cloud.vertexai.api.SafetySetting;
3940
import com.google.cloud.vertexai.api.Schema;
4041
import com.google.cloud.vertexai.api.Tool;
4142
import com.google.cloud.vertexai.generativeai.GenerativeModel;
@@ -46,6 +47,7 @@
4647
import io.micrometer.observation.Observation;
4748
import io.micrometer.observation.ObservationRegistry;
4849
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
50+
import org.springframework.ai.vertexai.gemini.common.VertexAiGeminiSafetySetting;
4951
import reactor.core.publisher.Flux;
5052

5153
import org.springframework.ai.chat.messages.AssistantMessage;
@@ -454,7 +456,7 @@ GeminiRequest createGeminiRequest(Prompt prompt, VertexAiGeminiChatOptions updat
454456

455457
var generativeModelBuilder = new GenerativeModel.Builder().setModelName(this.defaultOptions.getModel())
456458
.setVertexAi(this.vertexAI)
457-
.setSafetySettings(this.defaultOptions.getSafetySettings());
459+
.setSafetySettings(toGeminiSafetySettings(this.defaultOptions.getSafetySettings()));
458460

459461
if (prompt.getOptions() != null) {
460462
if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) {
@@ -502,7 +504,7 @@ GeminiRequest createGeminiRequest(Prompt prompt, VertexAiGeminiChatOptions updat
502504

503505
if (prompt.getOptions() instanceof VertexAiGeminiChatOptions options
504506
&& !CollectionUtils.isEmpty(options.getSafetySettings())) {
505-
generativeModelBuilder.setSafetySettings(options.getSafetySettings());
507+
generativeModelBuilder.setSafetySettings(toGeminiSafetySettings(options.getSafetySettings()));
506508
}
507509

508510
generativeModelBuilder.setGenerationConfig(generationConfig);
@@ -563,6 +565,16 @@ private List<Content> toGeminiContent(List<Message> instrucitons) {
563565
return contents;
564566
}
565567

568+
private List<SafetySetting> toGeminiSafetySettings(List<VertexAiGeminiSafetySetting> safetySettings) {
569+
return safetySettings.stream()
570+
.map(safetySetting -> SafetySetting.newBuilder()
571+
.setCategoryValue(safetySetting.getCategory().getValue())
572+
.setThresholdValue(safetySetting.getThreshold().getValue())
573+
.setMethodValue(safetySetting.getMethod().getValue())
574+
.build())
575+
.toList();
576+
}
577+
566578
private List<Tool> getFunctionTools(Set<String> functionNames) {
567579

568580
final var tool = Tool.newBuilder();

models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@
2828
import com.fasterxml.jackson.annotation.JsonInclude.Include;
2929
import com.fasterxml.jackson.annotation.JsonProperty;
3030

31-
import com.google.cloud.vertexai.api.SafetySetting;
3231
import org.springframework.ai.model.function.FunctionCallback;
3332
import org.springframework.ai.model.function.FunctionCallingOptions;
3433
import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel.ChatModel;
34+
import org.springframework.ai.vertexai.gemini.common.VertexAiGeminiSafetySetting;
3535
import org.springframework.util.Assert;
3636

3737
/**
@@ -119,7 +119,7 @@ public class VertexAiGeminiChatOptions implements FunctionCallingOptions {
119119
private boolean googleSearchRetrieval = false;
120120

121121
@JsonIgnore
122-
private List<SafetySetting> safetySettings = new ArrayList<>();
122+
private List<VertexAiGeminiSafetySetting> safetySettings = new ArrayList<>();
123123

124124
@JsonIgnore
125125
private Boolean proxyToolCalls;
@@ -274,11 +274,11 @@ public void setGoogleSearchRetrieval(boolean googleSearchRetrieval) {
274274
this.googleSearchRetrieval = googleSearchRetrieval;
275275
}
276276

277-
public List<SafetySetting> getSafetySettings() {
277+
public List<VertexAiGeminiSafetySetting> getSafetySettings() {
278278
return safetySettings;
279279
}
280280

281-
public void setSafetySettings(List<SafetySetting> safetySettings) {
281+
public void setSafetySettings(List<VertexAiGeminiSafetySetting> safetySettings) {
282282
Assert.notNull(safetySettings, "safetySettings must not be null");
283283
this.safetySettings = safetySettings;
284284
}
@@ -422,7 +422,7 @@ public Builder withGoogleSearchRetrieval(boolean googleSearch) {
422422
return this;
423423
}
424424

425-
public Builder withSafetySettings(List<SafetySetting> safetySettings) {
425+
public Builder withSafetySettings(List<VertexAiGeminiSafetySetting> safetySettings) {
426426
Assert.notNull(safetySettings, "safetySettings must not be null");
427427
this.options.safetySettings = safetySettings;
428428
return this;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
package org.springframework.ai.vertexai.gemini.common;
2+
3+
public class VertexAiGeminiSafetySetting {
4+
5+
/**
6+
* Enum representing different threshold levels for blocking harmful content.
7+
*/
8+
public enum HarmBlockThreshold {
9+
10+
HARM_BLOCK_THRESHOLD_UNSPECIFIED(0), BLOCK_LOW_AND_ABOVE(1), BLOCK_MEDIUM_AND_ABOVE(2), BLOCK_ONLY_HIGH(3),
11+
BLOCK_NONE(4), OFF(5);
12+
13+
private final int value;
14+
15+
HarmBlockThreshold(int value) {
16+
this.value = value;
17+
}
18+
19+
public int getValue() {
20+
return value;
21+
}
22+
23+
}
24+
25+
/**
26+
* Enum representing methods for evaluating harmful content.
27+
*/
28+
public enum HarmBlockMethod {
29+
30+
HARM_BLOCK_METHOD_UNSPECIFIED(0), SEVERITY(1), PROBABILITY(2);
31+
32+
private final int value;
33+
34+
HarmBlockMethod(int value) {
35+
this.value = value;
36+
}
37+
38+
public int getValue() {
39+
return value;
40+
}
41+
42+
}
43+
44+
/**
45+
* Enum representing different categories of harmful content.
46+
*/
47+
public enum HarmCategory {
48+
49+
HARM_CATEGORY_UNSPECIFIED(0), HARM_CATEGORY_HATE_SPEECH(1), HARM_CATEGORY_DANGEROUS_CONTENT(2),
50+
HARM_CATEGORY_HARASSMENT(3), HARM_CATEGORY_SEXUALLY_EXPLICIT(4);
51+
52+
private final int value;
53+
54+
HarmCategory(int value) {
55+
this.value = value;
56+
}
57+
58+
public int getValue() {
59+
return value;
60+
}
61+
62+
}
63+
64+
private HarmCategory category;
65+
66+
private HarmBlockThreshold threshold;
67+
68+
private HarmBlockMethod method;
69+
70+
// Default constructor
71+
public VertexAiGeminiSafetySetting() {
72+
this.category = HarmCategory.HARM_CATEGORY_UNSPECIFIED;
73+
this.threshold = HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED;
74+
this.method = HarmBlockMethod.HARM_BLOCK_METHOD_UNSPECIFIED;
75+
}
76+
77+
// Constructor with all fields
78+
public VertexAiGeminiSafetySetting(HarmCategory category, HarmBlockThreshold threshold, HarmBlockMethod method) {
79+
this.category = category;
80+
this.threshold = threshold;
81+
this.method = method;
82+
}
83+
84+
// Getters and setters
85+
public HarmCategory getCategory() {
86+
return category;
87+
}
88+
89+
public void setCategory(HarmCategory category) {
90+
this.category = category;
91+
}
92+
93+
public HarmBlockThreshold getThreshold() {
94+
return threshold;
95+
}
96+
97+
public void setThreshold(HarmBlockThreshold threshold) {
98+
this.threshold = threshold;
99+
}
100+
101+
public HarmBlockMethod getMethod() {
102+
return method;
103+
}
104+
105+
public void setMethod(HarmBlockMethod method) {
106+
this.method = method;
107+
}
108+
109+
@Override
110+
public String toString() {
111+
return "SafetySetting{" + "category=" + category + ", threshold=" + threshold + ", method=" + method + '}';
112+
}
113+
114+
@Override
115+
public boolean equals(Object o) {
116+
if (this == o)
117+
return true;
118+
if (o == null || getClass() != o.getClass())
119+
return false;
120+
121+
VertexAiGeminiSafetySetting that = (VertexAiGeminiSafetySetting) o;
122+
123+
if (category != that.category)
124+
return false;
125+
if (threshold != that.threshold)
126+
return false;
127+
return method == that.method;
128+
}
129+
130+
@Override
131+
public int hashCode() {
132+
int result = category != null ? category.hashCode() : 0;
133+
result = 31 * result + (threshold != null ? threshold.hashCode() : 0);
134+
result = 31 * result + (method != null ? method.hashCode() : 0);
135+
return result;
136+
}
137+
138+
public static class Builder {
139+
140+
private HarmCategory category = HarmCategory.HARM_CATEGORY_UNSPECIFIED;
141+
142+
private HarmBlockThreshold threshold = HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED;
143+
144+
private HarmBlockMethod method = HarmBlockMethod.HARM_BLOCK_METHOD_UNSPECIFIED;
145+
146+
public Builder withCategory(HarmCategory category) {
147+
this.category = category;
148+
return this;
149+
}
150+
151+
public Builder withThreshold(HarmBlockThreshold threshold) {
152+
this.threshold = threshold;
153+
return this;
154+
}
155+
156+
public Builder withMethod(HarmBlockMethod method) {
157+
this.method = method;
158+
return this;
159+
}
160+
161+
public VertexAiGeminiSafetySetting build() {
162+
return new VertexAiGeminiSafetySetting(category, threshold, method);
163+
}
164+
165+
}
166+
167+
}

models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModelIT.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@
2525

2626
import com.google.cloud.vertexai.Transport;
2727
import com.google.cloud.vertexai.VertexAI;
28-
import com.google.cloud.vertexai.api.HarmCategory;
29-
import com.google.cloud.vertexai.api.SafetySetting;
3028
import org.jetbrains.annotations.NotNull;
3129
import org.junit.jupiter.api.Test;
3230
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
@@ -43,6 +41,7 @@
4341
import org.springframework.ai.converter.ListOutputConverter;
4442
import org.springframework.ai.converter.MapOutputConverter;
4543
import org.springframework.ai.model.Media;
44+
import org.springframework.ai.vertexai.gemini.common.VertexAiGeminiSafetySetting;
4645
import org.springframework.beans.factory.annotation.Autowired;
4746
import org.springframework.beans.factory.annotation.Value;
4847
import org.springframework.boot.SpringBootConfiguration;
@@ -95,9 +94,9 @@ void googleSearchTool() {
9594

9695
@Test
9796
void testSafetySettings() {
98-
List<SafetySetting> safetySettings = List.of(SafetySetting.newBuilder()
99-
.setCategory(HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT)
100-
.setThreshold(SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE)
97+
List<VertexAiGeminiSafetySetting> safetySettings = List.of(new VertexAiGeminiSafetySetting.Builder()
98+
.withCategory(VertexAiGeminiSafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT)
99+
.withThreshold(VertexAiGeminiSafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE)
101100
.build());
102101
Prompt prompt = new Prompt("What are common digital attack vectors?",
103102
VertexAiGeminiChatOptions.builder().withSafetySettings(safetySettings).build());

0 commit comments

Comments
 (0)