Skip to content

Commit 6270d62

Browse files
committed
ChatClient register functions with explicit input type
The Lambda functions do not retain the type information, so we need to provide the input type explicitly. Resolves #1052 Co-authored-by: liuzhifei <[email protected]>
1 parent 03d1d50 commit 6270d62

File tree

3 files changed

+55
-0
lines changed

3 files changed

+55
-0
lines changed

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMultipleFunctionCallsIT.java

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,17 @@
1717

1818
import static org.assertj.core.api.Assertions.assertThat;
1919

20+
import java.lang.reflect.Method;
2021
import java.util.List;
22+
import java.util.function.Function;
2123
import java.util.stream.Collectors;
2224

2325
import org.junit.jupiter.api.Test;
2426
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
2527
import org.slf4j.Logger;
2628
import org.slf4j.LoggerFactory;
2729
import org.springframework.ai.chat.client.ChatClient;
30+
import org.springframework.ai.chat.client.DefaultChatClient;
2831
import org.springframework.ai.openai.OpenAiTestConfiguration;
2932
import org.springframework.ai.openai.api.tool.MockWeatherService;
3033
import org.springframework.ai.openai.testutils.AbstractIT;
@@ -123,4 +126,47 @@ void streamFunctionCallTest() {
123126

124127
}
125128

129+
@Test
130+
void functionCallWithExplicitInputType() throws NoSuchMethodException {
131+
132+
var chatClient = ChatClient.create(chatModel);
133+
134+
Method currentTemp = MyFunction.class.getMethod("getCurrentTemp", MyFunction.Req.class);
135+
136+
// NOTE: Lambda functions do not retain the type information, so we need to
137+
// provide the input type explicitly.
138+
MyFunction myFunction = new MyFunction();
139+
Function<MyFunction.Req, Object> function = createFunction(myFunction, currentTemp);
140+
141+
ChatClient.ChatClientRequestSpec chatClientRequestSpec = chatClient.prompt()
142+
.user("What's the weather like in Shanghai?")
143+
.function("currentTemp", "get current temp", MyFunction.Req.class, function);
144+
145+
String content = chatClientRequestSpec.call().content();
146+
147+
assertThat(content).contains("23");
148+
}
149+
150+
public static <T, R> Function<T, R> createFunction(Object obj, Method method) {
151+
return (T t) -> {
152+
try {
153+
return (R) method.invoke(obj, t);
154+
}
155+
catch (Exception e) {
156+
throw new RuntimeException(e);
157+
}
158+
};
159+
}
160+
161+
public static class MyFunction {
162+
163+
public record Req(String city) {
164+
}
165+
166+
public String getCurrentTemp(Req req) {
167+
return "23";
168+
}
169+
170+
}
171+
126172
}

spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,9 @@ interface ChatClientRequestSpec {
189189
<I, O> ChatClientRequestSpec function(String name, String description,
190190
java.util.function.Function<I, O> function);
191191

192+
<I, O> ChatClientRequestSpec function(String name, String description, Class<I> inputType,
193+
java.util.function.Function<I, O> function);
194+
192195
ChatClientRequestSpec functions(String... functionBeanNames);
193196

194197
ChatClientRequestSpec system(String text);

spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,11 @@ public <T extends ChatOptions> ChatClientRequestSpec options(T options) {
610610

611611
public <I, O> ChatClientRequestSpec function(String name, String description,
612612
java.util.function.Function<I, O> function) {
613+
return this.function(name, description, null, function);
614+
}
615+
616+
public <I, O> ChatClientRequestSpec function(String name, String description, Class<I> inputType,
617+
java.util.function.Function<I, O> function) {
613618

614619
Assert.hasText(name, "the name must be non-null and non-empty");
615620
Assert.hasText(description, "the description must be non-null and non-empty");
@@ -618,6 +623,7 @@ public <I, O> ChatClientRequestSpec function(String name, String description,
618623
var fcw = FunctionCallbackWrapper.builder(function)
619624
.withDescription(description)
620625
.withName(name)
626+
.withInputType(inputType)
621627
.withResponseConverter(Object::toString)
622628
.build();
623629
this.functionCallbacks.add(fcw);

0 commit comments

Comments
 (0)