Skip to content

Commit a59dd13

Browse files
committed
feat(vertex-ai-gemini): enhance jsonToStruct to support JSON arrays
Improve the jsonToStruct method in VertexAiGeminiChatModel to handle JSON arrays in addition to JSON objects. When a JSON array is detected, it's now properly converted to a Protobuf Struct with an items field containing the array elements. Resolves #2647 , #2849 Signed-off-by: Christian Tzolov <[email protected]>
1 parent 846165d commit a59dd13

File tree

2 files changed

+35
-7
lines changed

2 files changed

+35
-7
lines changed

models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@
1616

1717
package org.springframework.ai.vertexai.gemini;
1818

19-
import com.google.cloud.vertexai.api.Tool.GoogleSearch;
2019
import java.util.ArrayList;
2120
import java.util.Collection;
2221
import java.util.List;
2322
import java.util.Map;
2423

2524
import com.fasterxml.jackson.annotation.JsonInclude;
2625
import com.fasterxml.jackson.annotation.JsonInclude.Include;
26+
import com.fasterxml.jackson.databind.JsonNode;
2727
import com.google.cloud.vertexai.VertexAI;
2828
import com.google.cloud.vertexai.api.Candidate;
2929
import com.google.cloud.vertexai.api.Candidate.FinishReason;
@@ -33,15 +33,16 @@
3333
import com.google.cloud.vertexai.api.FunctionResponse;
3434
import com.google.cloud.vertexai.api.GenerateContentResponse;
3535
import com.google.cloud.vertexai.api.GenerationConfig;
36-
import com.google.cloud.vertexai.api.GoogleSearchRetrieval;
3736
import com.google.cloud.vertexai.api.Part;
3837
import com.google.cloud.vertexai.api.SafetySetting;
3938
import com.google.cloud.vertexai.api.Schema;
4039
import com.google.cloud.vertexai.api.Tool;
40+
import com.google.cloud.vertexai.api.Tool.GoogleSearch;
4141
import com.google.cloud.vertexai.generativeai.GenerativeModel;
4242
import com.google.cloud.vertexai.generativeai.PartMaker;
4343
import com.google.cloud.vertexai.generativeai.ResponseStream;
4444
import com.google.protobuf.Struct;
45+
import com.google.protobuf.Value;
4546
import com.google.protobuf.util.JsonFormat;
4647
import io.micrometer.observation.Observation;
4748
import io.micrometer.observation.ObservationRegistry;
@@ -226,7 +227,8 @@ public VertexAiGeminiChatModel(VertexAI vertexAI, VertexAiGeminiChatOptions defa
226227
this.observationRegistry = observationRegistry;
227228
this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate;
228229

229-
// Wrap the provided tool calling manager in a VertexToolCallingManager to ensure
230+
// Wrap the provided tool calling manager in a VertexToolCallingManager to
231+
// ensure
230232
// compatibility with Vertex AI's OpenAPI schema format.
231233
if (toolCallingManager instanceof VertexToolCallingManager) {
232234
this.toolCallingManager = toolCallingManager;
@@ -334,8 +336,34 @@ private static String structToJson(Struct struct) {
334336

335337
private static Struct jsonToStruct(String json) {
336338
try {
337-
var structBuilder = Struct.newBuilder();
338-
JsonFormat.parser().ignoringUnknownFields().merge(json, structBuilder);
339+
JsonNode rootNode = ModelOptionsUtils.OBJECT_MAPPER.readTree(json);
340+
341+
Struct.Builder structBuilder = Struct.newBuilder();
342+
343+
if (rootNode.isArray()) {
344+
// Handle JSON array
345+
List<Value> values = new ArrayList<>();
346+
347+
for (JsonNode element : rootNode) {
348+
String elementJson = element.toString();
349+
Struct.Builder elementBuilder = Struct.newBuilder();
350+
JsonFormat.parser().ignoringUnknownFields().merge(elementJson, elementBuilder);
351+
352+
// Add each parsed object as a value in an array field
353+
values.add(Value.newBuilder().setStructValue(elementBuilder.build()).build());
354+
}
355+
356+
// Add the array to the main struct with a field name like "items"
357+
structBuilder.putFields("items",
358+
Value.newBuilder()
359+
.setListValue(com.google.protobuf.ListValue.newBuilder().addAllValues(values).build())
360+
.build());
361+
}
362+
else {
363+
// Original behavior for single JSON object
364+
JsonFormat.parser().ignoringUnknownFields().merge(json, structBuilder);
365+
}
366+
339367
return structBuilder.build();
340368
}
341369
catch (Exception e) {

models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiChatModelToolCallingIT.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ public void functionCallTestInferredOpenApiSchema() {
126126

127127
assertThat(chatResponse.getMetadata()).isNotNull();
128128
assertThat(chatResponse.getMetadata().getUsage()).isNotNull();
129-
assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isGreaterThan(150).isLessThan(310);
129+
assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isGreaterThan(150).isLessThan(330);
130130

131131
ChatResponse response2 = this.chatModel
132132
.call(new Prompt("What is the payment status for transaction 696?", promptOptions));
@@ -201,7 +201,7 @@ public void functionCallUsageTestInferredOpenApiSchemaStream() {
201201
assertThat(chatResponse).isNotNull();
202202
assertThat(chatResponse.getMetadata()).isNotNull();
203203
assertThat(chatResponse.getMetadata().getUsage()).isNotNull();
204-
assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isGreaterThan(150).isLessThan(310);
204+
assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isGreaterThan(150).isLessThan(330);
205205

206206
}
207207

0 commit comments

Comments
 (0)