Skip to content

Commit 855b079

Browse files
committed
fix: handle errors when toolExecution fails for functionToolCallback
Signed-off-by: Marco Schaeck <[email protected]>
1 parent 59f2b3b commit 855b079

File tree

2 files changed

+90
-1
lines changed

2 files changed

+90
-1
lines changed

spring-ai-model/src/main/java/org/springframework/ai/tool/function/FunctionToolCallback.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.springframework.ai.tool.definition.ToolDefinition;
3232
import org.springframework.ai.tool.execution.DefaultToolCallResultConverter;
3333
import org.springframework.ai.tool.execution.ToolCallResultConverter;
34+
import org.springframework.ai.tool.execution.ToolExecutionException;
3435
import org.springframework.ai.tool.metadata.ToolMetadata;
3536
import org.springframework.ai.tool.support.ToolUtils;
3637
import org.springframework.ai.util.json.JsonParser;
@@ -99,7 +100,15 @@ public String call(String toolInput, @Nullable ToolContext toolContext) {
99100
logger.debug("Starting execution of tool: {}", this.toolDefinition.name());
100101

101102
I request = JsonParser.fromJson(toolInput, this.toolInputType);
102-
O response = this.toolFunction.apply(request, toolContext);
103+
O response;
104+
105+
try {
106+
response = this.toolFunction.apply(request, toolContext);
107+
}
108+
catch (Exception e) {
109+
logger.error("Error executing tool: {}", this.toolDefinition.name(), e);
110+
throw new ToolExecutionException(this.toolDefinition, e);
111+
}
103112

104113
logger.debug("Successful execution of tool: {}", this.toolDefinition.name());
105114

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Copyright 2025-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.tool.function;
18+
19+
import org.junit.jupiter.api.Test;
20+
import org.springframework.ai.tool.execution.ToolExecutionException;
21+
22+
import static org.assertj.core.api.Assertions.assertThat;
23+
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
24+
25+
/**
26+
* Unit tests for {@link FunctionToolCallback}.
27+
*
28+
* @author Marco Schäck
29+
*/
30+
class FunctionToolCallbackTests {
31+
32+
@Test
33+
void whenToolFunctionExecutesSuccessfullyThenReturnExpectedValue() {
34+
// Given
35+
SquareRootTool squareRootTool = new SquareRootTool();
36+
FunctionToolCallback<SquareRootTool.Input, SquareRootTool.Output> callback = FunctionToolCallback
37+
.builder("squareRootTool", squareRootTool::calculate)
38+
.inputType(SquareRootTool.Input.class)
39+
.build();
40+
41+
// When
42+
String result = callback.call("{\"number\":25}");
43+
44+
// Then
45+
assertThat(result).isEqualTo("{\"result\":5.0}");
46+
}
47+
48+
@Test
49+
void whenToolFunctionThrowsExceptionThenWrapInToolExecutionException() {
50+
// Given
51+
SquareRootTool squareRootTool = new SquareRootTool();
52+
FunctionToolCallback<SquareRootTool.Input, SquareRootTool.Output> callback = FunctionToolCallback
53+
.builder("squareRootTool", squareRootTool::calculate)
54+
.inputType(SquareRootTool.Input.class)
55+
.build();
56+
57+
// When & Then
58+
assertThatThrownBy(() -> callback.call("{\"number\":-9}")).isInstanceOf(ToolExecutionException.class)
59+
.hasCause(new IllegalArgumentException("Cannot calculate square root of negative number: -9"))
60+
.hasMessageContaining("Cannot calculate square root of negative number: -9");
61+
}
62+
63+
static class SquareRootTool {
64+
65+
record Input(int number) {
66+
}
67+
68+
record Output(double result) {
69+
}
70+
71+
public Output calculate(Input input) {
72+
if (input.number < 0) {
73+
throw new IllegalArgumentException("Cannot calculate square root of negative number: " + input.number);
74+
}
75+
return new Output(Math.sqrt(input.number));
76+
}
77+
78+
}
79+
80+
}

0 commit comments

Comments
 (0)