Skip to content

Commit 02b486f

Browse files
committed
Streamline the VertexAI Gemini Function Calling
- Align with AbstractToolCallSupport
1 parent aa18a67 commit 02b486f

File tree

2 files changed

+100
-86
lines changed

2 files changed

+100
-86
lines changed

models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java

Lines changed: 97 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -15,55 +15,59 @@
1515
*/
1616
package org.springframework.ai.vertexai.gemini;
1717

18-
import com.fasterxml.jackson.annotation.JsonInclude;
19-
import com.fasterxml.jackson.annotation.JsonInclude.Include;
20-
import com.google.cloud.vertexai.VertexAI;
21-
import com.google.cloud.vertexai.api.Content;
22-
import com.google.cloud.vertexai.api.FunctionCall;
23-
import com.google.cloud.vertexai.api.FunctionDeclaration;
24-
import com.google.cloud.vertexai.api.FunctionResponse;
25-
import com.google.cloud.vertexai.api.GenerateContentResponse;
26-
import com.google.cloud.vertexai.api.GenerationConfig;
27-
import com.google.cloud.vertexai.api.Part;
28-
import com.google.cloud.vertexai.api.Schema;
29-
import com.google.cloud.vertexai.api.Tool;
30-
import com.google.cloud.vertexai.generativeai.GenerativeModel;
31-
import com.google.cloud.vertexai.generativeai.PartMaker;
32-
import com.google.cloud.vertexai.generativeai.ResponseStream;
33-
import com.google.protobuf.Struct;
34-
import com.google.protobuf.util.JsonFormat;
18+
import java.util.ArrayList;
19+
import java.util.Collection;
20+
import java.util.HashSet;
21+
import java.util.List;
22+
import java.util.Map;
23+
import java.util.Set;
24+
3525
import org.springframework.ai.chat.messages.AssistantMessage;
36-
import org.springframework.ai.model.Media;
3726
import org.springframework.ai.chat.messages.Message;
3827
import org.springframework.ai.chat.messages.MessageType;
3928
import org.springframework.ai.chat.messages.SystemMessage;
4029
import org.springframework.ai.chat.messages.ToolResponseMessage;
4130
import org.springframework.ai.chat.messages.UserMessage;
31+
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
4232
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
33+
import org.springframework.ai.chat.model.AbstractToolCallSupport;
4334
import org.springframework.ai.chat.model.ChatModel;
4435
import org.springframework.ai.chat.model.ChatResponse;
4536
import org.springframework.ai.chat.model.Generation;
4637
import org.springframework.ai.chat.prompt.ChatOptions;
4738
import org.springframework.ai.chat.prompt.Prompt;
4839
import org.springframework.ai.model.ChatModelDescription;
40+
import org.springframework.ai.model.Media;
4941
import org.springframework.ai.model.ModelOptionsUtils;
50-
import org.springframework.ai.chat.model.AbstractToolCallSupport;
5142
import org.springframework.ai.model.function.FunctionCallbackContext;
5243
import org.springframework.ai.vertexai.gemini.metadata.VertexAiUsage;
5344
import org.springframework.beans.factory.DisposableBean;
5445
import org.springframework.lang.NonNull;
5546
import org.springframework.util.Assert;
5647
import org.springframework.util.CollectionUtils;
5748
import org.springframework.util.StringUtils;
58-
import reactor.core.publisher.Flux;
59-
import reactor.core.publisher.Mono;
6049

61-
import java.util.ArrayList;
62-
import java.util.Collection;
63-
import java.util.HashSet;
64-
import java.util.List;
65-
import java.util.Map;
66-
import java.util.Set;
50+
import com.fasterxml.jackson.annotation.JsonInclude;
51+
import com.fasterxml.jackson.annotation.JsonInclude.Include;
52+
import com.google.cloud.vertexai.VertexAI;
53+
import com.google.cloud.vertexai.api.Candidate;
54+
import com.google.cloud.vertexai.api.Candidate.FinishReason;
55+
import com.google.cloud.vertexai.api.Content;
56+
import com.google.cloud.vertexai.api.FunctionCall;
57+
import com.google.cloud.vertexai.api.FunctionDeclaration;
58+
import com.google.cloud.vertexai.api.FunctionResponse;
59+
import com.google.cloud.vertexai.api.GenerateContentResponse;
60+
import com.google.cloud.vertexai.api.GenerationConfig;
61+
import com.google.cloud.vertexai.api.Part;
62+
import com.google.cloud.vertexai.api.Schema;
63+
import com.google.cloud.vertexai.api.Tool;
64+
import com.google.cloud.vertexai.generativeai.GenerativeModel;
65+
import com.google.cloud.vertexai.generativeai.PartMaker;
66+
import com.google.cloud.vertexai.generativeai.ResponseStream;
67+
import com.google.protobuf.Struct;
68+
import com.google.protobuf.util.JsonFormat;
69+
70+
import reactor.core.publisher.Flux;
6771

