1515 */
1616package org .springframework .ai .vertexai .gemini ;
1717
18+ import java .util .ArrayList ;
19+ import java .util .HashSet ;
20+ import java .util .List ;
21+ import java .util .Map ;
22+ import java .util .Set ;
23+ import java .util .stream .Collectors ;
24+
25+ import org .springframework .ai .chat .messages .AssistantMessage ;
26+ import org .springframework .ai .chat .messages .Message ;
27+ import org .springframework .ai .chat .messages .MessageType ;
28+ import org .springframework .ai .chat .messages .ToolResponseMessage ;
29+ import org .springframework .ai .chat .messages .UserMessage ;
30+ import org .springframework .ai .chat .model .ChatModel ;
31+ import org .springframework .ai .chat .model .ChatResponse ;
32+ import org .springframework .ai .chat .model .Generation ;
33+ import org .springframework .ai .chat .prompt .ChatOptions ;
34+ import org .springframework .ai .chat .prompt .Prompt ;
35+ import org .springframework .ai .model .ChatModelDescription ;
36+ import org .springframework .ai .model .ModelOptionsUtils ;
37+ import org .springframework .ai .model .function .AbstractToolCallSupport ;
38+ import org .springframework .ai .model .function .FunctionCallbackContext ;
39+ import org .springframework .ai .vertexai .gemini .metadata .VertexAiChatResponseMetadata ;
40+ import org .springframework .ai .vertexai .gemini .metadata .VertexAiUsage ;
41+ import org .springframework .beans .factory .DisposableBean ;
42+ import org .springframework .lang .NonNull ;
43+ import org .springframework .util .Assert ;
44+ import org .springframework .util .CollectionUtils ;
45+ import org .springframework .util .StringUtils ;
46+
1847import com .fasterxml .jackson .annotation .JsonInclude ;
1948import com .fasterxml .jackson .annotation .JsonInclude .Include ;
2049import com .google .cloud .vertexai .VertexAI ;
2150import com .google .cloud .vertexai .api .Content ;
22- import com .google .cloud .vertexai .api .Content .Builder ;
2351import com .google .cloud .vertexai .api .FunctionCall ;
2452import com .google .cloud .vertexai .api .FunctionDeclaration ;
2553import com .google .cloud .vertexai .api .FunctionResponse ;
3462import com .google .cloud .vertexai .generativeai .ResponseStream ;
3563import com .google .protobuf .Struct ;
3664import com .google .protobuf .util .JsonFormat ;
37- import org .springframework .ai .chat .model .ChatModel ;
38- import org .springframework .ai .chat .model .ChatResponse ;
39- import org .springframework .ai .chat .model .Generation ;
40- import org .springframework .ai .chat .messages .AssistantMessage ;
41- import org .springframework .ai .chat .messages .Message ;
42- import org .springframework .ai .chat .messages .MessageType ;
43- import org .springframework .ai .chat .messages .UserMessage ;
44- import org .springframework .ai .chat .prompt .ChatOptions ;
45- import org .springframework .ai .chat .prompt .Prompt ;
46- import org .springframework .ai .model .ChatModelDescription ;
47- import org .springframework .ai .model .ModelOptionsUtils ;
48- import org .springframework .ai .model .function .AbstractFunctionCallSupport ;
49- import org .springframework .ai .model .function .FunctionCallbackContext ;
50- import org .springframework .ai .vertexai .gemini .metadata .VertexAiChatResponseMetadata ;
51- import org .springframework .ai .vertexai .gemini .metadata .VertexAiUsage ;
52- import org .springframework .beans .factory .DisposableBean ;
53- import org .springframework .lang .NonNull ;
54- import org .springframework .util .Assert ;
55- import org .springframework .util .CollectionUtils ;
56- import org .springframework .util .StringUtils ;
57- import reactor .core .publisher .Flux ;
5865
59- import java .util .ArrayList ;
60- import java .util .HashSet ;
61- import java .util .List ;
62- import java .util .Set ;
63- import java .util .stream .Collectors ;
66+ import reactor .core .publisher .Flux ;
67+ import reactor .core .publisher .Mono ;
6468
6569/**
6670 * @author Christian Tzolov
6771 * @author Grogdunn
6872 * @author luocongqiu
6973 * @since 0.8.1
7074 */
71- public class VertexAiGeminiChatModel
72- extends AbstractFunctionCallSupport <Content , VertexAiGeminiChatModel .GeminiRequest , GenerateContentResponse >
75+ public class VertexAiGeminiChatModel extends AbstractToolCallSupport <GenerateContentResponse >
7376 implements ChatModel , DisposableBean {
7477
7578 private final static boolean IS_RUNTIME_CALL = true ;
@@ -157,7 +160,15 @@ public ChatResponse call(Prompt prompt) {
157160
158161 var geminiRequest = createGeminiRequest (prompt );
159162
160- GenerateContentResponse response = this .callWithFunctionSupport (geminiRequest );
163+ GenerateContentResponse response = this .getContentResponse (geminiRequest );
164+
165+ // GenerateContentResponse response = this.callWithFunctionSupport(geminiRequest);
166+
167+ if (this .isToolFunctionCall (response )) {
168+ List <Message > toolCallMessageConversation = this .handleToolCallRequests (prompt .getInstructions (), response );
169+ return this .call (new Prompt (toolCallMessageConversation , prompt .getOptions ()));
170+
171+ }
161172
162173 List <Generation > generations = response .getCandidatesList ()
163174 .stream ()
@@ -170,6 +181,32 @@ public ChatResponse call(Prompt prompt) {
170181 return new ChatResponse (generations , toChatResponseMetadata (response ));
171182 }
172183
184+ public List <Message > handleToolCallRequests (List <Message > previousMessages , GenerateContentResponse response ) {
185+
186+ Content assistantContent = response .getCandidatesList ().get (0 ).getContent ();
187+
188+ List <AssistantMessage .ToolCall > assistantToolCalls = assistantContent .getPartsList ()
189+ .stream ()
190+ .filter (part -> part .hasFunctionCall ())
191+ .map (part -> {
192+ FunctionCall functionCall = part .getFunctionCall ();
193+ var functionName = functionCall .getName ();
194+ String functionArguments = structToJson (functionCall .getArgs ());
195+ return new AssistantMessage .ToolCall ("" , "function" , functionName , functionArguments );
196+ })
197+ .toList ();
198+
199+ AssistantMessage assistantMessage = new AssistantMessage ("" , Map .of (), assistantToolCalls );
200+
201+ List <ToolResponseMessage > toolResponseMessages = this .executeFuncitons (assistantMessage , true );
202+
203+ // History
204+ List <Message > toolCallMessageConversation = new ArrayList <>(previousMessages );
205+ toolCallMessageConversation .add (assistantMessage );
206+ toolCallMessageConversation .addAll (toolResponseMessages );
207+ return toolCallMessageConversation ;
208+ }
209+
173210 @ Override
174211 public Flux <ChatResponse > stream (Prompt prompt ) {
175212 try {
@@ -179,9 +216,16 @@ public Flux<ChatResponse> stream(Prompt prompt) {
179216 ResponseStream <GenerateContentResponse > responseStream = request .model
180217 .generateContentStream (request .contents );
181218
182- return Flux .fromStream (responseStream .stream ())
183- .switchMap (r -> handleFunctionCallOrReturnStream (request , Flux .just (r )))
184- .map (response -> {
219+ return Flux .fromStream (responseStream .stream ()).switchMap (response -> {
220+ if (this .isToolFunctionCall (response )) {
221+ List <Message > toolCallMessageConversation = this .handleToolCallRequests (prompt .getInstructions (),
222+ response );
223+ // Recursively call the stream method with the tool call message
224+ // conversation that contains the call responses.
225+ return this .stream (new Prompt (toolCallMessageConversation , prompt .getOptions ()));
226+ }
227+
228+ return Mono .just (response ).map (response2 -> {
185229 List <Generation > generations = response .getCandidatesList ()
186230 .stream ()
187231 .map (candidate -> candidate .getContent ().getPartsList ())
@@ -191,7 +235,9 @@ public Flux<ChatResponse> stream(Prompt prompt) {
191235 .toList ();
192236
193237 return new ChatResponse (generations , toChatResponseMetadata (response ));
238+
194239 });
240+ });
195241 }
196242 catch (Exception e ) {
197243 throw new RuntimeException ("Failed to generate content" , e );
@@ -302,7 +348,8 @@ private List<Content> toGeminiContent(Prompt prompt) {
302348
303349 List <Content > contents = prompt .getInstructions ()
304350 .stream ()
305- .filter (m -> m .getMessageType () == MessageType .USER || m .getMessageType () == MessageType .ASSISTANT )
351+ .filter (m -> m .getMessageType () == MessageType .USER || m .getMessageType () == MessageType .ASSISTANT
352+ || m .getMessageType () == MessageType .TOOL )
306353 .map (message -> Content .newBuilder ()
307354 .setRole (toGeminiMessageType (message .getMessageType ()).getValue ())
308355 .addAllParts (messageToGeminiParts (message ))
@@ -318,6 +365,7 @@ private static GeminiMessageType toGeminiMessageType(@NonNull MessageType type)
318365
319366 switch (type ) {
320367 case USER :
368+ case TOOL :
321369 return GeminiMessageType .USER ;
322370 case ASSISTANT :
323371 return GeminiMessageType .MODEL ;
@@ -348,7 +396,34 @@ static List<Part> messageToGeminiParts(Message message) {
348396 return parts ;
349397 }
350398 else if (message instanceof AssistantMessage assistantMessage ) {
351- return List .of (Part .newBuilder ().setText (assistantMessage .getContent ()).build ());
399+ List <Part > parts = new ArrayList <>();
400+ if (StringUtils .hasText (assistantMessage .getContent ())) {
401+ List .of (Part .newBuilder ().setText (assistantMessage .getContent ()).build ());
402+ }
403+ if (!CollectionUtils .isEmpty (assistantMessage .getToolCalls ())) {
404+ parts .addAll (assistantMessage .getToolCalls ()
405+ .stream ()
406+ .map (toolCall -> Part .newBuilder ()
407+ .setFunctionCall (FunctionCall .newBuilder ()
408+ .setName (toolCall .name ())
409+ .setArgs (jsonToStruct (toolCall .arguments ()))
410+ .build ())
411+ .build ())
412+ .toList ());
413+ }
414+ return parts ;
415+ }
416+ else if (message instanceof ToolResponseMessage toolResponseMessage ) {
417+
418+ return toolResponseMessage .getResponses ()
419+ .stream ()
420+ .map (response -> Part .newBuilder ()
421+ .setFunctionResponse (FunctionResponse .newBuilder ()
422+ .setName (response .name ())
423+ .setResponse (jsonToStruct (response .respoinse ()))
424+ .build ())
425+ .build ())
426+ .toList ();
352427 }
353428 else {
354429 throw new IllegalArgumentException ("Gemini doesn't support message type: " + message .getClass ());
@@ -402,58 +477,7 @@ private static Schema jsonToSchema(String json) {
402477 }
403478 }
404479
405- @ Override
406- public void destroy () throws Exception {
407- if (this .vertexAI != null ) {
408- this .vertexAI .close ();
409- }
410- }
411-
412- @ Override
413- protected GeminiRequest doCreateToolResponseRequest (GeminiRequest previousRequest , Content responseMessage ,
414- List <Content > conversationHistory ) {
415-
416- var iterator = responseMessage .getPartsList ().iterator ();
417-
418- Builder builder = Content .newBuilder ();
419- while (iterator .hasNext ()) {
420-
421- FunctionCall functionCall = iterator .next ().getFunctionCall ();
422-
423- var functionName = functionCall .getName ();
424- String functionArguments = structToJson (functionCall .getArgs ());
425-
426- if (!this .functionCallbackRegister .containsKey (functionName )) {
427- throw new IllegalStateException ("No function callback found for function name: " + functionName );
428- }
429-
430- String functionResponse = this .functionCallbackRegister .get (functionName ).call (functionArguments );
431-
432- builder .addParts (Part .newBuilder ()
433- .setFunctionResponse (FunctionResponse .newBuilder ()
434- .setName (functionCall .getName ())
435- .setResponse (jsonToStruct (functionResponse ))
436- .build ())
437- .build ());
438-
439- }
440- conversationHistory .add (builder .build ());
441-
442- return new GeminiRequest (conversationHistory , previousRequest .model ());
443- }
444-
445- @ Override
446- protected List <Content > doGetUserMessages (GeminiRequest request ) {
447- return request .contents ;
448- }
449-
450- @ Override
451- protected Content doGetToolResponseMessage (GenerateContentResponse response ) {
452- return response .getCandidatesList ().get (0 ).getContent ();
453- }
454-
455- @ Override
456- protected GenerateContentResponse doChatCompletion (GeminiRequest request ) {
480+ private GenerateContentResponse getContentResponse (GeminiRequest request ) {
457481 try {
458482 return request .model .generateContent (request .contents );
459483 }
@@ -462,19 +486,6 @@ protected GenerateContentResponse doChatCompletion(GeminiRequest request) {
462486 }
463487 }
464488
465- @ Override
466- protected Flux <GenerateContentResponse > doChatCompletionStream (GeminiRequest request ) {
467- try {
468- ResponseStream <GenerateContentResponse > responseStream = request .model
469- .generateContentStream (request .contents );
470-
471- return Flux .fromStream (responseStream .stream ());
472- }
473- catch (Exception e ) {
474- throw new RuntimeException ("Failed to generate content" , e );
475- }
476- }
477-
478489 @ Override
479490 protected boolean isToolFunctionCall (GenerateContentResponse response ) {
480491 if (response == null || CollectionUtils .isEmpty (response .getCandidatesList ())
@@ -490,4 +501,11 @@ public ChatOptions getDefaultOptions() {
490501 return VertexAiGeminiChatOptions .fromOptions (this .defaultOptions );
491502 }
492503
504+ @ Override
505+ public void destroy () throws Exception {
506+ if (this .vertexAI != null ) {
507+ this .vertexAI .close ();
508+ }
509+ }
510+
493511}
0 commit comments