Skip to content

Commit c6dd6b8

Browse files
committed
Wrap Exception from toolFunction in ToolExecutionException and rethrow
- When `toolFunction` throws a `ToolExecutionException`, rethrow it directly. - When `toolFunction` throws a `Exception`, wrap it in a `ToolExecutionException` and rethrow it. - Add related tests Signed-off-by: YunKui Lu <[email protected]>
1 parent 128c45a commit c6dd6b8

File tree

2 files changed

+185
-1
lines changed

2 files changed

+185
-1
lines changed

spring-ai-model/src/main/java/org/springframework/ai/tool/function/FunctionToolCallback.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.springframework.ai.tool.definition.ToolDefinition;
3232
import org.springframework.ai.tool.execution.DefaultToolCallResultConverter;
3333
import org.springframework.ai.tool.execution.ToolCallResultConverter;
34+
import org.springframework.ai.tool.execution.ToolExecutionException;
3435
import org.springframework.ai.tool.metadata.ToolMetadata;
3536
import org.springframework.ai.tool.support.ToolUtils;
3637
import org.springframework.ai.util.json.JsonParser;
@@ -99,13 +100,25 @@ public String call(String toolInput, @Nullable ToolContext toolContext) {
99100
logger.debug("Starting execution of tool: {}", this.toolDefinition.name());
100101

101102
I request = JsonParser.fromJson(toolInput, this.toolInputType);
102-
O response = this.toolFunction.apply(request, toolContext);
103+
O response = callMethod(request, toolContext);
103104

104105
logger.debug("Successful execution of tool: {}", this.toolDefinition.name());
105106

106107
return this.toolCallResultConverter.convert(response, null);
107108
}
108109