6872
/**
6973
* @author Christian Tzolov
@@ -161,47 +165,22 @@ public ChatResponse call(Prompt prompt) {
161165

162166
GenerateContentResponse response = this.getContentResponse(geminiRequest);
163167

164-
if (this.isToolFunctionCall(response)) {
165-
List<Message> toolCallMessageConversation = this.handleToolCallRequests(prompt.getInstructions(), response);
166-
return this.call(new Prompt(toolCallMessageConversation, prompt.getOptions()));
167-
168-
}
169-
170168
List<Generation> generations = response.getCandidatesList()
171169
.stream()
172-
.map(candidate -> candidate.getContent().getPartsList())
170+
.map(this::responseCandiateToGeneration)
173171
.flatMap(List::stream)
174-
.map(Part::getText)
175-
.map(t -> new Generation(t))
176172
.toList();
177173

178-
return new ChatResponse(generations, toChatResponseMetadata(response));
179-
}
174+
ChatResponse chatResponse = new ChatResponse(generations, toChatResponseMetadata(response));
180175

181-
public List<Message> handleToolCallRequests(List<Message> previousMessages, GenerateContentResponse response) {
182-
183-
Content assistantContent = response.getCandidatesList().get(0).getContent();
184-
185-
List<AssistantMessage.ToolCall> assistantToolCalls = assistantContent.getPartsList()
186-
.stream()
187-
.filter(part -> part.hasFunctionCall())
188-
.map(part -> {
189-
FunctionCall functionCall = part.getFunctionCall();
190-
var functionName = functionCall.getName();
191-
String functionArguments = structToJson(functionCall.getArgs());
192-
return new AssistantMessage.ToolCall("", "function", functionName, functionArguments);
193-
})
194-
.toList();
195-
196-
AssistantMessage assistantMessage = new AssistantMessage("", Map.of(), assistantToolCalls);
197-
198-
ToolResponseMessage toolResponseMessage = this.executeFunctions(assistantMessage);
176+
if (isToolCall(chatResponse, Set.of(FinishReason.STOP.name()))) {
177+
var toolCallConversation = handleToolCalls(prompt, chatResponse);
178+
// Recursively call the call method with the tool call message
179+
// conversation that contains the call responses.
180+
return this.call(new Prompt(toolCallConversation, prompt.getOptions()));
181+
}
199182

200-
// History
201-
List<Message> toolCallMessageConversation = new ArrayList<>(previousMessages);
202-
toolCallMessageConversation.add(assistantMessage);
203-
toolCallMessageConversation.add(toolResponseMessage);
204-
return toolCallMessageConversation;
183+
return chatResponse;
205184
}
206185

207186
@Override
@@ -214,33 +193,74 @@ public Flux<ChatResponse> stream(Prompt prompt) {
214193
.generateContentStream(request.contents);
215194

216195
return Flux.fromStream(responseStream.stream()).switchMap(response -> {
217-
if (this.isToolFunctionCall(response)) {
218-
List<Message> toolCallMessageConversation = this.handleToolCallRequests(prompt.getInstructions(),
219-
response);
196+
197+
List<Generation> generations = response.getCandidatesList()
198+
.stream()
199+
.map(this::responseCandiateToGeneration)
200+
.flatMap(List::stream)
201+
.toList();
202+
203+
ChatResponse chatResponse = new ChatResponse(generations, toChatResponseMetadata(response));
204+
205+
if (isToolCall(chatResponse,
206+
Set.of(FinishReason.STOP.name(), FinishReason.FINISH_REASON_UNSPECIFIED.name()))) {
207+
var toolCallConversation = handleToolCalls(prompt, chatResponse);
220208
// Recursively call the stream method with the tool call message
221209
// conversation that contains the call responses.
222-
return this.stream(new Prompt(toolCallMessageConversation, prompt.getOptions()));
210+
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
223211
}
224212

225-
return Mono.just(response).map(response2 -> {
226-
List<Generation> generations = response.getCandidatesList()
227-
.stream()
228-
.map(candidate -> candidate.getContent().getPartsList())
229-
.flatMap(List::stream)
230-
.map(Part::getText)
231-
.map(t -> new Generation(t))
232-
.toList();
233-
234-
return new ChatResponse(generations, toChatResponseMetadata(response));
235-
236-
});
213+
return Flux.just(chatResponse);
237214
});
238215
}
239216
catch (Exception e) {
240217
throw new RuntimeException("Failed to generate content", e);
241218
}
242219
}
243220

221+
protected List<Generation> responseCandiateToGeneration(Candidate candidate) {
222+
223+
// TODO - The candidateIndex (e.g. choice must be asigned to the generation).
224+
int candidateIndex = candidate.getIndex();
225+
FinishReason candidateFinishReasonn = candidate.getFinishReason();
226+
227+
Map<String, Object> messageMetadata = Map.of("candidateIndex", candidateIndex, "finishReason",
228+
candidateFinishReasonn);
229+
230+
ChatGenerationMetadata chatGenerationMetadata = ChatGenerationMetadata.from(candidateFinishReasonn.name(),
231+
null);
232+
233+
boolean isFunctinCall = candidate.getContent().getPartsList().stream().allMatch(Part::hasFunctionCall);
234+
235+
if (isFunctinCall) {
236+
List<AssistantMessage.ToolCall> assistantToolCalls = candidate.getContent()
237+
.getPartsList()
238+
.stream()
239+
.filter(part -> part.hasFunctionCall())
240+
.map(part -> {
241+
FunctionCall functionCall = part.getFunctionCall();
242+
var functionName = functionCall.getName();
243+
String functionArguments = structToJson(functionCall.getArgs());
244+
return new AssistantMessage.ToolCall("", "function", functionName, functionArguments);
245+
})
246+
.toList();
247+
248+
AssistantMessage assistantMessage = new AssistantMessage("", messageMetadata, assistantToolCalls);
249+
250+
return List.of(new Generation(assistantMessage, chatGenerationMetadata));
251+
}
252+
else {
253+
List<Generation> generations = candidate.getContent()
254+
.getPartsList()
255+
.stream()
256+
.map(part -> new AssistantMessage(part.getText(), messageMetadata))
257+
.map(assistantMessage -> new Generation(assistantMessage, chatGenerationMetadata))
258+
.toList();
259+
260+
return generations;
261+
}
262+
}
263+
244264
private ChatResponseMetadata toChatResponseMetadata(GenerateContentResponse response) {
245265
return ChatResponseMetadata.builder().withUsage(new VertexAiUsage(response.getUsageMetadata())).build();
246266
}
@@ -499,15 +519,6 @@ private GenerateContentResponse getContentResponse(GeminiRequest request) {
499519
}
500520
}
501521

502-
protected boolean isToolFunctionCall(GenerateContentResponse response) {
503-
if (response == null || CollectionUtils.isEmpty(response.getCandidatesList())
504-
|| response.getCandidatesList().get(0).getContent() == null
505-
|| CollectionUtils.isEmpty(response.getCandidatesList().get(0).getContent().getPartsList())) {
506-
return false;
507-
}
508-
return response.getCandidatesList().get(0).getContent().getPartsList().get(0).hasFunctionCall();
509-
}
510-
511522
@Override
512523
public ChatOptions getDefaultOptions() {
513524
return VertexAiGeminiChatOptions.fromOptions(this.defaultOptions);

spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallSupport.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@
2929
/**
3030
* @author Christian Tzolov
3131
* @author Grogdunn
32+
* @deprecated since 1.0.0-M1 in favor of
33+
* {@link org.springframework.ai.chat.model.AbstractToolCallSupport}
3234
*/
35+
@Deprecated(since = "1.0.0-M2", forRemoval = true)
3336
public abstract class AbstractFunctionCallSupport<Msg, Req, Resp> {
3437

3538
protected final static boolean IS_RUNTIME_CALL = true;

0 commit comments

Comments
 (0)