Skip to content

Commit c6587b6

Browse files
committed
LangChainChatModel done
1 parent a9d1805 commit c6587b6

File tree

10 files changed

+321
-152
lines changed

10 files changed

+321
-152
lines changed

experimental/ai/impl/pom.xml

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,14 @@
1919
<groupId>io.serverlessworkflow</groupId>
2020
<artifactId>serverlessworkflow-impl-core</artifactId>
2121
</dependency>
22+
<dependency>
23+
<groupId>io.serverlessworkflow</groupId>
24+
<artifactId>serverlessworkflow-experimental-types</artifactId>
25+
</dependency>
26+
<dependency>
27+
<groupId>io.serverlessworkflow</groupId>
28+
<artifactId>serverlessworkflow-experimental-ai-types</artifactId>
29+
</dependency>
2230
<dependency>
2331
<groupId>dev.langchain4j</groupId>
2432
<artifactId>langchain4j</artifactId>
@@ -50,15 +58,5 @@
5058
<artifactId>logback-classic</artifactId>
5159
<scope>test</scope>
5260
</dependency>
53-
<dependency>
54-
<groupId>io.serverlessworkflow</groupId>
55-
<artifactId>serverlessworkflow-experimental-types</artifactId>
56-
</dependency>
57-
<dependency>
58-
<groupId>io.serverlessworkflow</groupId>
59-
<artifactId>serverlessworkflow-experimental-ai-types</artifactId>
60-
<version>8.0.0-SNAPSHOT</version>
61-
<scope>compile</scope>
62-
</dependency>
6361
</dependencies>
64-
</project>
62+
</project>

experimental/ai/impl/src/main/java/io/serverlessworkflow/impl/executors/ai/AIChatModelCallExecutor.java

Lines changed: 6 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,6 @@
1616

1717
package io.serverlessworkflow.impl.executors.ai;
1818

19-
import dev.langchain4j.data.message.AiMessage;
20-
import dev.langchain4j.data.message.ChatMessage;
21-
import dev.langchain4j.data.message.SystemMessage;
22-
import dev.langchain4j.data.message.UserMessage;
23-
import dev.langchain4j.model.chat.ChatModel;
24-
import dev.langchain4j.model.chat.response.ChatResponse;
25-
import dev.langchain4j.model.output.FinishReason;
26-
import dev.langchain4j.model.output.TokenUsage;
2719
import io.serverlessworkflow.ai.api.types.CallAILangChainChatModel;
2820
import io.serverlessworkflow.api.types.TaskBase;
2921
import io.serverlessworkflow.api.types.ai.CallAIChatModel;
@@ -34,21 +26,10 @@
3426
import io.serverlessworkflow.impl.WorkflowModelFactory;
3527
import io.serverlessworkflow.impl.executors.CallableTask;
3628
import io.serverlessworkflow.impl.resources.ResourceLoader;
37-
import io.serverlessworkflow.impl.services.ChatModelService;
38-
import java.util.ArrayList;
39-
import java.util.HashSet;
40-
import java.util.List;
41-
import java.util.Map;
42-
import java.util.ServiceLoader;
43-
import java.util.Set;
4429
import java.util.concurrent.CompletableFuture;
45-
import java.util.regex.Matcher;
46-
import java.util.regex.Pattern;
4730

