Skip to content

Commit a9d1805

Browse files
committed
langchain support task
1 parent 4a0a2a7 commit a9d1805

File tree

6 files changed

+141
-22
lines changed

6 files changed

+141
-22
lines changed

experimental/ai/impl/pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,5 +54,11 @@
5454
<groupId>io.serverlessworkflow</groupId>
5555
<artifactId>serverlessworkflow-experimental-types</artifactId>
5656
</dependency>
57+
<dependency>
58+
<groupId>io.serverlessworkflow</groupId>
59+
<artifactId>serverlessworkflow-experimental-ai-types</artifactId>
60+
<version>8.0.0-SNAPSHOT</version>
61+
<scope>compile</scope>
62+
</dependency>
5763
</dependencies>
5864
</project>

experimental/ai/impl/src/main/java/io/serverlessworkflow/impl/executors/ai/AIChatModelCallExecutor.java

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,6 @@
1616

1717
package io.serverlessworkflow.impl.executors.ai;
1818

19-
import java.util.ArrayList;
20-
import java.util.HashSet;
21-
import java.util.List;
22-
import java.util.Map;
23-
import java.util.ServiceLoader;
24-
import java.util.Set;
25-
import java.util.concurrent.CompletableFuture;
26-
import java.util.regex.Matcher;
27-
import java.util.regex.Pattern;
28-
2919
import dev.langchain4j.data.message.AiMessage;
3020
import dev.langchain4j.data.message.ChatMessage;
3121
import dev.langchain4j.data.message.SystemMessage;
@@ -34,6 +24,7 @@
3424
import dev.langchain4j.model.chat.response.ChatResponse;
3525
import dev.langchain4j.model.output.FinishReason;
3626
import dev.langchain4j.model.output.TokenUsage;
27+
import io.serverlessworkflow.ai.api.types.CallAILangChainChatModel;
3728
import io.serverlessworkflow.api.types.TaskBase;
3829
import io.serverlessworkflow.api.types.ai.CallAIChatModel;
3930
import io.serverlessworkflow.impl.TaskContext;
@@ -44,31 +35,51 @@
4435
import io.serverlessworkflow.impl.executors.CallableTask;
4536
import io.serverlessworkflow.impl.resources.ResourceLoader;
4637
import io.serverlessworkflow.impl.services.ChatModelService;
47-
38+
import java.util.ArrayList;
39+
import java.util.HashSet;
40+
import java.util.List;
41+
import java.util.Map;
42+
import java.util.ServiceLoader;
43+
import java.util.Set;
44+
import java.util.concurrent.CompletableFuture;
45+
import java.util.regex.Matcher;
46+
import java.util.regex.Pattern;
4847

