Skip to content

Commit 086117e

Browse files
committed
High-level API function calling support for VertexAI Gemini
- Refactor VertexAiChatModel's function calling handling to use Spring AI abstractions.
1 parent fac38c5 commit 086117e

File tree

9 files changed

+696
-142
lines changed

9 files changed

+696
-142
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import org.springframework.ai.chat.messages.Message;
2222
import org.springframework.ai.chat.messages.MessageType;
2323
import org.springframework.ai.chat.messages.ToolResponseMessage;
24+
import org.springframework.ai.chat.messages.ToolResponseMessage.ToolResponse;
2425
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
2526
import org.springframework.ai.chat.metadata.RateLimit;
2627
import org.springframework.ai.chat.model.ChatModel;
@@ -265,7 +266,7 @@ private List<Message> handleToolCallRequests(List<Message> previousMessages, Cha
265266
AssistantMessage assistantMessage = new AssistantMessage(nativeAssistantMessage.content(), Map.of(),
266267
assistantToolCalls);
267268

268-
List<ToolResponseMessage> toolResponseMessages = this.executeFuncitons(assistantMessage);
269+
List<ToolResponseMessage> toolResponseMessages = this.executeFuncitons(assistantMessage, false);
269270

270271
// History
271272
List<Message> messages = new ArrayList<>(previousMessages);
@@ -337,8 +338,11 @@ else if (message.getMessageType() == MessageType.ASSISTANT) {
337338
}
338339
else if (message.getMessageType() == MessageType.TOOL) {
339340
ToolResponseMessage toolMessage = (ToolResponseMessage) message;
340-
return new ChatCompletionMessage(toolMessage.getContent(), ChatCompletionMessage.Role.TOOL,
341-
toolMessage.getName(), toolMessage.getId(), null);
341+
Assert.isTrue(toolMessage.getResponses().size() == 1,
342+
"ToolResponseMessage must have exactly one response");
343+
ToolResponse response = toolMessage.getResponses().get(0);
344+
return new ChatCompletionMessage(response.respoinse(), ChatCompletionMessage.Role.TOOL, response.name(),
345+
response.id(), null);
342346
}
343347
else {
344348
throw new IllegalArgumentException("Unsupported message type: " + message.getMessageType());

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModel3IT.java renamed to models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelFunctionCallingIT.java

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,11 @@
4343

4444
import static org.assertj.core.api.Assertions.assertThat;
4545

46-
@SpringBootTest(classes = OpenAiChatModel3IT.Config.class)
46+
@SpringBootTest(classes = OpenAiChatModelFunctionCallingIT.Config.class)
4747
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
48-
class OpenAiChatModel3IT {
48+
class OpenAiChatModelFunctionCallingIT {
4949

50-
private static final Logger logger = LoggerFactory.getLogger(OpenAiChatModel3IT.class);
50+
private static final Logger logger = LoggerFactory.getLogger(OpenAiChatModelFunctionCallingIT.class);
5151

5252
@Autowired
5353
ChatModel chatModel;
@@ -72,9 +72,7 @@ void functionCallTest() {
7272

7373
logger.info("Response: {}", response);
7474

75-
assertThat(response.getResult().getOutput().getContent()).containsAnyOf("30.0", "30");
76-
assertThat(response.getResult().getOutput().getContent()).containsAnyOf("10.0", "10");
77-
assertThat(response.getResult().getOutput().getContent()).containsAnyOf("15.0", "15");
75+
assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15");
7876
}
7977

8078
@Test
@@ -105,9 +103,7 @@ void streamFunctionCallTest() {
105103
.collect(Collectors.joining());
106104
logger.info("Response: {}", content);
107105

108-
assertThat(content).containsAnyOf("30.0", "30");
109-
assertThat(content).containsAnyOf("10.0", "10");
110-
assertThat(content).containsAnyOf("15.0", "15");
106+
assertThat(content).contains("30", "10", "15");
111107
}
112108

113109
@SpringBootConfiguration

models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java

Lines changed: 118 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,39 @@
1515
*/
1616
package org.springframework.ai.vertexai.gemini;
1717

18+
import java.util.ArrayList;
19+
import java.util.HashSet;
20+
import java.util.List;
21+
import java.util.Map;
22+
import java.util.Set;
23+
import java.util.stream.Collectors;
24+
25+
import org.springframework.ai.chat.messages.AssistantMessage;
26+
import org.springframework.ai.chat.messages.Message;
27+
import org.springframework.ai.chat.messages.MessageType;
28+
import org.springframework.ai.chat.messages.ToolResponseMessage;
29+
import org.springframework.ai.chat.messages.UserMessage;
30+
import org.springframework.ai.chat.model.ChatModel;
31+
import org.springframework.ai.chat.model.ChatResponse;
32+
import org.springframework.ai.chat.model.Generation;
33+
import org.springframework.ai.chat.prompt.ChatOptions;
34+
import org.springframework.ai.chat.prompt.Prompt;
35+
import org.springframework.ai.model.ChatModelDescription;
36+
import org.springframework.ai.model.ModelOptionsUtils;
37+
import org.springframework.ai.model.function.AbstractToolCallSupport;
38+
import org.springframework.ai.model.function.FunctionCallbackContext;
39+
import org.springframework.ai.vertexai.gemini.metadata.VertexAiChatResponseMetadata;
40+
import org.springframework.ai.vertexai.gemini.metadata.VertexAiUsage;
41+
import org.springframework.beans.factory.DisposableBean;
42+
import org.springframework.lang.NonNull;
43+
import org.springframework.util.Assert;
44+
import org.springframework.util.CollectionUtils;
45+
import org.springframework.util.StringUtils;
46+
1847
import com.fasterxml.jackson.annotation.JsonInclude;
1948
import com.fasterxml.jackson.annotation.JsonInclude.Include;
2049
import com.google.cloud.vertexai.VertexAI;
2150
import com.google.cloud.vertexai.api.Content;
22-
import com.google.cloud.vertexai.api.Content.Builder;
2351
import com.google.cloud.vertexai.api.FunctionCall;
2452
import com.google.cloud.vertexai.api.FunctionDeclaration;
2553
import com.google.cloud.vertexai.api.FunctionResponse;
@@ -34,42 +62,17 @@
3462
import com.google.cloud.vertexai.generativeai.ResponseStream;
3563
import com.google.protobuf.Struct;
3664
import com.google.protobuf.util.JsonFormat;
37-
import org.springframework.ai.chat.model.ChatModel;
38-
import org.springframework.ai.chat.model.ChatResponse;
39-
import org.springframework.ai.chat.model.Generation;
40-
import org.springframework.ai.chat.messages.AssistantMessage;
41-
import org.springframework.ai.chat.messages.Message;
42-
import org.springframework.ai.chat.messages.MessageType;
43-
import org.springframework.ai.chat.messages.UserMessage;
44-
import org.springframework.ai.chat.prompt.ChatOptions;
45-
import org.springframework.ai.chat.prompt.Prompt;
46-
import org.springframework.ai.model.ChatModelDescription;
47-
import org.springframework.ai.model.ModelOptionsUtils;
48-
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
49-
import org.springframework.ai.model.function.FunctionCallbackContext;
50-
import org.springframework.ai.vertexai.gemini.metadata.VertexAiChatResponseMetadata;
51-
import org.springframework.ai.vertexai.gemini.metadata.VertexAiUsage;
52-
import org.springframework.beans.factory.DisposableBean;
53-
import org.springframework.lang.NonNull;
54-
import org.springframework.util.Assert;
55-
import org.springframework.util.CollectionUtils;
56-
import org.springframework.util.StringUtils;
57-
import reactor.core.publisher.Flux;
5865

59-
import java.util.ArrayList;
60-
import java.util.HashSet;
61-
import java.util.List;
62-
import java.util.Set;
63-
import java.util.stream.Collectors;
66+
import reactor.core.publisher.Flux;
67+
import reactor.core.publisher.Mono;
6468

6569
/**
6670
* @author Christian Tzolov
6771
* @author Grogdunn
6872
* @author luocongqiu
6973
* @since 0.8.1
7074
*/
71-
public class VertexAiGeminiChatModel
72-
extends AbstractFunctionCallSupport<Content, VertexAiGeminiChatModel.GeminiRequest, GenerateContentResponse>
75+
public class VertexAiGeminiChatModel extends AbstractToolCallSupport<GenerateContentResponse>
7376
implements ChatModel, DisposableBean {
7477

7578
private final static boolean IS_RUNTIME_CALL = true;
@@ -157,7 +160,15 @@ public ChatResponse call(Prompt prompt) {
157160

158161
var geminiRequest = createGeminiRequest(prompt);
159162

160-
GenerateContentResponse response = this.callWithFunctionSupport(geminiRequest);
163+
GenerateContentResponse response = this.getContentResponse(geminiRequest);
164+
165+
// GenerateContentResponse response = this.callWithFunctionSupport(geminiRequest);
166+
167+
if (this.isToolFunctionCall(response)) {
168+
List<Message> toolCallMessageConversation = this.handleToolCallRequests(prompt.getInstructions(), response);
169+
return this.call(new Prompt(toolCallMessageConversation, prompt.getOptions()));
170+
171+
}
161172

162173
List<Generation> generations = response.getCandidatesList()
163174
.stream()
@@ -170,6 +181,32 @@ public ChatResponse call(Prompt prompt) {
170181
return new ChatResponse(generations, toChatResponseMetadata(response));
171182
}
172183

184+
public List<Message> handleToolCallRequests(List<Message> previousMessages, GenerateContentResponse response) {
185+
186+
Content assistantContent = response.getCandidatesList().get(0).getContent();
187+
188+
List<AssistantMessage.ToolCall> assistantToolCalls = assistantContent.getPartsList()
189+
.stream()
190+
.filter(part -> part.hasFunctionCall())
191+
.map(part -> {
192+
FunctionCall functionCall = part.getFunctionCall();
193+
var functionName = functionCall.getName();
194+
String functionArguments = structToJson(functionCall.getArgs());
195+
return new AssistantMessage.ToolCall("", "function", functionName, functionArguments);
196+
})
197+
.toList();
198+
199+
AssistantMessage assistantMessage = new AssistantMessage("", Map.of(), assistantToolCalls);
200+
201+
List<ToolResponseMessage> toolResponseMessages = this.executeFuncitons(assistantMessage, true);
202+
203+
// History
204+
List<Message> toolCallMessageConversation = new ArrayList<>(previousMessages);
205+
toolCallMessageConversation.add(assistantMessage);
206+
toolCallMessageConversation.addAll(toolResponseMessages);
207+
return toolCallMessageConversation;
208+
}
209+
173210
@Override
174211
public Flux<ChatResponse> stream(Prompt prompt) {
175212
try {
@@ -179,9 +216,16 @@ public Flux<ChatResponse> stream(Prompt prompt) {
179216
ResponseStream<GenerateContentResponse> responseStream = request.model
180217
.generateContentStream(request.contents);
181218

182-
return Flux.fromStream(responseStream.stream())
183-
.switchMap(r -> handleFunctionCallOrReturnStream(request, Flux.just(r)))
184-
.map(response -> {
219+
return Flux.fromStream(responseStream.stream()).switchMap(response -> {
220+
if (this.isToolFunctionCall(response)) {
221+
List<Message> toolCallMessageConversation = this.handleToolCallRequests(prompt.getInstructions(),
222+
response);
223+
// Recursively call the stream method with the tool call message
224+
// conversation that contains the call responses.
225+
return this.stream(new Prompt(toolCallMessageConversation, prompt.getOptions()));
226+
}
227+
228+
return Mono.just(response).map(response2 -> {
185229
List<Generation> generations = response.getCandidatesList()
186230
.stream()
187231
.map(candidate -> candidate.getContent().getPartsList())
@@ -191,7 +235,9 @@ public Flux<ChatResponse> stream(Prompt prompt) {
191235
.toList();
192236

193237
return new ChatResponse(generations, toChatResponseMetadata(response));
238+
194239
});
240+
});
195241
}
196242
catch (Exception e) {
197243
throw new RuntimeException("Failed to generate content", e);
@@ -302,7 +348,8 @@ private List<Content> toGeminiContent(Prompt prompt) {
302348

303349
List<Content> contents = prompt.getInstructions()
304350
.stream()
305-
.filter(m -> m.getMessageType() == MessageType.USER || m.getMessageType() == MessageType.ASSISTANT)
351+
.filter(m -> m.getMessageType() == MessageType.USER || m.getMessageType() == MessageType.ASSISTANT
352+
|| m.getMessageType() == MessageType.TOOL)
306353
.map(message -> Content.newBuilder()
307354
.setRole(toGeminiMessageType(message.getMessageType()).getValue())
308355
.addAllParts(messageToGeminiParts(message))
@@ -318,6 +365,7 @@ private static GeminiMessageType toGeminiMessageType(@NonNull MessageType type)
318365

319366
switch (type) {
320367
case USER:
368+
case TOOL:
321369
return GeminiMessageType.USER;
322370
case ASSISTANT:
323371
return GeminiMessageType.MODEL;
@@ -348,7 +396,34 @@ static List<Part> messageToGeminiParts(Message message) {
348396
return parts;
349397
}
350398
else if (message instanceof AssistantMessage assistantMessage) {
351-
return List.of(Part.newBuilder().setText(assistantMessage.getContent()).build());
399+
List<Part> parts = new ArrayList<>();
400+
if (StringUtils.hasText(assistantMessage.getContent())) {
401+
List.of(Part.newBuilder().setText(assistantMessage.getContent()).build());
402+
}
403+
if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
404+
parts.addAll(assistantMessage.getToolCalls()
405+
.stream()
406+
.map(toolCall -> Part.newBuilder()
407+
.setFunctionCall(FunctionCall.newBuilder()
408+
.setName(toolCall.name())
409+
.setArgs(jsonToStruct(toolCall.arguments()))
410+
.build())
411+
.build())
412+
.toList());
413+
}
414+
return parts;
415+
}
416+
else if (message instanceof ToolResponseMessage toolResponseMessage) {
417+
418+
return toolResponseMessage.getResponses()
419+
.stream()
420+
.map(response -> Part.newBuilder()
421+
.setFunctionResponse(FunctionResponse.newBuilder()
422+
.setName(response.name())
423+
.setResponse(jsonToStruct(response.respoinse()))
424+
.build())
425+
.build())
426+
.toList();
352427
}
353428
else {
354429
throw new IllegalArgumentException("Gemini doesn't support message type: " + message.getClass());
@@ -402,58 +477,7 @@ private static Schema jsonToSchema(String json) {
402477
}
403478
}
404479

405-
@Override
406-
public void destroy() throws Exception {
407-
if (this.vertexAI != null) {
408-
this.vertexAI.close();
409-
}
410-
}
411-
412-
@Override
413-
protected GeminiRequest doCreateToolResponseRequest(GeminiRequest previousRequest, Content responseMessage,
414-
List<Content> conversationHistory) {
415-
416-
var iterator = responseMessage.getPartsList().iterator();
417-
418-
Builder builder = Content.newBuilder();
419-
while (iterator.hasNext()) {
420-
421-
FunctionCall functionCall = iterator.next().getFunctionCall();
422-
423-
var functionName = functionCall.getName();
424-
String functionArguments = structToJson(functionCall.getArgs());
425-
426-
if (!this.functionCallbackRegister.containsKey(functionName)) {
427-
throw new IllegalStateException("No function callback found for function name: " + functionName);
428-
}
429-
430-
String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments);
431-
432-
builder.addParts(Part.newBuilder()
433-
.setFunctionResponse(FunctionResponse.newBuilder()
434-
.setName(functionCall.getName())
435-
.setResponse(jsonToStruct(functionResponse))
436-
.build())
437-
.build());
438-
439-
}
440-
conversationHistory.add(builder.build());
441-
442-
return new GeminiRequest(conversationHistory, previousRequest.model());
443-
}
444-
445-
@Override
446-
protected List<Content> doGetUserMessages(GeminiRequest request) {
447-
return request.contents;
448-
}
449-
450-
@Override
451-
protected Content doGetToolResponseMessage(GenerateContentResponse response) {
452-
return response.getCandidatesList().get(0).getContent();
453-
}
454-
455-
@Override
456-
protected GenerateContentResponse doChatCompletion(GeminiRequest request) {
480+
private GenerateContentResponse getContentResponse(GeminiRequest request) {
457481
try {
458482
return request.model.generateContent(request.contents);
459483
}
@@ -462,19 +486,6 @@ protected GenerateContentResponse doChatCompletion(GeminiRequest request) {
462486
}
463487
}
464488

465-
@Override
466-
protected Flux<GenerateContentResponse> doChatCompletionStream(GeminiRequest request) {
467-
try {
468-
ResponseStream<GenerateContentResponse> responseStream = request.model
469-
.generateContentStream(request.contents);
470-
471-
return Flux.fromStream(responseStream.stream());
472-
}
473-
catch (Exception e) {
474-
throw new RuntimeException("Failed to generate content", e);
475-
}
476-
}
477-
478489
@Override
479490
protected boolean isToolFunctionCall(GenerateContentResponse response) {
480491
if (response == null || CollectionUtils.isEmpty(response.getCandidatesList())
@@ -490,4 +501,11 @@ public ChatOptions getDefaultOptions() {
490501
return VertexAiGeminiChatOptions.fromOptions(this.defaultOptions);
491502
}
492503

504+
@Override
505+
public void destroy() throws Exception {
506+
if (this.vertexAI != null) {
507+
this.vertexAI.close();
508+
}
509+
}
510+
493511
}

0 commit comments

Comments
 (0)