1818
1919import java .util .List ;
2020import java .util .Map ;
21+ import java .util .function .BiFunction ;
2122
2223import io .micrometer .observation .ObservationRegistry ;
2324import org .junit .jupiter .api .Test ;
2425import org .junit .jupiter .api .extension .ExtendWith ;
2526import org .mockito .Mock ;
27+ import org .mockito .Mockito ;
2628import org .mockito .junit .jupiter .MockitoExtension ;
29+ import org .mockito .quality .Strictness ;
2730import reactor .core .publisher .Flux ;
2831
2932import org .springframework .ai .chat .client .ChatClientRequest ;
3639import org .springframework .ai .chat .messages .Message ;
3740import org .springframework .ai .chat .messages .ToolResponseMessage ;
3841import org .springframework .ai .chat .messages .UserMessage ;
42+ import org .springframework .ai .chat .metadata .ChatResponseMetadata ;
3943import org .springframework .ai .chat .model .ChatResponse ;
4044import org .springframework .ai .chat .model .Generation ;
4145import org .springframework .ai .chat .prompt .ChatOptions ;
@@ -162,22 +166,28 @@ void testAdviseCallWithoutToolCalls() {
162166 ChatClientResponse response = createMockResponse (false );
163167
164168 // Create a terminal advisor that returns the response
165- CallAdvisor terminalAdvisor = new CallAdvisor () {
166- @ Override
167- public String getName () {
168- return "terminal" ;
169- }
169+ CallAdvisor terminalAdvisor = new TerminalCallAdvisor ((req , chain ) -> response );
170170
171- @ Override
172- public int getOrder () {
173- return 0 ;
174- }
171+ // Create a real chain with both advisors
172+ CallAdvisorChain realChain = DefaultAroundAdvisorChain . builder ( ObservationRegistry . NOOP )
173+ . pushAll ( List . of ( advisor , terminalAdvisor ))
174+ . build ();
175175
176- @ Override
177- public ChatClientResponse adviseCall (ChatClientRequest req , CallAdvisorChain chain ) {
178- return response ;
179- }
180- };
176+ ChatClientResponse result = advisor .adviseCall (request , realChain );
177+
178+ assertThat (result ).isEqualTo (response );
179+ verify (this .toolCallingManager , times (0 )).executeToolCalls (any (), any ());
180+ }
181+
182+ @ Test
183+ void testAdviseCallWithNullChatResponse () {
184+ ToolCallAdvisor advisor = ToolCallAdvisor .builder ().toolCallingManager (this .toolCallingManager ).build ();
185+
186+ ChatClientRequest request = createMockRequest (true );
187+ ChatClientResponse responseWithNullChatResponse = ChatClientResponse .builder ().build ();
188+
189+ // Create a terminal advisor that returns the response with null chatResponse
190+ CallAdvisor terminalAdvisor = new TerminalCallAdvisor ((req , chain ) -> responseWithNullChatResponse );
181191
182192 // Create a real chain with both advisors
183193 CallAdvisorChain realChain = DefaultAroundAdvisorChain .builder (ObservationRegistry .NOOP )
@@ -186,7 +196,7 @@ public ChatClientResponse adviseCall(ChatClientRequest req, CallAdvisorChain cha
186196
187197 ChatClientResponse result = advisor .adviseCall (request , realChain );
188198
189- assertThat (result ).isEqualTo (response );
199+ assertThat (result ).isEqualTo (responseWithNullChatResponse );
190200 verify (this .toolCallingManager , times (0 )).executeToolCalls (any (), any ());
191201 }
192202
@@ -200,23 +210,11 @@ void testAdviseCallWithSingleToolCallIteration() {
200210
201211 // Create a terminal advisor that returns responses in sequence
202212 int [] callCount = { 0 };
203- CallAdvisor terminalAdvisor = new CallAdvisor () {
204- @ Override
205- public String getName () {
206- return "terminal" ;
207- }
208213
209- @ Override
210- public int getOrder () {
211- return 0 ;
212- }
213-
214- @ Override
215- public ChatClientResponse adviseCall (ChatClientRequest req , CallAdvisorChain chain ) {
216- callCount [0 ]++;
217- return callCount [0 ] == 1 ? responseWithToolCall : finalResponse ;
218- }
219- };
214+ CallAdvisor terminalAdvisor = new TerminalCallAdvisor ((req , chain ) -> {
215+ callCount [0 ]++;
216+ return callCount [0 ] == 1 ? responseWithToolCall : finalResponse ;
217+ });
220218
221219 // Create a real chain with both advisors
222220 CallAdvisorChain realChain = DefaultAroundAdvisorChain .builder (ObservationRegistry .NOOP )
@@ -225,7 +223,7 @@ public ChatClientResponse adviseCall(ChatClientRequest req, CallAdvisorChain cha
225223
226224 // Mock tool execution result
227225 List <Message > conversationHistory = List .of (new UserMessage ("test" ),
228- new AssistantMessage ( "" , Map . of (), List . of ()), new ToolResponseMessage ( List . of () ));
226+ AssistantMessage . builder (). content ( "" ). build (), ToolResponseMessage . builder (). build ( ));
229227 ToolExecutionResult toolExecutionResult = ToolExecutionResult .builder ()
230228 .conversationHistory (conversationHistory )
231229 .build ();
@@ -250,31 +248,18 @@ void testAdviseCallWithMultipleToolCallIterations() {
250248
251249 // Create a terminal advisor that returns responses in sequence
252250 int [] callCount = { 0 };
253- CallAdvisor terminalAdvisor = new CallAdvisor () {
254- @ Override
255- public String getName ( ) {
256- return "terminal" ;
251+ CallAdvisor terminalAdvisor = new TerminalCallAdvisor (( req , chain ) -> {
252+ callCount [ 0 ]++;
253+ if ( callCount [ 0 ] == 1 ) {
254+ return firstToolCallResponse ;
257255 }
258-
259- @ Override
260- public int getOrder () {
261- return 0 ;
256+ else if (callCount [0 ] == 2 ) {
257+ return secondToolCallResponse ;
262258 }
263-
264- @ Override
265- public ChatClientResponse adviseCall (ChatClientRequest req , CallAdvisorChain chain ) {
266- callCount [0 ]++;
267- if (callCount [0 ] == 1 ) {
268- return firstToolCallResponse ;
269- }
270- else if (callCount [0 ] == 2 ) {
271- return secondToolCallResponse ;
272- }
273- else {
274- return finalResponse ;
275- }
259+ else {
260+ return finalResponse ;
276261 }
277- };
262+ }) ;
278263
279264 // Create a real chain with both advisors
280265 CallAdvisorChain realChain = DefaultAroundAdvisorChain .builder (ObservationRegistry .NOOP )
@@ -284,7 +269,7 @@ else if (callCount[0] == 2) {
284269 // Mock tool execution results
285270 AssistantMessage .builder ().build ();
286271 List <Message > conversationHistory = List .of (new UserMessage ("test" ),
287- new AssistantMessage ( "" , Map . of (), List . of ()), new ToolResponseMessage ( List . of () ));
272+ AssistantMessage . builder (). content ( "" ). build (), ToolResponseMessage . builder (). build ( ));
288273 ToolExecutionResult toolExecutionResult = ToolExecutionResult .builder ()
289274 .conversationHistory (conversationHistory )
290275 .build ();
@@ -298,6 +283,49 @@ else if (callCount[0] == 2) {
298283 verify (this .toolCallingManager , times (2 )).executeToolCalls (any (Prompt .class ), any (ChatResponse .class ));
299284 }
300285
286+ @ Test
287+ void testAdviseCallWithReturnDirectToolExecution () {
288+ ToolCallAdvisor advisor = ToolCallAdvisor .builder ().toolCallingManager (this .toolCallingManager ).build ();
289+
290+ ChatClientRequest request = createMockRequest (true );
291+ ChatClientResponse responseWithToolCall = createMockResponse (true );
292+
293+ // Create a terminal advisor that returns the response
294+ CallAdvisor terminalAdvisor = new TerminalCallAdvisor ((req , chain ) -> responseWithToolCall );
295+
296+ // Create a real chain with both advisors
297+ CallAdvisorChain realChain = DefaultAroundAdvisorChain .builder (ObservationRegistry .NOOP )
298+ .pushAll (List .of (advisor , terminalAdvisor ))
299+ .build ();
300+
301+ // Mock tool execution result with returnDirect = true
302+ ToolResponseMessage .ToolResponse toolResponse = new ToolResponseMessage .ToolResponse ("tool-1" , "testTool" ,
303+ "Tool result data" );
304+ ToolResponseMessage toolResponseMessage = ToolResponseMessage .builder ()
305+ .responses (List .of (toolResponse ))
306+ .build ();
307+ List <Message > conversationHistory = List .of (new UserMessage ("test" ),
308+ AssistantMessage .builder ().content ("" ).build (), toolResponseMessage );
309+ ToolExecutionResult toolExecutionResult = ToolExecutionResult .builder ()
310+ .conversationHistory (conversationHistory )
311+ .returnDirect (true )
312+ .build ();
313+ when (this .toolCallingManager .executeToolCalls (any (Prompt .class ), any (ChatResponse .class )))
314+ .thenReturn (toolExecutionResult );
315+
316+ ChatClientResponse result = advisor .adviseCall (request , realChain );
317+
318+ // Verify that the tool execution was called only once (no loop continuation)
319+ verify (this .toolCallingManager , times (1 )).executeToolCalls (any (Prompt .class ), any (ChatResponse .class ));
320+
321+ // Verify that the result contains the tool execution result as generations
322+ assertThat (result .chatResponse ()).isNotNull ();
323+ assertThat (result .chatResponse ().getResults ()).hasSize (1 );
324+ assertThat (result .chatResponse ().getResults ().get (0 ).getOutput ().getText ()).isEqualTo ("Tool result data" );
325+ assertThat (result .chatResponse ().getResults ().get (0 ).getMetadata ().getFinishReason ())
326+ .isEqualTo (ToolExecutionResult .FINISH_REASON );
327+ }
328+
301329 @ Test
302330 void testInternalToolExecutionIsDisabled () {
303331 ToolCallAdvisor advisor = ToolCallAdvisor .builder ().toolCallingManager (this .toolCallingManager ).build ();
@@ -307,23 +335,11 @@ void testInternalToolExecutionIsDisabled() {
307335
308336 // Use a simple holder to capture the request
309337 ChatClientRequest [] capturedRequest = new ChatClientRequest [1 ];
310- CallAdvisor capturingAdvisor = new CallAdvisor () {
311- @ Override
312- public String getName () {
313- return "capturing" ;
314- }
315-
316- @ Override
317- public int getOrder () {
318- return 0 ;
319- }
320338
321- @ Override
322- public ChatClientResponse adviseCall (ChatClientRequest req , CallAdvisorChain chain ) {
323- capturedRequest [0 ] = req ;
324- return response ;
325- }
326- };
339+ CallAdvisor capturingAdvisor = new TerminalCallAdvisor ((req , chain ) -> {
340+ capturedRequest [0 ] = req ;
341+ return response ;
342+ });
327343
328344 CallAdvisorChain capturingChain = DefaultAroundAdvisorChain .builder (ObservationRegistry .NOOP )
329345 .pushAll (List .of (advisor , capturingAdvisor ))
@@ -369,10 +385,10 @@ private ChatClientRequest createMockRequest(boolean withToolCallingOptions) {
369385 ChatOptions options = null ;
370386 if (withToolCallingOptions ) {
371387 ToolCallingChatOptions toolOptions = mock (ToolCallingChatOptions .class ,
372- org . mockito . Mockito .withSettings ().lenient ( ));
388+ Mockito .withSettings ().strictness ( Strictness . LENIENT ));
373389 // Create a separate mock for the copy that tracks the internal state
374390 ToolCallingChatOptions copiedOptions = mock (ToolCallingChatOptions .class ,
375- org . mockito . Mockito .withSettings ().lenient ( ));
391+ Mockito .withSettings ().strictness ( Strictness . LENIENT ));
376392
377393 // Use a holder to track the state
378394 boolean [] internalToolExecutionEnabled = { true };
@@ -387,7 +403,7 @@ private ChatClientRequest createMockRequest(boolean withToolCallingOptions) {
387403
388404 // When setInternalToolExecutionEnabled is called on the copy, update the
389405 // state
390- org . mockito . Mockito .doAnswer (invocation -> {
406+ Mockito .doAnswer (invocation -> {
391407 internalToolExecutionEnabled [0 ] = invocation .getArgument (0 );
392408 return null ;
393409 }).when (copiedOptions ).setInternalToolExecutionEnabled (org .mockito .ArgumentMatchers .anyBoolean ());
@@ -401,17 +417,61 @@ private ChatClientRequest createMockRequest(boolean withToolCallingOptions) {
401417 }
402418
403419 private ChatClientResponse createMockResponse (boolean hasToolCalls ) {
404- ChatResponse chatResponse = mock (ChatResponse .class , org .mockito .Mockito .withSettings ().lenient ());
405- when (chatResponse .hasToolCalls ()).thenReturn (hasToolCalls );
406-
407- Generation generation = mock (Generation .class , org .mockito .Mockito .withSettings ().lenient ());
420+ Generation generation = mock (Generation .class , Mockito .withSettings ().strictness (Strictness .LENIENT ));
408421 when (generation .getOutput ()).thenReturn (new AssistantMessage ("response" ));
409- when (chatResponse .getResults ()).thenReturn (List .of (generation ));
410422
411- ChatClientResponse response = mock (ChatClientResponse .class , org .mockito .Mockito .withSettings ().lenient ());
412- when (response .chatResponse ()).thenReturn (chatResponse );
423+ // Mock metadata to avoid NullPointerException in ChatResponse.Builder.from()
424+ ChatResponseMetadata metadata = mock (ChatResponseMetadata .class ,
425+ Mockito .withSettings ().strictness (Strictness .LENIENT ));
426+ when (metadata .getModel ()).thenReturn ("" );
427+ when (metadata .getId ()).thenReturn ("" );
428+ when (metadata .getRateLimit ()).thenReturn (null );
429+ when (metadata .getUsage ()).thenReturn (null );
430+ when (metadata .getPromptMetadata ()).thenReturn (null );
431+ when (metadata .entrySet ()).thenReturn (java .util .Collections .emptySet ());
432+
433+ // Create a real ChatResponse instead of mocking it to avoid issues with
434+ // ChatResponse.Builder.from()
435+ ChatResponse chatResponse = ChatResponse .builder ().generations (List .of (generation )).metadata (metadata ).build ();
436+
437+ // Mock hasToolCalls since it's not part of the builder
438+ ChatResponse spyChatResponse = Mockito .spy (chatResponse );
439+ when (spyChatResponse .hasToolCalls ()).thenReturn (hasToolCalls );
440+
441+ ChatClientResponse response = mock (ChatClientResponse .class ,
442+ Mockito .withSettings ().strictness (Strictness .LENIENT ));
443+ when (response .chatResponse ()).thenReturn (spyChatResponse );
444+
445+ // Mock mutate() to return a real builder that can handle the mutation
446+ when (response .mutate ())
447+ .thenAnswer (invocation -> ChatClientResponse .builder ().chatResponse (spyChatResponse ).context (Map .of ()));
413448
414449 return response ;
415450 }
416451
452+ private static class TerminalCallAdvisor implements CallAdvisor {
453+
454+ private final BiFunction <ChatClientRequest , CallAdvisorChain , ChatClientResponse > responseFunction ;
455+
456+ TerminalCallAdvisor (BiFunction <ChatClientRequest , CallAdvisorChain , ChatClientResponse > responseFunction ) {
457+ this .responseFunction = responseFunction ;
458+ }
459+
460+ @ Override
461+ public String getName () {
462+ return "terminal" ;
463+ }
464+
465+ @ Override
466+ public int getOrder () {
467+ return 0 ;
468+ }
469+
470+ @ Override
471+ public ChatClientResponse adviseCall (ChatClientRequest req , CallAdvisorChain chain ) {
472+ return this .responseFunction .apply (req , chain );
473+ }
474+
475+ };
476+
417477}
0 commit comments