diff --git a/spring-ai-model/src/main/java/org/springframework/ai/tool/function/FunctionToolCallback.java b/spring-ai-model/src/main/java/org/springframework/ai/tool/function/FunctionToolCallback.java index 091f9e2b8e4..92a038d49d2 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/tool/function/FunctionToolCallback.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/tool/function/FunctionToolCallback.java @@ -31,6 +31,7 @@ import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.ai.tool.execution.DefaultToolCallResultConverter; import org.springframework.ai.tool.execution.ToolCallResultConverter; +import org.springframework.ai.tool.execution.ToolExecutionException; import org.springframework.ai.tool.metadata.ToolMetadata; import org.springframework.ai.tool.support.ToolUtils; import org.springframework.ai.util.json.JsonParser; @@ -99,7 +100,15 @@ public String call(String toolInput, @Nullable ToolContext toolContext) { logger.debug("Starting execution of tool: {}", this.toolDefinition.name()); I request = JsonParser.fromJson(toolInput, this.toolInputType); - O response = this.toolFunction.apply(request, toolContext); + O response; + + try { + response = this.toolFunction.apply(request, toolContext); + } + catch (Exception e) { + logger.error("Error executing tool: {}", this.toolDefinition.name(), e); + throw new ToolExecutionException(this.toolDefinition, e); + } logger.debug("Successful execution of tool: {}", this.toolDefinition.name()); diff --git a/spring-ai-model/src/test/java/org/springframework/ai/tool/function/FunctionToolCallbackTests.java b/spring-ai-model/src/test/java/org/springframework/ai/tool/function/FunctionToolCallbackTests.java new file mode 100644 index 00000000000..3db68689e0e --- /dev/null +++ b/spring-ai-model/src/test/java/org/springframework/ai/tool/function/FunctionToolCallbackTests.java @@ -0,0 +1,80 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.tool.function; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.tool.execution.ToolExecutionException; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; + +/** + * Unit tests for {@link FunctionToolCallback}. + * + * @author Marco Schäck + */ +class FunctionToolCallbackTests { + + @Test + void whenToolFunctionExecutesSuccessfullyThenReturnExpectedValue() { + // Given + SquareRootTool squareRootTool = new SquareRootTool(); + FunctionToolCallback callback = FunctionToolCallback + .builder("squareRootTool", squareRootTool::calculate) + .inputType(SquareRootTool.Input.class) + .build(); + + // When + String result = callback.call("{\"number\":25}"); + + // Then + assertThat(result).isEqualTo("{\"result\":5.0}"); + } + + @Test + void whenToolFunctionThrowsExceptionThenWrapInToolExecutionException() { + // Given + SquareRootTool squareRootTool = new SquareRootTool(); + FunctionToolCallback callback = FunctionToolCallback + .builder("squareRootTool", squareRootTool::calculate) + .inputType(SquareRootTool.Input.class) + .build(); + + // When & Then + assertThatThrownBy(() -> callback.call("{\"number\":-9}")).isInstanceOf(ToolExecutionException.class) + .hasCause(new IllegalArgumentException("Cannot calculate square root of negative number: -9")) + .hasMessageContaining("Cannot calculate square root of negative number: -9"); + } + + static class SquareRootTool { + + record Input(int number) { + } + + record Output(double result) { + } + + public Output calculate(Input input) { + if (input.number < 0) { + throw new IllegalArgumentException("Cannot calculate square root of negative number: " + input.number); + } + return new Output(Math.sqrt(input.number)); + } + + } + +}