Skip to content

Commit 0419888

Browse files
committed
feat: Add return direct support and null safety to ToolCallAdvisor
Implements return direct functionality allowing tools to bypass the LLM and return results directly to clients. Adds null safety checks for chatResponse and comprehensive test coverage. Signed-off-by: Christian Tzolov <[email protected]>
1 parent df6c602 commit 0419888

File tree

3 files changed

+177
-83
lines changed

3 files changed

+177
-83
lines changed

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/ToolCallAdvisor.java

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain;
2626
import org.springframework.ai.chat.client.advisor.api.StreamAdvisor;
2727
import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;
28+
import org.springframework.ai.chat.model.ChatResponse;
2829
import org.springframework.ai.chat.prompt.Prompt;
2930
import org.springframework.ai.model.tool.ToolCallingChatOptions;
3031
import org.springframework.ai.model.tool.ToolCallingManager;
@@ -112,14 +113,30 @@ public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAd
112113
// After Call
113114

114115
// TODO: check that this is tool call is sufficiant for all chat models
115-
// that support tool calls.
116-
isToolCall = chatClientResponse.chatResponse().hasToolCalls();
116+
// that support tool calls. (e.g. Anthropic and Bedrock are checking for
117+
// finish status as well)
118+
isToolCall = chatClientResponse.chatResponse() != null && chatClientResponse.chatResponse().hasToolCalls();
117119

118120
if (isToolCall) {
119121

120122
ToolExecutionResult toolExecutionResult = this.toolCallingManager
121123
.executeToolCalls(processedChatClientRequest.prompt(), chatClientResponse.chatResponse());
122124

125+
if (toolExecutionResult.returnDirect()) {
126+
// Interupt the tool calling loop and return the tool execution result
127+
// directly to the client application instead of returning it tothe
128+
// LLM.
129+
isToolCall = false;
130+
131+
// Return tool execution result directly to the application client.
132+
chatClientResponse = chatClientResponse.mutate()
133+
.chatResponse(ChatResponse.builder()
134+
.from(chatClientResponse.chatResponse())
135+
.generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
136+
.build())
137+
.build();
138+
}
139+
123140
instructions = toolExecutionResult.conversationHistory();
124141
}
125142

spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/ToolCallAdvisorTests.java

Lines changed: 141 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,15 @@
1818

1919
import java.util.List;
2020
import java.util.Map;
21+
import java.util.function.BiFunction;
2122

2223
import io.micrometer.observation.ObservationRegistry;
2324
import org.junit.jupiter.api.Test;
2425
import org.junit.jupiter.api.extension.ExtendWith;
2526
import org.mockito.Mock;
27+
import org.mockito.Mockito;
2628
import org.mockito.junit.jupiter.MockitoExtension;
29+
import org.mockito.quality.Strictness;
2730
import reactor.core.publisher.Flux;
2831