4948
public class AIChatModelCallExecutor implements CallableTask<CallAIChatModel> {
5049

5150
private static final Pattern VARIABLE_PATTERN = Pattern.compile("\\{\\{\\s*(.+?)\\s*\\}\\}");
5251

5352
@Override
54-
public void init(CallAIChatModel task, WorkflowApplication application, ResourceLoader loader) {
55-
56-
}
53+
public void init(CallAIChatModel task, WorkflowApplication application, ResourceLoader loader) {}
5754

5855
@Override
59-
public CompletableFuture<WorkflowModel> apply(WorkflowContext workflowContext, TaskContext taskContext, WorkflowModel input) {
56+
public CompletableFuture<WorkflowModel> apply(
57+
WorkflowContext workflowContext, TaskContext taskContext, WorkflowModel input) {
6058
WorkflowModelFactory modelFactory = workflowContext.definition().application().modelFactory();
59+
if (taskContext.task() instanceof CallAILangChainChatModel callAILangChainChatModel) {
60+
return CompletableFuture.completedFuture(
61+
modelFactory.fromAny(doCall(callAILangChainChatModel, input.asJavaObject())));
62+
}
63+
6164
if (taskContext.task() instanceof CallAIChatModel callAIChatModel) {
62-
return CompletableFuture.completedFuture(modelFactory.fromAny(doCall(callAIChatModel, input.asJavaObject())));
65+
return CompletableFuture.completedFuture(
66+
modelFactory.fromAny(doCall(callAIChatModel, input.asJavaObject())));
6367
}
64-
throw new IllegalArgumentException("AIChatModelCallExecutor can only process CallAIChatModel tasks, but received: " + taskContext.task().getClass().getName());
68+
throw new IllegalArgumentException(
69+
"AIChatModelCallExecutor can only process CallAIChatModel tasks, but received: "
70+
+ taskContext.task().getClass().getName());
6571
}
6672

6773
@Override
6874
public boolean accept(Class<? extends TaskBase> clazz) {
6975
return CallAIChatModel.class.isAssignableFrom(clazz);
7076
}
7177

78+
private Object doCall(CallAILangChainChatModel callAIChatModel, Object javaObject) {
79+
ChatModel chatModel = callAIChatModel.getChatModel();
80+
Class<?> chatModelRequest = callAIChatModel.getChatModelRequest();
81+
}
82+
7283
private Object doCall(CallAIChatModel callAIChatModel, Object javaObject) {
7384
validate(callAIChatModel, javaObject);
7485
ChatModel chatModel = createChatModel(callAIChatModel);
@@ -114,7 +125,8 @@ private ChatModel createChatModel(CallAIChatModel callAIChatModel) {
114125
if (chatModelService != null) {
115126
return chatModelService.getChatModel(callAIChatModel.getPreferences());
116127
}
117-
throw new IllegalStateException("No LLM models found. Please ensure that you have the required dependencies in your classpath.");
128+
throw new IllegalStateException(
129+
"No LLM models found. Please ensure that you have the required dependencies in your classpath.");
118130
}
119131

120132
private ChatModelService getAvailableModel() {
@@ -124,7 +136,8 @@ private ChatModelService getAvailableModel() {
124136
return service;
125137
}
126138

127-
throw new IllegalStateException("No LLM models found. Please ensure that you have the required dependencies in your classpath.");
139+
throw new IllegalStateException(
140+
"No LLM models found. Please ensure that you have the required dependencies in your classpath.");
128141
}
129142

130143
private Map<String, Object> prepareResponse(ChatResponse response, Object javaObject) {

experimental/ai/pom.xml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,7 @@
1313
<packaging>pom</packaging>
1414
<modules>
1515
<module>impl</module>
16+
<module>types</module>
17+
<module>models</module>
1618
</modules>
1719
</project>

experimental/ai/types/pom.xml

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<project xmlns="http://maven.apache.org/POM/4.0.0"
3+
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
4+
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
5+
<modelVersion>4.0.0</modelVersion>
6+
<parent>
7+
<groupId>io.serverlessworkflow</groupId>
8+
<artifactId>serverlessworkflow-experimental-ai-parent</artifactId>
9+
<version>8.0.0-SNAPSHOT</version>
10+
</parent>
11+
<artifactId>serverlessworkflow-experimental-ai-types</artifactId>
12+
<name>ServelessWorkflow:: Experimental:: AI:: Types</name>
13+
14+
<dependencies>
15+
<dependency>
16+
<groupId>io.serverlessworkflow</groupId>
17+
<artifactId>serverlessworkflow-experimental-types</artifactId>
18+
</dependency>
19+
<dependency>
20+
<groupId>io.serverlessworkflow</groupId>
21+
<artifactId>serverlessworkflow-impl-core</artifactId>
22+
</dependency>
23+
<dependency>
24+
<groupId>dev.langchain4j</groupId>
25+
<artifactId>langchain4j</artifactId>
26+
<version>1.1.0</version>
27+
<optional>true</optional>
28+
</dependency>
29+
<dependency>
30+
<groupId>org.junit.jupiter</groupId>
31+
<artifactId>junit-jupiter-api</artifactId>
32+
<scope>test</scope>
33+
</dependency>
34+
<dependency>
35+
<groupId>org.junit.jupiter</groupId>
36+
<artifactId>junit-jupiter-engine</artifactId>
37+
<scope>test</scope>
38+
</dependency>
39+
<dependency>
40+
<groupId>org.junit.jupiter</groupId>
41+
<artifactId>junit-jupiter-params</artifactId>
42+
<scope>test</scope>
43+
</dependency>
44+
<dependency>
45+
<groupId>org.assertj</groupId>
46+
<artifactId>assertj-core</artifactId>
47+
<scope>test</scope>
48+
</dependency>
49+
<dependency>
50+
<groupId>ch.qos.logback</groupId>
51+
<artifactId>logback-classic</artifactId>
52+
<scope>test</scope>
53+
</dependency>
54+
<dependency>
55+
<groupId>io.serverlessworkflow</groupId>
56+
<artifactId>serverlessworkflow-experimental-types</artifactId>
57+
</dependency>
58+
</dependencies>
59+
60+
</project>
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/*
2+
* Copyright 2020-Present The Serverless Workflow Specification Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.serverlessworkflow.ai.api.types;
18+
19+
import dev.langchain4j.model.chat.ChatModel;
20+
import io.serverlessworkflow.api.types.TaskBase;
21+
22+
public class CallAILangChainChatModel extends TaskBase {
23+
24+
private final ChatModel chatModel;
25+
private final Class<?> chatModelRequest;
26+
27+
public CallAILangChainChatModel(ChatModel chatModel, Class<?> chatModelRequest) {
28+
this.chatModel = chatModel;
29+
this.chatModelRequest = chatModelRequest;
30+
}
31+
32+
public ChatModel getChatModel() {
33+
return chatModel;
34+
}
35+
36+
public Class<?> getChatModelRequest() {
37+
return chatModelRequest;
38+
}
39+
}

experimental/types/src/main/java/io/serverlessworkflow/api/types/ai/CallAIChatModel.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,19 @@
1616

1717
package io.serverlessworkflow.api.types.ai;
1818

19+
import io.serverlessworkflow.api.types.TaskBase;
1920
import java.time.Duration;
2021
import java.util.ArrayList;
2122
import java.util.Collection;
2223
import java.util.List;
2324

24-
import io.serverlessworkflow.api.types.TaskBase;
25-
2625
public class CallAIChatModel extends TaskBase {
2726

2827
private ChatModelPreferences chatModelPreferences;
2928

3029
private ChatModelRequest chatModelRequest;
3130

32-
private CallAIChatModel() {}
31+
protected CallAIChatModel() {}
3332

3433
public static Builder builder() {
3534
return new Builder();

0 commit comments

Comments
 (0)