110+
private O callMethod(I request, @Nullable ToolContext toolContext) {
111+
try {
112+
return this.toolFunction.apply(request, toolContext);
113+
}
114+
catch (ToolExecutionException ex) {
115+
throw ex;
116+
}
117+
catch (Exception ex) {
118+
throw new ToolExecutionException(this.toolDefinition, ex);
119+
}
120+
}
121+
109122
@Override
110123
public String toString() {
111124
return "FunctionToolCallback{" + "toolDefinition=" + this.toolDefinition + ", toolMetadata=" + this.toolMetadata
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
package org.springframework.ai.tool.function;
2+
3+
import java.util.Map;
4+
import java.util.concurrent.atomic.AtomicReference;
5+
import java.util.function.BiFunction;
6+
import java.util.function.Consumer;
7+
import java.util.function.Function;
8+
import java.util.function.Supplier;
9+
10+
import org.junit.jupiter.api.Test;
11+
12+
import org.springframework.ai.chat.model.ToolContext;
13+
import org.springframework.ai.tool.execution.ToolExecutionException;
14+
import org.springframework.ai.tool.metadata.ToolMetadata;
15+
16+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
17+
import static org.assertj.core.api.InstanceOfAssertFactories.type;
18+
import static org.junit.jupiter.api.Assertions.assertEquals;
19+
20+
/**
21+
* @author YunKui Lu
22+
*/
23+
class FunctionToolCallbackTest {
24+
25+
@Test
26+
void testConsumerToolCall() {
27+
TestFunctionTool tool = new TestFunctionTool();
28+
FunctionToolCallback<String, Void> callback = FunctionToolCallback.builder("testTool", tool.stringConsumer())
29+
.toolMetadata(ToolMetadata.builder().returnDirect(true).build())
30+
.description("test description")
31+
.inputType(String.class)
32+
.build();
33+
34+
callback.call("\"test string param\"");
35+
36+
assertEquals("test string param", tool.calledValue.get());
37+
}
38+
39+
@Test
40+
void testBiFunctionToolCall() {
41+
TestFunctionTool tool = new TestFunctionTool();
42+
FunctionToolCallback<String, String> callback = FunctionToolCallback
43+
.builder("testTool", tool.stringBiFunction())
44+
.toolMetadata(ToolMetadata.builder().returnDirect(true).build())
45+
.description("test description")
46+
.inputType(String.class)
47+
.build();
48+
49+
ToolContext toolContext = new ToolContext(Map.of("foo", "bar"));
50+
51+
String callResult = callback.call("\"test string param\"", toolContext);
52+
53+
assertEquals("test string param", tool.calledValue.get());
54+
assertEquals("\"return value = test string param\"", callResult);
55+
assertEquals(toolContext, tool.calledToolContext.get());
56+
}
57+
58+
@Test
59+
void testFunctionToolCall() {
60+
TestFunctionTool tool = new TestFunctionTool();
61+
FunctionToolCallback<String, String> callback = FunctionToolCallback.builder("testTool", tool.stringFunction())
62+
.toolMetadata(ToolMetadata.builder().returnDirect(true).build())
63+
.description("test description")
64+
.inputType(String.class)
65+
.build();
66+
67+
ToolContext toolContext = new ToolContext(Map.of());
68+
69+
String callResult = callback.call("\"test string param\"", toolContext);
70+
71+
assertEquals("test string param", tool.calledValue.get());
72+
assertEquals("\"return value = test string param\"", callResult);
73+
}
74+
75+
@Test
76+
void testSupplierToolCall() {
77+
TestFunctionTool tool = new TestFunctionTool();
78+
79+
FunctionToolCallback<Void, String> callback = FunctionToolCallback.builder("testTool", tool.stringSupplier())
80+
.toolMetadata(ToolMetadata.builder().returnDirect(true).build())
81+
.description("test description")
82+
.inputType(Void.class)
83+
.build();
84+
85+
ToolContext toolContext = new ToolContext(Map.of());
86+
87+
String callResult = callback.call("\"test string param\"", toolContext);
88+
89+
assertEquals("not params", tool.calledValue.get());
90+
assertEquals("\"return value = \"", callResult);
91+
}
92+
93+
@Test
94+
void testThrowRuntimeException() {
95+
TestFunctionTool tool = new TestFunctionTool();
96+
FunctionToolCallback<String, Void> callback = FunctionToolCallback
97+
.builder("testTool", tool.throwRuntimeException())
98+
.toolMetadata(ToolMetadata.builder().returnDirect(true).build())
99+
.description("test description")
100+
.inputType(String.class)
101+
.build();
102+
103+
assertThatThrownBy(() -> callback.call("\"test string param\"")).hasMessage("test exception")
104+
.hasCauseInstanceOf(RuntimeException.class)
105+
.asInstanceOf(type(ToolExecutionException.class))
106+
.extracting(ToolExecutionException::getToolDefinition)
107+
.isEqualTo(callback.getToolDefinition());
108+
}
109+
110+
@Test
111+
void testThrowToolExecutionException() {
112+
TestFunctionTool tool = new TestFunctionTool();
113+
FunctionToolCallback<String, Void> callback = FunctionToolCallback
114+
.builder("testTool", tool.throwToolExecutionException())
115+
.toolMetadata(ToolMetadata.builder().returnDirect(true).build())
116+
.description("test description")
117+
.inputType(String.class)
118+
.build();
119+
120+
assertThatThrownBy(() -> callback.call("\"test string param\"")).hasMessage("test exception")
121+
.hasCauseInstanceOf(RuntimeException.class)
122+
.isInstanceOf(ToolExecutionException.class);
123+
}
124+
125+
static class TestFunctionTool {
126+
127+
AtomicReference<Object> calledValue = new AtomicReference<>();
128+
129+
AtomicReference<ToolContext> calledToolContext = new AtomicReference<>();
130+
131+
public Consumer<String> stringConsumer() {
132+
return s -> {
133+
calledValue.set(s);
134+
};
135+
}
136+
137+
public BiFunction<String, ToolContext, String> stringBiFunction() {
138+
return (s, context) -> {
139+
calledValue.set(s);
140+
calledToolContext.set(context);
141+
return "return value = " + s;
142+
};
143+
}
144+
145+
public Function<String, String> stringFunction() {
146+
return s -> {
147+
calledValue.set(s);
148+
return "return value = " + s;
149+
};
150+
}
151+
152+
public Supplier<String> stringSupplier() {
153+
calledValue.set("not params");
154+
return () -> "return value = ";
155+
}
156+
157+
public Consumer<String> throwRuntimeException() {
158+
return s -> {
159+
throw new RuntimeException("test exception");
160+
};
161+
}
162+
163+
public Consumer<String> throwToolExecutionException() {
164+
return s -> {
165+
throw new ToolExecutionException(null, new RuntimeException("test exception"));
166+
};
167+
}
168+
169+
}
170+
171+
}

0 commit comments

Comments
 (0)