Skip to content

Commit a8a4b60

Browse files
committed
CallAILangChainChatModel works
1 parent c6587b6 commit a8a4b60

File tree

4 files changed

+55
-58
lines changed

4 files changed

+55
-58
lines changed

experimental/ai/impl/src/main/java/io/serverlessworkflow/impl/executors/ai/AIChatModelCallExecutor.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import io.serverlessworkflow.ai.api.types.CallAILangChainChatModel;
2020
import io.serverlessworkflow.api.types.TaskBase;
21+
import io.serverlessworkflow.api.types.ai.AbstractCallAIChatModelTask;
2122
import io.serverlessworkflow.api.types.ai.CallAIChatModel;
2223
import io.serverlessworkflow.impl.TaskContext;
2324
import io.serverlessworkflow.impl.WorkflowApplication;
@@ -28,10 +29,11 @@
2829
import io.serverlessworkflow.impl.resources.ResourceLoader;
2930
import java.util.concurrent.CompletableFuture;
3031

31-
public class AIChatModelCallExecutor implements CallableTask<CallAIChatModel> {
32+
public class AIChatModelCallExecutor implements CallableTask<AbstractCallAIChatModelTask> {
3233

3334
@Override
34-
public void init(CallAIChatModel task, WorkflowApplication application, ResourceLoader loader) {}
35+
public void init(
36+
AbstractCallAIChatModelTask task, WorkflowApplication application, ResourceLoader loader) {}
3537

3638
@Override
3739
public CompletableFuture<WorkflowModel> apply(
@@ -54,6 +56,6 @@ public CompletableFuture<WorkflowModel> apply(
5456

5557
@Override
5658
public boolean accept(Class<? extends TaskBase> clazz) {
57-
return CallAIChatModel.class.isAssignableFrom(clazz);
59+
return AbstractCallAIChatModelTask.class.isAssignableFrom(clazz);
5860
}
5961
}

experimental/ai/impl/src/main/java/io/serverlessworkflow/impl/executors/ai/AbstractCallAIChatModelExecutor.java

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -16,46 +16,7 @@
1616

1717
package io.serverlessworkflow.impl.executors.ai;
1818

19-
import dev.langchain4j.data.message.AiMessage;
20-
import dev.langchain4j.model.chat.response.ChatResponse;
21-
import dev.langchain4j.model.output.FinishReason;
22-
import dev.langchain4j.model.output.TokenUsage;
23-
import java.util.Map;
24-
2519
public abstract class AbstractCallAIChatModelExecutor<T> {
2620

2721
public abstract Object apply(T callAIChatModel, Object javaObject);
28-
29-
protected Map<String, Object> prepareResponse(ChatResponse response, Object javaObject) {
30-
String id = response.id();
31-
String modelName = response.modelName();
32-
TokenUsage tokenUsage = response.tokenUsage();
33-
FinishReason finishReason = response.finishReason();
34-
AiMessage aiMessage = response.aiMessage();
35-
36-
Map<String, Object> responseMap = (Map<String, Object>) javaObject;
37-
if (response.id() != null) {
38-
responseMap.put("id", id);
39-
}
40-
41-
if (modelName != null) {
42-
responseMap.put("modelName", modelName);
43-
}
44-
45-
if (tokenUsage != null) {
46-
responseMap.put("tokenUsage.inputTokenCount", tokenUsage.inputTokenCount());
47-
responseMap.put("tokenUsage.outputTokenCount", tokenUsage.outputTokenCount());
48-
responseMap.put("tokenUsage.totalTokenCount", tokenUsage.totalTokenCount());
49-
}
50-
51-
if (finishReason != null) {
52-
responseMap.put("finishReason", finishReason.name());
53-
}
54-
55-
if (aiMessage != null) {
56-
responseMap.put("text", aiMessage.text());
57-
}
58-
59-
return responseMap;
60-
}
6122
}

experimental/ai/impl/src/main/java/io/serverlessworkflow/impl/executors/ai/CallAIChatModelExecutor.java

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,14 @@
1616

1717
package io.serverlessworkflow.impl.executors.ai;
1818

19+
import dev.langchain4j.data.message.AiMessage;
1920
import dev.langchain4j.data.message.ChatMessage;
2021
import dev.langchain4j.data.message.SystemMessage;
2122
import dev.langchain4j.data.message.UserMessage;
2223
import dev.langchain4j.model.chat.ChatModel;
24+
import dev.langchain4j.model.chat.response.ChatResponse;
25+
import dev.langchain4j.model.output.FinishReason;
26+
import dev.langchain4j.model.output.TokenUsage;
2327
import io.serverlessworkflow.api.types.ai.CallAIChatModel;
2428
import io.serverlessworkflow.impl.services.ChatModelService;
2529
import java.util.ArrayList;
@@ -105,4 +109,37 @@ private static Set<String> extractVariables(String template) {
105109
private void validate(CallAIChatModel callAIChatModel, Object javaObject) {
106110
// TODO
107111
}
112+
113+
protected Map<String, Object> prepareResponse(ChatResponse response, Object javaObject) {
114+
String id = response.id();
115+
String modelName = response.modelName();
116+
TokenUsage tokenUsage = response.tokenUsage();
117+
FinishReason finishReason = response.finishReason();
118+
AiMessage aiMessage = response.aiMessage();
119+
120+
Map<String, Object> responseMap = (Map<String, Object>) javaObject;
121+
if (response.id() != null) {
122+
responseMap.put("id", id);
123+
}
124+
125+
if (modelName != null) {
126+
responseMap.put("modelName", modelName);
127+
}
128+
129+
if (tokenUsage != null) {
130+
responseMap.put("tokenUsage.inputTokenCount", tokenUsage.inputTokenCount());
131+
responseMap.put("tokenUsage.outputTokenCount", tokenUsage.outputTokenCount());
132+
responseMap.put("tokenUsage.totalTokenCount", tokenUsage.totalTokenCount());
133+
}
134+
135+
if (finishReason != null) {
136+
responseMap.put("finishReason", finishReason.name());
137+
}
138+
139+
if (aiMessage != null) {
140+
responseMap.put("text", aiMessage.text());
141+
}
142+
143+
return responseMap;
144+
}
108145
}

experimental/ai/impl/src/main/java/io/serverlessworkflow/impl/executors/ai/CallAILangChainChatModelExecutor.java

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
package io.serverlessworkflow.impl.executors.ai;
1818

1919
import dev.langchain4j.model.chat.ChatModel;
20-
import dev.langchain4j.model.chat.response.ChatResponse;
2120
import dev.langchain4j.service.AiServices;
2221
import dev.langchain4j.service.V;
2322
import io.serverlessworkflow.ai.api.types.CallAILangChainChatModel;
@@ -49,8 +48,10 @@ public Object apply(CallAILangChainChatModel callAIChatModel, Object javaObject)
4948
}
5049

5150
Object response = method.invoke(aiServices, args);
52-
if (response instanceof ChatResponse chatResponse) {
53-
return prepareResponse(chatResponse, substitutions);
51+
52+
if (response instanceof String chatResponse) {
53+
substitutions.put("text", chatResponse);
54+
return substitutions;
5455
} else {
5556
throw new IllegalArgumentException(
5657
"Method " + method.getName() + " did not return a ChatResponse");
@@ -65,7 +66,7 @@ private void validate(
6566

6667
private Method getMethod(Class<?> chatModelRequest, String methodName) {
6768
for (Method method : chatModelRequest.getMethods()) {
68-
if (method.getName().equals("methodName")) {
69+
if (method.getName().equals(methodName)) {
6970
return method;
7071
}
7172
}
@@ -76,20 +77,16 @@ private Method getMethod(Class<?> chatModelRequest, String methodName) {
7677
private List<String> resolvedParameters(Method method, Map<String, Object> substitutions) {
7778
List<String> resolvedParameters = new ArrayList<>();
7879
for (Parameter parameter : method.getParameters()) {
79-
String paramName = resolveParameter(parameter);
80-
if (substitutions.containsKey(paramName)) {
81-
resolvedParameters.add(paramName);
82-
} else {
83-
throw new IllegalArgumentException("Missing substitution for parameter: " + paramName);
80+
if (parameter.getAnnotation(V.class) != null) {
81+
V v = parameter.getAnnotation(V.class);
82+
String paramName = v.value();
83+
if (substitutions.containsKey(paramName)) {
84+
resolvedParameters.add(paramName);
85+
} else {
86+
throw new IllegalArgumentException("Missing substitution for parameter: " + paramName);
87+
}
8488
}
8589
}
8690
return resolvedParameters;
8791
}
88-
89-
private String resolveParameter(Parameter parameter) {
90-
if (parameter.getAnnotation(V.class) != null) {
91-
return parameter.getName();
92-
}
93-
return parameter.getName();
94-
}
9592
}

0 commit comments

Comments
 (0)