Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions experimental/ai/impl/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>io.serverlessworkflow</groupId>
<artifactId>serverlessworkflow-experimental-ai-parent</artifactId>
<version>8.0.0-SNAPSHOT</version>
</parent>
<artifactId>serverlessworkflow-experimental-ai-impl</artifactId>
<name>ServelessWorkflow:: Experimental:: AI:: Impl</name>
<dependencies>
<dependency>
<groupId>io.serverlessworkflow</groupId>
<artifactId>serverlessworkflow-experimental-types</artifactId>
</dependency>
<dependency>
<groupId>io.serverlessworkflow</groupId>
<artifactId>serverlessworkflow-impl-core</artifactId>
</dependency>
<dependency>
<groupId>io.serverlessworkflow</groupId>
<artifactId>serverlessworkflow-experimental-types</artifactId>
</dependency>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This dependency is duplicated

<dependency>
<groupId>io.serverlessworkflow</groupId>
<artifactId>serverlessworkflow-experimental-ai-types</artifactId>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j</artifactId>
<version>1.1.0</version>
<optional>true</optional>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why optional?

</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-api</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-params</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Copyright 2020-Present The Serverless Workflow Specification Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.serverlessworkflow.impl.executors.ai;

import io.serverlessworkflow.ai.api.types.CallAILangChainChatModel;
import io.serverlessworkflow.api.types.TaskBase;
import io.serverlessworkflow.api.types.ai.AbstractCallAIChatModelTask;
import io.serverlessworkflow.api.types.ai.CallAIChatModel;
import io.serverlessworkflow.impl.TaskContext;
import io.serverlessworkflow.impl.WorkflowApplication;
import io.serverlessworkflow.impl.WorkflowContext;
import io.serverlessworkflow.impl.WorkflowModel;
import io.serverlessworkflow.impl.WorkflowModelFactory;
import io.serverlessworkflow.impl.executors.CallableTask;
import io.serverlessworkflow.impl.resources.ResourceLoader;
import java.util.concurrent.CompletableFuture;