2932
import org.springframework.ai.chat.client.ChatClientRequest;
@@ -36,6 +39,7 @@
3639
import org.springframework.ai.chat.messages.Message;
3740
import org.springframework.ai.chat.messages.ToolResponseMessage;
3841
import org.springframework.ai.chat.messages.UserMessage;
42+
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
3943
import org.springframework.ai.chat.model.ChatResponse;
4044
import org.springframework.ai.chat.model.Generation;
4145
import org.springframework.ai.chat.prompt.ChatOptions;
@@ -162,22 +166,28 @@ void testAdviseCallWithoutToolCalls() {
162166
ChatClientResponse response = createMockResponse(false);
163167

164168
// Create a terminal advisor that returns the response
165-
CallAdvisor terminalAdvisor = new CallAdvisor() {
166-
@Override
167-
public String getName() {
168-
return "terminal";
169-
}
169+
CallAdvisor terminalAdvisor = new TerminalCallAdvisor((req, chain) -> response);
170170

171-
@Override
172-
public int getOrder() {
173-
return 0;
174-
}
171+
// Create a real chain with both advisors
172+
CallAdvisorChain realChain = DefaultAroundAdvisorChain.builder(ObservationRegistry.NOOP)
173+
.pushAll(List.of(advisor, terminalAdvisor))
174+
.build();
175175

176-
@Override
177-
public ChatClientResponse adviseCall(ChatClientRequest req, CallAdvisorChain chain) {
178-
return response;
179-
}
180-
};
176+
ChatClientResponse result = advisor.adviseCall(request, realChain);
177+
178+
assertThat(result).isEqualTo(response);
179+
verify(this.toolCallingManager, times(0)).executeToolCalls(any(), any());
180+
}
181+
182+
@Test
183+
void testAdviseCallWithNullChatResponse() {
184+
ToolCallAdvisor advisor = ToolCallAdvisor.builder().toolCallingManager(this.toolCallingManager).build();
185+
186+
ChatClientRequest request = createMockRequest(true);
187+
ChatClientResponse responseWithNullChatResponse = ChatClientResponse.builder().build();
188+
189+
// Create a terminal advisor that returns the response with null chatResponse
190+
CallAdvisor terminalAdvisor = new TerminalCallAdvisor((req, chain) -> responseWithNullChatResponse);
181191

182192
// Create a real chain with both advisors
183193
CallAdvisorChain realChain = DefaultAroundAdvisorChain.builder(ObservationRegistry.NOOP)
@@ -186,7 +196,7 @@ public ChatClientResponse adviseCall(ChatClientRequest req, CallAdvisorChain cha
186196

187197
ChatClientResponse result = advisor.adviseCall(request, realChain);
188198

189-
assertThat(result).isEqualTo(response);
199+
assertThat(result).isEqualTo(responseWithNullChatResponse);
190200
verify(this.toolCallingManager, times(0)).executeToolCalls(any(), any());
191201
}
192202

@@ -200,23 +210,11 @@ void testAdviseCallWithSingleToolCallIteration() {
200210

201211
// Create a terminal advisor that returns responses in sequence
202212
int[] callCount = { 0 };
203-
CallAdvisor terminalAdvisor = new CallAdvisor() {
204-
@Override
205-
public String getName() {
206-
return "terminal";
207-
}
208213

209-
@Override
210-
public int getOrder() {
211-
return 0;
212-
}
213-
214-
@Override
215-
public ChatClientResponse adviseCall(ChatClientRequest req, CallAdvisorChain chain) {
216-
callCount[0]++;
217-
return callCount[0] == 1 ? responseWithToolCall : finalResponse;
218-
}
219-
};
214+
CallAdvisor terminalAdvisor = new TerminalCallAdvisor((req, chain) -> {
215+
callCount[0]++;
216+
return callCount[0] == 1 ? responseWithToolCall : finalResponse;
217+
});
220218

221219
// Create a real chain with both advisors
222220
CallAdvisorChain realChain = DefaultAroundAdvisorChain.builder(ObservationRegistry.NOOP)
@@ -225,7 +223,7 @@ public ChatClientResponse adviseCall(ChatClientRequest req, CallAdvisorChain cha
225223

226224
// Mock tool execution result
227225
List<Message> conversationHistory = List.of(new UserMessage("test"),
228-
new AssistantMessage("", Map.of(), List.of()), new ToolResponseMessage(List.of()));
226+
AssistantMessage.builder().content("").build(), ToolResponseMessage.builder().build());
229227
ToolExecutionResult toolExecutionResult = ToolExecutionResult.builder()
230228
.conversationHistory(conversationHistory)
231229
.build();
@@ -250,31 +248,18 @@ void testAdviseCallWithMultipleToolCallIterations() {
250248

251249
// Create a terminal advisor that returns responses in sequence
252250
int[] callCount = { 0 };
253-
CallAdvisor terminalAdvisor = new CallAdvisor() {
254-
@Override
255-
public String getName() {
256-
return "terminal";
251+
CallAdvisor terminalAdvisor = new TerminalCallAdvisor((req, chain) -> {
252+
callCount[0]++;
253+
if (callCount[0] == 1) {
254+
return firstToolCallResponse;
257255
}
258-
259-
@Override
260-
public int getOrder() {
261-
return 0;
256+
else if (callCount[0] == 2) {
257+
return secondToolCallResponse;
262258
}
263-
264-
@Override
265-
public ChatClientResponse adviseCall(ChatClientRequest req, CallAdvisorChain chain) {
266-
callCount[0]++;
267-
if (callCount[0] == 1) {
268-
return firstToolCallResponse;
269-
}
270-
else if (callCount[0] == 2) {
271-
return secondToolCallResponse;
272-
}
273-
else {
274-
return finalResponse;
275-
}
259+
else {
260+
return finalResponse;
276261
}
277-
};
262+
});
278263

279264
// Create a real chain with both advisors
280265
CallAdvisorChain realChain = DefaultAroundAdvisorChain.builder(ObservationRegistry.NOOP)
@@ -284,7 +269,7 @@ else if (callCount[0] == 2) {
284269
// Mock tool execution results
285270
AssistantMessage.builder().build();
286271
List<Message> conversationHistory = List.of(new UserMessage("test"),
287-
new AssistantMessage("", Map.of(), List.of()), new ToolResponseMessage(List.of()));
272+
AssistantMessage.builder().content("").build(), ToolResponseMessage.builder().build());
288273
ToolExecutionResult toolExecutionResult = ToolExecutionResult.builder()
289274
.conversationHistory(conversationHistory)
290275
.build();
@@ -298,6 +283,49 @@ else if (callCount[0] == 2) {
298283
verify(this.toolCallingManager, times(2)).executeToolCalls(any(Prompt.class), any(ChatResponse.class));
299284
}
300285

286+
@Test
287+
void testAdviseCallWithReturnDirectToolExecution() {
288+
ToolCallAdvisor advisor = ToolCallAdvisor.builder().toolCallingManager(this.toolCallingManager).build();
289+
290+
ChatClientRequest request = createMockRequest(true);
291+
ChatClientResponse responseWithToolCall = createMockResponse(true);
292+
293+
// Create a terminal advisor that returns the response
294+
CallAdvisor terminalAdvisor = new TerminalCallAdvisor((req, chain) -> responseWithToolCall);
295+
296+
// Create a real chain with both advisors
297+
CallAdvisorChain realChain = DefaultAroundAdvisorChain.builder(ObservationRegistry.NOOP)
298+
.pushAll(List.of(advisor, terminalAdvisor))
299+
.build();
300+
301+
// Mock tool execution result with returnDirect = true
302+
ToolResponseMessage.ToolResponse toolResponse = new ToolResponseMessage.ToolResponse("tool-1", "testTool",
303+
"Tool result data");
304+
ToolResponseMessage toolResponseMessage = ToolResponseMessage.builder()
305+
.responses(List.of(toolResponse))
306+
.build();
307+
List<Message> conversationHistory = List.of(new UserMessage("test"),
308+
AssistantMessage.builder().content("").build(), toolResponseMessage);
309+
ToolExecutionResult toolExecutionResult = ToolExecutionResult.builder()
310+
.conversationHistory(conversationHistory)
311+
.returnDirect(true)
312+
.build();
313+
when(this.toolCallingManager.executeToolCalls(any(Prompt.class), any(ChatResponse.class)))
314+
.thenReturn(toolExecutionResult);
315+
316+
ChatClientResponse result = advisor.adviseCall(request, realChain);
317+
318+
// Verify that the tool execution was called only once (no loop continuation)
319+
verify(this.toolCallingManager, times(1)).executeToolCalls(any(Prompt.class), any(ChatResponse.class));
320+
321+
// Verify that the result contains the tool execution result as generations
322+
assertThat(result.chatResponse()).isNotNull();
323+
assertThat(result.chatResponse().getResults()).hasSize(1);
324+
assertThat(result.chatResponse().getResults().get(0).getOutput().getText()).isEqualTo("Tool result data");
325+
assertThat(result.chatResponse().getResults().get(0).getMetadata().getFinishReason())
326+
.isEqualTo(ToolExecutionResult.FINISH_REASON);
327+
}
328+
301329
@Test
302330
void testInternalToolExecutionIsDisabled() {
303331
ToolCallAdvisor advisor = ToolCallAdvisor.builder().toolCallingManager(this.toolCallingManager).build();
@@ -307,23 +335,11 @@ void testInternalToolExecutionIsDisabled() {
307335

308336
// Use a simple holder to capture the request
309337
ChatClientRequest[] capturedRequest = new ChatClientRequest[1];
310-
CallAdvisor capturingAdvisor = new CallAdvisor() {
311-
@Override
312-
public String getName() {
313-
return "capturing";
314-
}
315-
316-
@Override
317-
public int getOrder() {
318-
return 0;
319-
}
320338

321-
@Override
322-
public ChatClientResponse adviseCall(ChatClientRequest req, CallAdvisorChain chain) {
323-
capturedRequest[0] = req;
324-
return response;
325-
}
326-
};
339+
CallAdvisor capturingAdvisor = new TerminalCallAdvisor((req, chain) -> {
340+
capturedRequest[0] = req;
341+
return response;
342+
});
327343

328344
CallAdvisorChain capturingChain = DefaultAroundAdvisorChain.builder(ObservationRegistry.NOOP)
329345
.pushAll(List.of(advisor, capturingAdvisor))
@@ -369,10 +385,10 @@ private ChatClientRequest createMockRequest(boolean withToolCallingOptions) {
369385
ChatOptions options = null;
370386
if (withToolCallingOptions) {
371387
ToolCallingChatOptions toolOptions = mock(ToolCallingChatOptions.class,
372-
org.mockito.Mockito.withSettings().lenient());
388+
Mockito.withSettings().strictness(Strictness.LENIENT));
373389
// Create a separate mock for the copy that tracks the internal state
374390
ToolCallingChatOptions copiedOptions = mock(ToolCallingChatOptions.class,
375-
org.mockito.Mockito.withSettings().lenient());
391+
Mockito.withSettings().strictness(Strictness.LENIENT));
376392

377393
// Use a holder to track the state
378394
boolean[] internalToolExecutionEnabled = { true };
@@ -387,7 +403,7 @@ private ChatClientRequest createMockRequest(boolean withToolCallingOptions) {
387403

388404
// When setInternalToolExecutionEnabled is called on the copy, update the
389405
// state
390-
org.mockito.Mockito.doAnswer(invocation -> {
406+
Mockito.doAnswer(invocation -> {
391407
internalToolExecutionEnabled[0] = invocation.getArgument(0);
392408
return null;
393409
}).when(copiedOptions).setInternalToolExecutionEnabled(org.mockito.ArgumentMatchers.anyBoolean());
@@ -401,17 +417,61 @@ private ChatClientRequest createMockRequest(boolean withToolCallingOptions) {
401417
}
402418

403419
private ChatClientResponse createMockResponse(boolean hasToolCalls) {
404-
ChatResponse chatResponse = mock(ChatResponse.class, org.mockito.Mockito.withSettings().lenient());
405-
when(chatResponse.hasToolCalls()).thenReturn(hasToolCalls);
406-
407-
Generation generation = mock(Generation.class, org.mockito.Mockito.withSettings().lenient());
420+
Generation generation = mock(Generation.class, Mockito.withSettings().strictness(Strictness.LENIENT));
408421
when(generation.getOutput()).thenReturn(new AssistantMessage("response"));
409-
when(chatResponse.getResults()).thenReturn(List.of(generation));
410422

411-
ChatClientResponse response = mock(ChatClientResponse.class, org.mockito.Mockito.withSettings().lenient());
412-
when(response.chatResponse()).thenReturn(chatResponse);
423+
// Mock metadata to avoid NullPointerException in ChatResponse.Builder.from()
424+
ChatResponseMetadata metadata = mock(ChatResponseMetadata.class,
425+
Mockito.withSettings().strictness(Strictness.LENIENT));
426+
when(metadata.getModel()).thenReturn("");
427+
when(metadata.getId()).thenReturn("");
428+
when(metadata.getRateLimit()).thenReturn(null);
429+
when(metadata.getUsage()).thenReturn(null);
430+
when(metadata.getPromptMetadata()).thenReturn(null);
431+
when(metadata.entrySet()).thenReturn(java.util.Collections.emptySet());
432+
433+
// Create a real ChatResponse instead of mocking it to avoid issues with
434+
// ChatResponse.Builder.from()
435+
ChatResponse chatResponse = ChatResponse.builder().generations(List.of(generation)).metadata(metadata).build();
436+
437+
// Mock hasToolCalls since it's not part of the builder
438+
ChatResponse spyChatResponse = Mockito.spy(chatResponse);
439+
when(spyChatResponse.hasToolCalls()).thenReturn(hasToolCalls);
440+
441+
ChatClientResponse response = mock(ChatClientResponse.class,
442+
Mockito.withSettings().strictness(Strictness.LENIENT));
443+
when(response.chatResponse()).thenReturn(spyChatResponse);
444+
445+
// Mock mutate() to return a real builder that can handle the mutation
446+
when(response.mutate())
447+
.thenAnswer(invocation -> ChatClientResponse.builder().chatResponse(spyChatResponse).context(Map.of()));
413448

414449
return response;
415450
}
416451

452+
private static class TerminalCallAdvisor implements CallAdvisor {
453+
454+
private final BiFunction<ChatClientRequest, CallAdvisorChain, ChatClientResponse> responseFunction;
455+
456+
TerminalCallAdvisor(BiFunction<ChatClientRequest, CallAdvisorChain, ChatClientResponse> responseFunction) {
457+
this.responseFunction = responseFunction;
458+
}
459+
460+
@Override
461+
public String getName() {
462+
return "terminal";
463+
}
464+
465+
@Override
466+
public int getOrder() {
467+
return 0;
468+
}
469+
470+
@Override
471+
public ChatClientResponse adviseCall(ChatClientRequest req, CallAdvisorChain chain) {
472+
return this.responseFunction.apply(req, chain);
473+
}
474+
475+
};
476+
417477
}

0 commit comments

Comments
 (0)