-
Notifications
You must be signed in to change notification settings - Fork 50
Initial CallChatModel implementation #668
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
4a0a2a7
a9d1805
c6587b6
a8a4b60
b3da32d
df44f19
4f19d52
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
<?xml version="1.0" encoding="UTF-8"?> | ||
<project xmlns="http://maven.apache.org/POM/4.0.0" | ||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" | ||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> | ||
<modelVersion>4.0.0</modelVersion> | ||
<parent> | ||
<groupId>io.serverlessworkflow</groupId> | ||
<artifactId>serverlessworkflow-experimental-ai-parent</artifactId> | ||
<version>8.0.0-SNAPSHOT</version> | ||
</parent> | ||
<artifactId>serverlessworkflow-experimental-ai-impl</artifactId> | ||
<name>ServelessWorkflow:: Experimental:: AI:: Impl</name> | ||
<dependencies> | ||
<dependency> | ||
<groupId>io.serverlessworkflow</groupId> | ||
<artifactId>serverlessworkflow-experimental-types</artifactId> | ||
</dependency> | ||
<dependency> | ||
<groupId>io.serverlessworkflow</groupId> | ||
<artifactId>serverlessworkflow-impl-core</artifactId> | ||
</dependency> | ||
<dependency> | ||
<groupId>io.serverlessworkflow</groupId> | ||
<artifactId>serverlessworkflow-experimental-types</artifactId> | ||
</dependency> | ||
<dependency> | ||
<groupId>io.serverlessworkflow</groupId> | ||
<artifactId>serverlessworkflow-experimental-ai-types</artifactId> | ||
</dependency> | ||
<dependency> | ||
<groupId>dev.langchain4j</groupId> | ||
<artifactId>langchain4j</artifactId> | ||
<version>1.1.0</version> | ||
<optional>true</optional> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why optional? |
||
</dependency> | ||
<dependency> | ||
<groupId>org.junit.jupiter</groupId> | ||
<artifactId>junit-jupiter-api</artifactId> | ||
<scope>test</scope> | ||
</dependency> | ||
<dependency> | ||
<groupId>org.junit.jupiter</groupId> | ||
<artifactId>junit-jupiter-engine</artifactId> | ||
<scope>test</scope> | ||
</dependency> | ||
<dependency> | ||
<groupId>org.junit.jupiter</groupId> | ||
<artifactId>junit-jupiter-params</artifactId> | ||
<scope>test</scope> | ||
</dependency> | ||
<dependency> | ||
<groupId>org.assertj</groupId> | ||
<artifactId>assertj-core</artifactId> | ||
<scope>test</scope> | ||
</dependency> | ||
<dependency> | ||
<groupId>ch.qos.logback</groupId> | ||
<artifactId>logback-classic</artifactId> | ||
<scope>test</scope> | ||
</dependency> | ||
</dependencies> | ||
</project> |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
/* | ||
* Copyright 2020-Present The Serverless Workflow Specification Authors | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package io.serverlessworkflow.impl.executors.ai; | ||
|
||
import io.serverlessworkflow.ai.api.types.CallAILangChainChatModel; | ||
import io.serverlessworkflow.api.types.TaskBase; | ||
import io.serverlessworkflow.api.types.ai.AbstractCallAIChatModelTask; | ||
import io.serverlessworkflow.api.types.ai.CallAIChatModel; | ||
import io.serverlessworkflow.impl.TaskContext; | ||
import io.serverlessworkflow.impl.WorkflowApplication; | ||
import io.serverlessworkflow.impl.WorkflowContext; | ||
import io.serverlessworkflow.impl.WorkflowModel; | ||
import io.serverlessworkflow.impl.WorkflowModelFactory; | ||
import io.serverlessworkflow.impl.executors.CallableTask; | ||
import io.serverlessworkflow.impl.resources.ResourceLoader; | ||
import java.util.concurrent.CompletableFuture; | ||
|
||
public class AIChatModelCallExecutor implements CallableTask<AbstractCallAIChatModelTask> { | ||
|
||
@Override | ||
public void init( | ||
AbstractCallAIChatModelTask task, WorkflowApplication application, ResourceLoader loader) {} | ||
|
||
@Override | ||
public CompletableFuture<WorkflowModel> apply( | ||
WorkflowContext workflowContext, TaskContext taskContext, WorkflowModel input) { | ||
WorkflowModelFactory modelFactory = workflowContext.definition().application().modelFactory(); | ||
if (taskContext.task() instanceof CallAILangChainChatModel callAILangChainChatModel) { | ||
return CompletableFuture.completedFuture( | ||
modelFactory.fromAny( | ||
new CallAILangChainChatModelExecutor() | ||
.apply(callAILangChainChatModel, input.asJavaObject()))); | ||
} else if (taskContext.task() instanceof CallAIChatModel callAIChatModel) { | ||
return CompletableFuture.completedFuture( | ||
modelFactory.fromAny( | ||
new CallAIChatModelExecutor().apply(callAIChatModel, input.asJavaObject()))); | ||
} | ||
throw new IllegalArgumentException( | ||
"AIChatModelCallExecutor can only process CallAIChatModel tasks, but received: " | ||
+ taskContext.task().getClass().getName()); | ||
} | ||
|
||
@Override | ||
public boolean accept(Class<? extends TaskBase> clazz) { | ||
return AbstractCallAIChatModelTask.class.isAssignableFrom(clazz); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
/* | ||
* Copyright 2020-Present The Serverless Workflow Specification Authors | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package io.serverlessworkflow.impl.executors.ai; | ||
|
||
import io.serverlessworkflow.impl.executors.DefaultTaskExecutorFactory; | ||
|
||
public class AIChatModelTaskExecutorFactory extends DefaultTaskExecutorFactory {} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
/* | ||
* Copyright 2020-Present The Serverless Workflow Specification Authors | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package io.serverlessworkflow.impl.executors.ai; | ||
|
||
public abstract class AbstractCallAIChatModelExecutor<T> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be an interface, not an abstract class. so it should be renamed as AIChatModelExecutor |
||
|
||
public abstract Object apply(T callAIChatModel, Object javaObject); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
/* | ||
* Copyright 2020-Present The Serverless Workflow Specification Authors | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package io.serverlessworkflow.impl.executors.ai; | ||
|
||
import dev.langchain4j.data.message.AiMessage; | ||
import dev.langchain4j.data.message.ChatMessage; | ||
import dev.langchain4j.data.message.SystemMessage; | ||
import dev.langchain4j.data.message.UserMessage; | ||
import dev.langchain4j.model.chat.ChatModel; | ||
import dev.langchain4j.model.chat.response.ChatResponse; | ||
import dev.langchain4j.model.output.FinishReason; | ||
import dev.langchain4j.model.output.TokenUsage; | ||
import io.serverlessworkflow.api.types.ai.CallAIChatModel; | ||
import io.serverlessworkflow.impl.services.ChatModelService; | ||
import java.util.ArrayList; | ||
import java.util.HashSet; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.ServiceLoader; | ||
import java.util.Set; | ||
import java.util.regex.Matcher; | ||
import java.util.regex.Pattern; | ||
|
||
public class CallAIChatModelExecutor extends AbstractCallAIChatModelExecutor<CallAIChatModel> { | ||
|
||
private static final Pattern VARIABLE_PATTERN = Pattern.compile("\\{\\{\\s*(.+?)\\s*\\}\\}"); | ||
|
||
@Override | ||
public Object apply(CallAIChatModel callAIChatModel, Object javaObject) { | ||
validate(callAIChatModel, javaObject); | ||
|
||
ChatModel chatModel = createChatModel(callAIChatModel); | ||
Map<String, Object> substitutions = (Map<String, Object>) javaObject; | ||
|
||
List<ChatMessage> messages = new ArrayList<>(); | ||
|
||
if (callAIChatModel.getChatModelRequest().getSystemMessages() != null) { | ||
for (String systemMessage : callAIChatModel.getChatModelRequest().getSystemMessages()) { | ||
String fixedUserMessage = replaceVariables(systemMessage, substitutions); | ||
messages.add(new SystemMessage(fixedUserMessage)); | ||
} | ||
} | ||
|
||
if (callAIChatModel.getChatModelRequest().getUserMessages() != null) { | ||
for (String userMessage : callAIChatModel.getChatModelRequest().getUserMessages()) { | ||
String fixedUserMessage = replaceVariables(userMessage, substitutions); | ||
messages.add(new UserMessage(fixedUserMessage)); | ||
} | ||
} | ||
|
||
return prepareResponse(chatModel.chat(messages), javaObject); | ||
} | ||
|
||
private ChatModel createChatModel(CallAIChatModel callAIChatModel) { | ||
ChatModelService chatModelService = getAvailableModel(); | ||
if (chatModelService != null) { | ||
return chatModelService.getChatModel(callAIChatModel.getPreferences()); | ||
} | ||
throw new IllegalStateException( | ||
"No LLM models found. Please ensure that you have the required dependencies in your classpath."); | ||
} | ||
|
||
private String replaceVariables(String template, Map<String, Object> substitutions) { | ||
Set<String> variables = extractVariables(template); | ||
for (Map.Entry<String, Object> entry : substitutions.entrySet()) { | ||
String variable = entry.getKey(); | ||
Object value = entry.getValue(); | ||
if (value != null && variables.contains(variable)) { | ||
template = template.replace("{{" + variable + "}}", value.toString()); | ||
} | ||
} | ||
return template; | ||
} | ||
|
||
private ChatModelService getAvailableModel() { | ||
ServiceLoader<ChatModelService> loader = ServiceLoader.load(ChatModelService.class); | ||
|
||
for (ChatModelService service : loader) { | ||
return service; | ||
} | ||
|
||
throw new IllegalStateException( | ||
"No LLM models found. Please ensure that you have the required dependencies in your classpath."); | ||
} | ||
|
||
private static Set<String> extractVariables(String template) { | ||
Set<String> variables = new HashSet<>(); | ||
Matcher matcher = VARIABLE_PATTERN.matcher(template); | ||
while (matcher.find()) { | ||
variables.add(matcher.group(1)); | ||
} | ||
return variables; | ||
} | ||
|
||
private void validate(CallAIChatModel callAIChatModel, Object javaObject) { | ||
// TODO | ||
} | ||
|
||
protected Map<String, Object> prepareResponse(ChatResponse response, Object javaObject) { | ||
String id = response.id(); | ||
String modelName = response.modelName(); | ||
TokenUsage tokenUsage = response.tokenUsage(); | ||
FinishReason finishReason = response.finishReason(); | ||
AiMessage aiMessage = response.aiMessage(); | ||
|
||
Map<String, Object> responseMap = (Map<String, Object>) javaObject; | ||
if (response.id() != null) { | ||
responseMap.put("id", id); | ||
} | ||
|
||
if (modelName != null) { | ||
responseMap.put("modelName", modelName); | ||
} | ||
|
||
if (tokenUsage != null) { | ||
responseMap.put("tokenUsage.inputTokenCount", tokenUsage.inputTokenCount()); | ||
responseMap.put("tokenUsage.outputTokenCount", tokenUsage.outputTokenCount()); | ||
responseMap.put("tokenUsage.totalTokenCount", tokenUsage.totalTokenCount()); | ||
} | ||
|
||
if (finishReason != null) { | ||
responseMap.put("finishReason", finishReason.name()); | ||
} | ||
|
||
if (aiMessage != null) { | ||
responseMap.put("result", aiMessage.text()); | ||
} | ||
|
||
return responseMap; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This dependency is duplicated