4831
public class AIChatModelCallExecutor implements CallableTask<CallAIChatModel> {
4932

50-
private static final Pattern VARIABLE_PATTERN = Pattern.compile("\\{\\{\\s*(.+?)\\s*\\}\\}");
51-
5233
@Override
5334
public void init(CallAIChatModel task, WorkflowApplication application, ResourceLoader loader) {}
5435

@@ -58,12 +39,13 @@ public CompletableFuture<WorkflowModel> apply(
5839
WorkflowModelFactory modelFactory = workflowContext.definition().application().modelFactory();
5940
if (taskContext.task() instanceof CallAILangChainChatModel callAILangChainChatModel) {
6041
return CompletableFuture.completedFuture(
61-
modelFactory.fromAny(doCall(callAILangChainChatModel, input.asJavaObject())));
62-
}
63-
64-
if (taskContext.task() instanceof CallAIChatModel callAIChatModel) {
42+
modelFactory.fromAny(
43+
new CallAILangChainChatModelExecutor()
44+
.apply(callAILangChainChatModel, input.asJavaObject())));
45+
} else if (taskContext.task() instanceof CallAIChatModel callAIChatModel) {
6546
return CompletableFuture.completedFuture(
66-
modelFactory.fromAny(doCall(callAIChatModel, input.asJavaObject())));
47+
modelFactory.fromAny(
48+
new CallAIChatModelExecutor().apply(callAIChatModel, input.asJavaObject())));
6749
}
6850
throw new IllegalArgumentException(
6951
"AIChatModelCallExecutor can only process CallAIChatModel tasks, but received: "
@@ -74,112 +56,4 @@ public CompletableFuture<WorkflowModel> apply(
7456
public boolean accept(Class<? extends TaskBase> clazz) {
7557
return CallAIChatModel.class.isAssignableFrom(clazz);
7658
}
77-
78-
private Object doCall(CallAILangChainChatModel callAIChatModel, Object javaObject) {
79-
ChatModel chatModel = callAIChatModel.getChatModel();
80-
Class<?> chatModelRequest = callAIChatModel.getChatModelRequest();
81-
}
82-
83-
private Object doCall(CallAIChatModel callAIChatModel, Object javaObject) {
84-
validate(callAIChatModel, javaObject);
85-
ChatModel chatModel = createChatModel(callAIChatModel);
86-
Map<String, Object> substitutions = (Map<String, Object>) javaObject;
87-
88-
List<ChatMessage> messages = new ArrayList<>();
89-
90-
if (callAIChatModel.getChatModelRequest().getSystemMessages() != null) {
91-
for (String systemMessage : callAIChatModel.getChatModelRequest().getSystemMessages()) {
92-
String fixedUserMessage = replaceVariables(systemMessage, substitutions);
93-
messages.add(new SystemMessage(fixedUserMessage));
94-
}
95-
}
96-
97-
if (callAIChatModel.getChatModelRequest().getUserMessages() != null) {
98-
for (String userMessage : callAIChatModel.getChatModelRequest().getUserMessages()) {
99-
String fixedUserMessage = replaceVariables(userMessage, substitutions);
100-
messages.add(new UserMessage(fixedUserMessage));
101-
}
102-
}
103-
104-
return prepareResponse(chatModel.chat(messages), javaObject);
105-
}
106-
107-
private String replaceVariables(String template, Map<String, Object> substitutions) {
108-
Set<String> variables = extractVariables(template);
109-
for (Map.Entry<String, Object> entry : substitutions.entrySet()) {
110-
String variable = entry.getKey();
111-
Object value = entry.getValue();
112-
if (value != null && variables.contains(variable)) {
113-
template = template.replace("{{" + variable + "}}", value.toString());
114-
}
115-
}
116-
return template;
117-
}
118-
119-
private void validate(CallAIChatModel callAIChatModel, Object javaObject) {
120-
// TODO
121-
}
122-
123-
private ChatModel createChatModel(CallAIChatModel callAIChatModel) {
124-
ChatModelService chatModelService = getAvailableModel();
125-
if (chatModelService != null) {
126-
return chatModelService.getChatModel(callAIChatModel.getPreferences());
127-
}
128-
throw new IllegalStateException(
129-
"No LLM models found. Please ensure that you have the required dependencies in your classpath.");
130-
}
131-
132-
private ChatModelService getAvailableModel() {
133-
ServiceLoader<ChatModelService> loader = ServiceLoader.load(ChatModelService.class);
134-
135-
for (ChatModelService service : loader) {
136-
return service;
137-
}
138-
139-
throw new IllegalStateException(
140-
"No LLM models found. Please ensure that you have the required dependencies in your classpath.");
141-
}
142-
143-
private Map<String, Object> prepareResponse(ChatResponse response, Object javaObject) {
144-
145-
String id = response.id();
146-
String modelName = response.modelName();
147-
TokenUsage tokenUsage = response.tokenUsage();
148-
FinishReason finishReason = response.finishReason();
149-
AiMessage aiMessage = response.aiMessage();
150-
151-
Map<String, Object> responseMap = (Map<String, Object>) javaObject;
152-
if (response.id() != null) {
153-
responseMap.put("id", id);
154-
}
155-
156-
if (modelName != null) {
157-
responseMap.put("modelName", modelName);
158-
}
159-
160-
if (tokenUsage != null) {
161-
responseMap.put("tokenUsage.inputTokenCount", tokenUsage.inputTokenCount());
162-
responseMap.put("tokenUsage.outputTokenCount", tokenUsage.outputTokenCount());
163-
responseMap.put("tokenUsage.totalTokenCount", tokenUsage.totalTokenCount());
164-
}
165-
166-
if (finishReason != null) {
167-
responseMap.put("finishReason", finishReason.name());
168-
}
169-
170-
if (aiMessage != null) {
171-
responseMap.put("text", aiMessage.text());
172-
}
173-
174-
return responseMap;
175-
}
176-
177-
private static Set<String> extractVariables(String template) {
178-
Set<String> variables = new HashSet<>();
179-
Matcher matcher = VARIABLE_PATTERN.matcher(template);
180-
while (matcher.find()) {
181-
variables.add(matcher.group(1));
182-
}
183-
return variables;
184-
}
18559
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/*
2+
* Copyright 2020-Present The Serverless Workflow Specification Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.serverlessworkflow.impl.executors.ai;
18+
19+
import dev.langchain4j.data.message.AiMessage;
20+
import dev.langchain4j.model.chat.response.ChatResponse;
21+
import dev.langchain4j.model.output.FinishReason;
22+
import dev.langchain4j.model.output.TokenUsage;
23+
import java.util.Map;
24+
25+
public abstract class AbstractCallAIChatModelExecutor<T> {
26+
27+
public abstract Object apply(T callAIChatModel, Object javaObject);
28+
29+
protected Map<String, Object> prepareResponse(ChatResponse response, Object javaObject) {
30+
String id = response.id();
31+
String modelName = response.modelName();
32+
TokenUsage tokenUsage = response.tokenUsage();
33+
FinishReason finishReason = response.finishReason();
34+
AiMessage aiMessage = response.aiMessage();
35+
36+
Map<String, Object> responseMap = (Map<String, Object>) javaObject;
37+
if (response.id() != null) {
38+
responseMap.put("id", id);
39+
}
40+
41+
if (modelName != null) {
42+
responseMap.put("modelName", modelName);
43+
}
44+
45+
if (tokenUsage != null) {
46+
responseMap.put("tokenUsage.inputTokenCount", tokenUsage.inputTokenCount());
47+
responseMap.put("tokenUsage.outputTokenCount", tokenUsage.outputTokenCount());
48+
responseMap.put("tokenUsage.totalTokenCount", tokenUsage.totalTokenCount());
49+
}
50+
51+
if (finishReason != null) {
52+
responseMap.put("finishReason", finishReason.name());
53+
}
54+
55+
if (aiMessage != null) {
56+
responseMap.put("text", aiMessage.text());
57+
}
58+
59+
return responseMap;
60+
}
61+
}
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/*
2+
* Copyright 2020-Present The Serverless Workflow Specification Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.serverlessworkflow.impl.executors.ai;
18+
19+
import dev.langchain4j.data.message.ChatMessage;
20+
import dev.langchain4j.data.message.SystemMessage;
21+
import dev.langchain4j.data.message.UserMessage;
22+
import dev.langchain4j.model.chat.ChatModel;
23+
import io.serverlessworkflow.api.types.ai.CallAIChatModel;
24+
import io.serverlessworkflow.impl.services.ChatModelService;
25+
import java.util.ArrayList;
26+
import java.util.HashSet;
27+
import java.util.List;
28+
import java.util.Map;
29+
import java.util.ServiceLoader;
30+
import java.util.Set;
31+
import java.util.regex.Matcher;
32+
import java.util.regex.Pattern;
33+
34+
public class CallAIChatModelExecutor extends AbstractCallAIChatModelExecutor<CallAIChatModel> {
35+
36+
private static final Pattern VARIABLE_PATTERN = Pattern.compile("\\{\\{\\s*(.+?)\\s*\\}\\}");
37+
38+
@Override
39+
public Object apply(CallAIChatModel callAIChatModel, Object javaObject) {
40+
validate(callAIChatModel, javaObject);
41+
42+
ChatModel chatModel = createChatModel(callAIChatModel);
43+
Map<String, Object> substitutions = (Map<String, Object>) javaObject;
44+
45+
List<ChatMessage> messages = new ArrayList<>();
46+
47+
if (callAIChatModel.getChatModelRequest().getSystemMessages() != null) {
48+
for (String systemMessage : callAIChatModel.getChatModelRequest().getSystemMessages()) {
49+
String fixedUserMessage = replaceVariables(systemMessage, substitutions);
50+
messages.add(new SystemMessage(fixedUserMessage));
51+
}
52+
}
53+
54+
if (callAIChatModel.getChatModelRequest().getUserMessages() != null) {
55+
for (String userMessage : callAIChatModel.getChatModelRequest().getUserMessages()) {
56+
String fixedUserMessage = replaceVariables(userMessage, substitutions);
57+
messages.add(new UserMessage(fixedUserMessage));
58+
}
59+
}
60+
61+
return prepareResponse(chatModel.chat(messages), javaObject);
62+
}
63+
64+
private ChatModel createChatModel(CallAIChatModel callAIChatModel) {
65+
ChatModelService chatModelService = getAvailableModel();
66+
if (chatModelService != null) {
67+
return chatModelService.getChatModel(callAIChatModel.getPreferences());
68+
}
69+
throw new IllegalStateException(
70+
"No LLM models found. Please ensure that you have the required dependencies in your classpath.");
71+
}
72+
73+
private String replaceVariables(String template, Map<String, Object> substitutions) {
74+
Set<String> variables = extractVariables(template);
75+
for (Map.Entry<String, Object> entry : substitutions.entrySet()) {
76+
String variable = entry.getKey();
77+
Object value = entry.getValue();
78+
if (value != null && variables.contains(variable)) {
79+
template = template.replace("{{" + variable + "}}", value.toString());
80+
}
81+
}
82+
return template;
83+
}
84+
85+
private ChatModelService getAvailableModel() {
86+
ServiceLoader<ChatModelService> loader = ServiceLoader.load(ChatModelService.class);
87+
88+
for (ChatModelService service : loader) {
89+
return service;
90+
}
91+
92+
throw new IllegalStateException(
93+
"No LLM models found. Please ensure that you have the required dependencies in your classpath.");
94+
}
95+
96+
private static Set<String> extractVariables(String template) {
97+
Set<String> variables = new HashSet<>();
98+
Matcher matcher = VARIABLE_PATTERN.matcher(template);
99+
while (matcher.find()) {
100+
variables.add(matcher.group(1));
101+
}
102+
return variables;
103+
}
104+
105+
private void validate(CallAIChatModel callAIChatModel, Object javaObject) {
106+
// TODO
107+
}
108+
}

0 commit comments

Comments
 (0)