diff --git a/spring-ai-model/src/main/java/org/springframework/ai/tool/function/FunctionToolCallback.java b/spring-ai-model/src/main/java/org/springframework/ai/tool/function/FunctionToolCallback.java index 091f9e2b8e4..4892a803bfe 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/tool/function/FunctionToolCallback.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/tool/function/FunctionToolCallback.java @@ -31,6 +31,7 @@ import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.ai.tool.execution.DefaultToolCallResultConverter; import org.springframework.ai.tool.execution.ToolCallResultConverter; +import org.springframework.ai.tool.execution.ToolExecutionException; import org.springframework.ai.tool.metadata.ToolMetadata; import org.springframework.ai.tool.support.ToolUtils; import org.springframework.ai.util.json.JsonParser; @@ -44,6 +45,7 @@ * A {@link ToolCallback} implementation to invoke functions as tools. * * @author Thomas Vitale + * @author YunKui Lu * @since 1.0.0 */ public class FunctionToolCallback implements ToolCallback { @@ -99,13 +101,25 @@ public String call(String toolInput, @Nullable ToolContext toolContext) { logger.debug("Starting execution of tool: {}", this.toolDefinition.name()); I request = JsonParser.fromJson(toolInput, this.toolInputType); - O response = this.toolFunction.apply(request, toolContext); + O response = callMethod(request, toolContext); logger.debug("Successful execution of tool: {}", this.toolDefinition.name()); return this.toolCallResultConverter.convert(response, null); } + private O callMethod(I request, @Nullable ToolContext toolContext) { + try { + return this.toolFunction.apply(request, toolContext); + } + catch (ToolExecutionException ex) { + throw ex; + } + catch (Exception ex) { + throw new ToolExecutionException(this.toolDefinition, ex); + } + } + @Override public String toString() { return "FunctionToolCallback{" + "toolDefinition=" + this.toolDefinition + ", toolMetadata=" + this.toolMetadata diff --git a/spring-ai-model/src/test/java/org/springframework/ai/tool/function/FunctionToolCallbackTest.java b/spring-ai-model/src/test/java/org/springframework/ai/tool/function/FunctionToolCallbackTest.java new file mode 100644 index 00000000000..4d6190305d1 --- /dev/null +++ b/spring-ai-model/src/test/java/org/springframework/ai/tool/function/FunctionToolCallbackTest.java @@ -0,0 +1,187 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.tool.function; + +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; + +import org.junit.jupiter.api.Test; + +import org.springframework.ai.chat.model.ToolContext; +import org.springframework.ai.tool.execution.ToolExecutionException; +import org.springframework.ai.tool.metadata.ToolMetadata; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.InstanceOfAssertFactories.type; +import static org.junit.jupiter.api.Assertions.assertEquals; + +/** + * @author YunKui Lu + */ +class FunctionToolCallbackTest { + + @Test + void testConsumerToolCall() { + TestFunctionTool tool = new TestFunctionTool(); + FunctionToolCallback callback = FunctionToolCallback.builder("testTool", tool.stringConsumer()) + .toolMetadata(ToolMetadata.builder().returnDirect(true).build()) + .description("test description") + .inputType(String.class) + .build(); + + callback.call("\"test string param\""); + + assertEquals("test string param", tool.calledValue.get()); + } + + @Test + void testBiFunctionToolCall() { + TestFunctionTool tool = new TestFunctionTool(); + FunctionToolCallback callback = FunctionToolCallback + .builder("testTool", tool.stringBiFunction()) + .toolMetadata(ToolMetadata.builder().returnDirect(true).build()) + .description("test description") + .inputType(String.class) + .build(); + + ToolContext toolContext = new ToolContext(Map.of("foo", "bar")); + + String callResult = callback.call("\"test string param\"", toolContext); + + assertEquals("test string param", tool.calledValue.get()); + assertEquals("\"return value = test string param\"", callResult); + assertEquals(toolContext, tool.calledToolContext.get()); + } + + @Test + void testFunctionToolCall() { + TestFunctionTool tool = new TestFunctionTool(); + FunctionToolCallback callback = FunctionToolCallback.builder("testTool", tool.stringFunction()) + .toolMetadata(ToolMetadata.builder().returnDirect(true).build()) + .description("test description") + .inputType(String.class) + .build(); + + ToolContext toolContext = new ToolContext(Map.of()); + + String callResult = callback.call("\"test string param\"", toolContext); + + assertEquals("test string param", tool.calledValue.get()); + assertEquals("\"return value = test string param\"", callResult); + } + + @Test + void testSupplierToolCall() { + TestFunctionTool tool = new TestFunctionTool(); + + FunctionToolCallback callback = FunctionToolCallback.builder("testTool", tool.stringSupplier()) + .toolMetadata(ToolMetadata.builder().returnDirect(true).build()) + .description("test description") + .inputType(Void.class) + .build(); + + ToolContext toolContext = new ToolContext(Map.of()); + + String callResult = callback.call("\"test string param\"", toolContext); + + assertEquals("not params", tool.calledValue.get()); + assertEquals("\"return value = \"", callResult); + } + + @Test + void testThrowRuntimeException() { + TestFunctionTool tool = new TestFunctionTool(); + FunctionToolCallback callback = FunctionToolCallback + .builder("testTool", tool.throwRuntimeException()) + .toolMetadata(ToolMetadata.builder().returnDirect(true).build()) + .description("test description") + .inputType(String.class) + .build(); + + assertThatThrownBy(() -> callback.call("\"test string param\"")).hasMessage("test exception") + .hasCauseInstanceOf(RuntimeException.class) + .asInstanceOf(type(ToolExecutionException.class)) + .extracting(ToolExecutionException::getToolDefinition) + .isEqualTo(callback.getToolDefinition()); + } + + @Test + void testThrowToolExecutionException() { + TestFunctionTool tool = new TestFunctionTool(); + FunctionToolCallback callback = FunctionToolCallback + .builder("testTool", tool.throwToolExecutionException()) + .toolMetadata(ToolMetadata.builder().returnDirect(true).build()) + .description("test description") + .inputType(String.class) + .build(); + + assertThatThrownBy(() -> callback.call("\"test string param\"")).hasMessage("test exception") + .hasCauseInstanceOf(RuntimeException.class) + .isInstanceOf(ToolExecutionException.class); + } + + static class TestFunctionTool { + + AtomicReference calledValue = new AtomicReference<>(); + + AtomicReference calledToolContext = new AtomicReference<>(); + + public Consumer stringConsumer() { + return s -> { + calledValue.set(s); + }; + } + + public BiFunction stringBiFunction() { + return (s, context) -> { + calledValue.set(s); + calledToolContext.set(context); + return "return value = " + s; + }; + } + + public Function stringFunction() { + return s -> { + calledValue.set(s); + return "return value = " + s; + }; + } + + public Supplier stringSupplier() { + calledValue.set("not params"); + return () -> "return value = "; + } + + public Consumer throwRuntimeException() { + return s -> { + throw new RuntimeException("test exception"); + }; + } + + public Consumer throwToolExecutionException() { + return s -> { + throw new ToolExecutionException(null, new RuntimeException("test exception")); + }; + } + + } + +}