public class AIChatModelCallExecutor implements CallableTask<AbstractCallAIChatModelTask> {

@Override
public void init(
AbstractCallAIChatModelTask task, WorkflowApplication application, ResourceLoader loader) {}

@Override
public CompletableFuture<WorkflowModel> apply(
WorkflowContext workflowContext, TaskContext taskContext, WorkflowModel input) {
WorkflowModelFactory modelFactory = workflowContext.definition().application().modelFactory();
if (taskContext.task() instanceof CallAILangChainChatModel callAILangChainChatModel) {
return CompletableFuture.completedFuture(
modelFactory.fromAny(
new CallAILangChainChatModelExecutor()
.apply(callAILangChainChatModel, input.asJavaObject())));
} else if (taskContext.task() instanceof CallAIChatModel callAIChatModel) {
return CompletableFuture.completedFuture(
modelFactory.fromAny(
new CallAIChatModelExecutor().apply(callAIChatModel, input.asJavaObject())));
}
throw new IllegalArgumentException(
"AIChatModelCallExecutor can only process CallAIChatModel tasks, but received: "
+ taskContext.task().getClass().getName());
}

@Override
public boolean accept(Class<? extends TaskBase> clazz) {
return AbstractCallAIChatModelTask.class.isAssignableFrom(clazz);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* Copyright 2020-Present The Serverless Workflow Specification Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.serverlessworkflow.impl.executors.ai;

import io.serverlessworkflow.impl.executors.DefaultTaskExecutorFactory;

public class AIChatModelTaskExecutorFactory extends DefaultTaskExecutorFactory {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* Copyright 2020-Present The Serverless Workflow Specification Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.serverlessworkflow.impl.executors.ai;

public abstract class AbstractCallAIChatModelExecutor<T> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be an interface, not an abstract class. so it should be renamed as AIChatModelExecutor
And it can extends BiFunction<T,Object,Object>, not define the apply


public abstract Object apply(T callAIChatModel, Object javaObject);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
/*
* Copyright 2020-Present The Serverless Workflow Specification Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.serverlessworkflow.impl.executors.ai;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.chat.ChatModel;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.TokenUsage;
import io.serverlessworkflow.api.types.ai.CallAIChatModel;
import io.serverlessworkflow.impl.services.ChatModelService;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.ServiceLoader;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

public class CallAIChatModelExecutor extends AbstractCallAIChatModelExecutor<CallAIChatModel> {

private static final Pattern VARIABLE_PATTERN = Pattern.compile("\\{\\{\\s*(.+?)\\s*\\}\\}");

@Override
public Object apply(CallAIChatModel callAIChatModel, Object javaObject) {
validate(callAIChatModel, javaObject);

ChatModel chatModel = createChatModel(callAIChatModel);
Map<String, Object> substitutions = (Map<String, Object>) javaObject;

List<ChatMessage> messages = new ArrayList<>();

if (callAIChatModel.getChatModelRequest().getSystemMessages() != null) {
for (String systemMessage : callAIChatModel.getChatModelRequest().getSystemMessages()) {
String fixedUserMessage = replaceVariables(systemMessage, substitutions);
messages.add(new SystemMessage(fixedUserMessage));
}
}

if (callAIChatModel.getChatModelRequest().getUserMessages() != null) {
for (String userMessage : callAIChatModel.getChatModelRequest().getUserMessages()) {
String fixedUserMessage = replaceVariables(userMessage, substitutions);
messages.add(new UserMessage(fixedUserMessage));
}
}

return prepareResponse(chatModel.chat(messages), javaObject);
}

private ChatModel createChatModel(CallAIChatModel callAIChatModel) {
ChatModelService chatModelService = getAvailableModel();
if (chatModelService != null) {
return chatModelService.getChatModel(callAIChatModel.getPreferences());
}
throw new IllegalStateException(
"No LLM models found. Please ensure that you have the required dependencies in your classpath.");
}

private String replaceVariables(String template, Map<String, Object> substitutions) {
Set<String> variables = extractVariables(template);
for (Map.Entry<String, Object> entry : substitutions.entrySet()) {
String variable = entry.getKey();
Object value = entry.getValue();
if (value != null && variables.contains(variable)) {
template = template.replace("{{" + variable + "}}", value.toString());
}
}
return template;
}

private ChatModelService getAvailableModel() {
ServiceLoader<ChatModelService> loader = ServiceLoader.load(ChatModelService.class);

for (ChatModelService service : loader) {
return service;
}

throw new IllegalStateException(
"No LLM models found. Please ensure that you have the required dependencies in your classpath.");
}

private static Set<String> extractVariables(String template) {
Set<String> variables = new HashSet<>();
Matcher matcher = VARIABLE_PATTERN.matcher(template);
while (matcher.find()) {
variables.add(matcher.group(1));
}
return variables;
}

private void validate(CallAIChatModel callAIChatModel, Object javaObject) {
// TODO
}

protected Map<String, Object> prepareResponse(ChatResponse response, Object javaObject) {
String id = response.id();
String modelName = response.modelName();
TokenUsage tokenUsage = response.tokenUsage();
FinishReason finishReason = response.finishReason();
AiMessage aiMessage = response.aiMessage();

Map<String, Object> responseMap = (Map<String, Object>) javaObject;
if (response.id() != null) {
responseMap.put("id", id);
}

if (modelName != null) {
responseMap.put("modelName", modelName);
}

if (tokenUsage != null) {
responseMap.put("tokenUsage.inputTokenCount", tokenUsage.inputTokenCount());
responseMap.put("tokenUsage.outputTokenCount", tokenUsage.outputTokenCount());
responseMap.put("tokenUsage.totalTokenCount", tokenUsage.totalTokenCount());
}

if (finishReason != null) {
responseMap.put("finishReason", finishReason.name());
}

if (aiMessage != null) {
responseMap.put("result", aiMessage.text());
}

return responseMap;
}
}
Loading