15
15
*/
16
16
package org .springframework .ai .vertexai .gemini ;
17
17
18
- import com .fasterxml .jackson .annotation .JsonInclude ;
19
- import com .fasterxml .jackson .annotation .JsonInclude .Include ;
20
- import com .google .cloud .vertexai .VertexAI ;
21
- import com .google .cloud .vertexai .api .Content ;
22
- import com .google .cloud .vertexai .api .FunctionCall ;
23
- import com .google .cloud .vertexai .api .FunctionDeclaration ;
24
- import com .google .cloud .vertexai .api .FunctionResponse ;
25
- import com .google .cloud .vertexai .api .GenerateContentResponse ;
26
- import com .google .cloud .vertexai .api .GenerationConfig ;
27
- import com .google .cloud .vertexai .api .Part ;
28
- import com .google .cloud .vertexai .api .Schema ;
29
- import com .google .cloud .vertexai .api .Tool ;
30
- import com .google .cloud .vertexai .generativeai .GenerativeModel ;
31
- import com .google .cloud .vertexai .generativeai .PartMaker ;
32
- import com .google .cloud .vertexai .generativeai .ResponseStream ;
33
- import com .google .protobuf .Struct ;
34
- import com .google .protobuf .util .JsonFormat ;
18
+ import java .util .ArrayList ;
19
+ import java .util .Collection ;
20
+ import java .util .HashSet ;
21
+ import java .util .List ;
22
+ import java .util .Map ;
23
+ import java .util .Set ;
24
+
35
25
import org .springframework .ai .chat .messages .AssistantMessage ;
36
- import org .springframework .ai .model .Media ;
37
26
import org .springframework .ai .chat .messages .Message ;
38
27
import org .springframework .ai .chat .messages .MessageType ;
39
28
import org .springframework .ai .chat .messages .SystemMessage ;
40
29
import org .springframework .ai .chat .messages .ToolResponseMessage ;
41
30
import org .springframework .ai .chat .messages .UserMessage ;
31
+ import org .springframework .ai .chat .metadata .ChatGenerationMetadata ;
42
32
import org .springframework .ai .chat .metadata .ChatResponseMetadata ;
33
+ import org .springframework .ai .chat .model .AbstractToolCallSupport ;
43
34
import org .springframework .ai .chat .model .ChatModel ;
44
35
import org .springframework .ai .chat .model .ChatResponse ;
45
36
import org .springframework .ai .chat .model .Generation ;
46
37
import org .springframework .ai .chat .prompt .ChatOptions ;
47
38
import org .springframework .ai .chat .prompt .Prompt ;
48
39
import org .springframework .ai .model .ChatModelDescription ;
40
+ import org .springframework .ai .model .Media ;
49
41
import org .springframework .ai .model .ModelOptionsUtils ;
50
- import org .springframework .ai .chat .model .AbstractToolCallSupport ;
51
42
import org .springframework .ai .model .function .FunctionCallbackContext ;
52
43
import org .springframework .ai .vertexai .gemini .metadata .VertexAiUsage ;
53
44
import org .springframework .beans .factory .DisposableBean ;
54
45
import org .springframework .lang .NonNull ;
55
46
import org .springframework .util .Assert ;
56
47
import org .springframework .util .CollectionUtils ;
57
48
import org .springframework .util .StringUtils ;
58
- import reactor .core .publisher .Flux ;
59
- import reactor .core .publisher .Mono ;
60
49
61
- import java .util .ArrayList ;
62
- import java .util .Collection ;
63
- import java .util .HashSet ;
64
- import java .util .List ;
65
- import java .util .Map ;
66
- import java .util .Set ;
50
+ import com .fasterxml .jackson .annotation .JsonInclude ;
51
+ import com .fasterxml .jackson .annotation .JsonInclude .Include ;
52
+ import com .google .cloud .vertexai .VertexAI ;
53
+ import com .google .cloud .vertexai .api .Candidate ;
54
+ import com .google .cloud .vertexai .api .Candidate .FinishReason ;
55
+ import com .google .cloud .vertexai .api .Content ;
56
+ import com .google .cloud .vertexai .api .FunctionCall ;
57
+ import com .google .cloud .vertexai .api .FunctionDeclaration ;
58
+ import com .google .cloud .vertexai .api .FunctionResponse ;
59
+ import com .google .cloud .vertexai .api .GenerateContentResponse ;
60
+ import com .google .cloud .vertexai .api .GenerationConfig ;
61
+ import com .google .cloud .vertexai .api .Part ;
62
+ import com .google .cloud .vertexai .api .Schema ;
63
+ import com .google .cloud .vertexai .api .Tool ;
64
+ import com .google .cloud .vertexai .generativeai .GenerativeModel ;
65
+ import com .google .cloud .vertexai .generativeai .PartMaker ;
66
+ import com .google .cloud .vertexai .generativeai .ResponseStream ;
67
+ import com .google .protobuf .Struct ;
68
+ import com .google .protobuf .util .JsonFormat ;
69
+
70
+ import reactor .core .publisher .Flux ;
67
71
68
72
/**
69
73
* @author Christian Tzolov
@@ -161,47 +165,22 @@ public ChatResponse call(Prompt prompt) {
161
165
162
166
GenerateContentResponse response = this .getContentResponse (geminiRequest );
163
167
164
- if (this .isToolFunctionCall (response )) {
165
- List <Message > toolCallMessageConversation = this .handleToolCallRequests (prompt .getInstructions (), response );
166
- return this .call (new Prompt (toolCallMessageConversation , prompt .getOptions ()));
167
-
168
- }
169
-
170
168
List <Generation > generations = response .getCandidatesList ()
171
169
.stream ()
172
- .map (candidate -> candidate . getContent (). getPartsList () )
170
+ .map (this :: responseCandiateToGeneration )
173
171
.flatMap (List ::stream )
174
- .map (Part ::getText )
175
- .map (t -> new Generation (t ))
176
172
.toList ();
177
173
178
- return new ChatResponse (generations , toChatResponseMetadata (response ));
179
- }
174
+ ChatResponse chatResponse = new ChatResponse (generations , toChatResponseMetadata (response ));
180
175
181
- public List <Message > handleToolCallRequests (List <Message > previousMessages , GenerateContentResponse response ) {
182
-
183
- Content assistantContent = response .getCandidatesList ().get (0 ).getContent ();
184
-
185
- List <AssistantMessage .ToolCall > assistantToolCalls = assistantContent .getPartsList ()
186
- .stream ()
187
- .filter (part -> part .hasFunctionCall ())
188
- .map (part -> {
189
- FunctionCall functionCall = part .getFunctionCall ();
190
- var functionName = functionCall .getName ();
191
- String functionArguments = structToJson (functionCall .getArgs ());
192
- return new AssistantMessage .ToolCall ("" , "function" , functionName , functionArguments );
193
- })
194
- .toList ();
195
-
196
- AssistantMessage assistantMessage = new AssistantMessage ("" , Map .of (), assistantToolCalls );
197
-
198
- ToolResponseMessage toolResponseMessage = this .executeFunctions (assistantMessage );
176
+ if (isToolCall (chatResponse , Set .of (FinishReason .STOP .name ()))) {
177
+ var toolCallConversation = handleToolCalls (prompt , chatResponse );
178
+ // Recursively call the call method with the tool call message
179
+ // conversation that contains the call responses.
180
+ return this .call (new Prompt (toolCallConversation , prompt .getOptions ()));
181
+ }
199
182
200
- // History
201
- List <Message > toolCallMessageConversation = new ArrayList <>(previousMessages );
202
- toolCallMessageConversation .add (assistantMessage );
203
- toolCallMessageConversation .add (toolResponseMessage );
204
- return toolCallMessageConversation ;
183
+ return chatResponse ;
205
184
}
206
185
207
186
@ Override
@@ -214,33 +193,74 @@ public Flux<ChatResponse> stream(Prompt prompt) {
214
193
.generateContentStream (request .contents );
215
194
216
195
return Flux .fromStream (responseStream .stream ()).switchMap (response -> {
217
- if (this .isToolFunctionCall (response )) {
218
- List <Message > toolCallMessageConversation = this .handleToolCallRequests (prompt .getInstructions (),
219
- response );
196
+
197
+ List <Generation > generations = response .getCandidatesList ()
198
+ .stream ()
199
+ .map (this ::responseCandiateToGeneration )
200
+ .flatMap (List ::stream )
201
+ .toList ();
202
+
203
+ ChatResponse chatResponse = new ChatResponse (generations , toChatResponseMetadata (response ));
204
+
205
+ if (isToolCall (chatResponse ,
206
+ Set .of (FinishReason .STOP .name (), FinishReason .FINISH_REASON_UNSPECIFIED .name ()))) {
207
+ var toolCallConversation = handleToolCalls (prompt , chatResponse );
220
208
// Recursively call the stream method with the tool call message
221
209
// conversation that contains the call responses.
222
- return this .stream (new Prompt (toolCallMessageConversation , prompt .getOptions ()));
210
+ return this .stream (new Prompt (toolCallConversation , prompt .getOptions ()));
223
211
}
224
212
225
- return Mono .just (response ).map (response2 -> {
226
- List <Generation > generations = response .getCandidatesList ()
227
- .stream ()
228
- .map (candidate -> candidate .getContent ().getPartsList ())
229
- .flatMap (List ::stream )
230
- .map (Part ::getText )
231
- .map (t -> new Generation (t ))
232
- .toList ();
233
-
234
- return new ChatResponse (generations , toChatResponseMetadata (response ));
235
-
236
- });
213
+ return Flux .just (chatResponse );
237
214
});
238
215
}
239
216
catch (Exception e ) {
240
217
throw new RuntimeException ("Failed to generate content" , e );
241
218
}
242
219
}
243
220
221
+ protected List <Generation > responseCandiateToGeneration (Candidate candidate ) {
222
+
223
+ // TODO - The candidateIndex (e.g. choice must be asigned to the generation).
224
+ int candidateIndex = candidate .getIndex ();
225
+ FinishReason candidateFinishReasonn = candidate .getFinishReason ();
226
+
227
+ Map <String , Object > messageMetadata = Map .of ("candidateIndex" , candidateIndex , "finishReason" ,
228
+ candidateFinishReasonn );
229
+
230
+ ChatGenerationMetadata chatGenerationMetadata = ChatGenerationMetadata .from (candidateFinishReasonn .name (),
231
+ null );
232
+
233
+ boolean isFunctinCall = candidate .getContent ().getPartsList ().stream ().allMatch (Part ::hasFunctionCall );
234
+
235
+ if (isFunctinCall ) {
236
+ List <AssistantMessage .ToolCall > assistantToolCalls = candidate .getContent ()
237
+ .getPartsList ()
238
+ .stream ()
239
+ .filter (part -> part .hasFunctionCall ())
240
+ .map (part -> {
241
+ FunctionCall functionCall = part .getFunctionCall ();
242
+ var functionName = functionCall .getName ();
243
+ String functionArguments = structToJson (functionCall .getArgs ());
244
+ return new AssistantMessage .ToolCall ("" , "function" , functionName , functionArguments );
245
+ })
246
+ .toList ();
247
+
248
+ AssistantMessage assistantMessage = new AssistantMessage ("" , messageMetadata , assistantToolCalls );
249
+
250
+ return List .of (new Generation (assistantMessage , chatGenerationMetadata ));
251
+ }
252
+ else {
253
+ List <Generation > generations = candidate .getContent ()
254
+ .getPartsList ()
255
+ .stream ()
256
+ .map (part -> new AssistantMessage (part .getText (), messageMetadata ))
257
+ .map (assistantMessage -> new Generation (assistantMessage , chatGenerationMetadata ))
258
+ .toList ();
259
+
260
+ return generations ;
261
+ }
262
+ }
263
+
244
264
private ChatResponseMetadata toChatResponseMetadata (GenerateContentResponse response ) {
245
265
return ChatResponseMetadata .builder ().withUsage (new VertexAiUsage (response .getUsageMetadata ())).build ();
246
266
}
@@ -499,15 +519,6 @@ private GenerateContentResponse getContentResponse(GeminiRequest request) {
499
519
}
500
520
}
501
521
502
- protected boolean isToolFunctionCall (GenerateContentResponse response ) {
503
- if (response == null || CollectionUtils .isEmpty (response .getCandidatesList ())
504
- || response .getCandidatesList ().get (0 ).getContent () == null
505
- || CollectionUtils .isEmpty (response .getCandidatesList ().get (0 ).getContent ().getPartsList ())) {
506
- return false ;
507
- }
508
- return response .getCandidatesList ().get (0 ).getContent ().getPartsList ().get (0 ).hasFunctionCall ();
509
- }
510
-
511
522
@ Override
512
523
public ChatOptions getDefaultOptions () {
513
524
return VertexAiGeminiChatOptions .fromOptions (this .defaultOptions );
0 commit comments