diff --git a/experimental/ai/impl/pom.xml b/experimental/ai/impl/pom.xml new file mode 100644 index 00000000..4757feca --- /dev/null +++ b/experimental/ai/impl/pom.xml @@ -0,0 +1,53 @@ + + + 4.0.0 + + io.serverlessworkflow + serverlessworkflow-experimental-ai-parent + 8.0.0-SNAPSHOT + + serverlessworkflow-experimental-ai-impl + ServelessWorkflow:: Experimental:: AI:: Impl + + + io.serverlessworkflow + serverlessworkflow-impl-core + + + io.serverlessworkflow + serverlessworkflow-experimental-ai-types + + + dev.langchain4j + langchain4j + 1.1.0 + + + org.junit.jupiter + junit-jupiter-api + test + + + org.junit.jupiter + junit-jupiter-engine + test + + + org.junit.jupiter + junit-jupiter-params + test + + + org.assertj + assertj-core + test + + + ch.qos.logback + logback-classic + test + + + diff --git a/experimental/ai/impl/src/main/java/io/serverlessworkflow/impl/executors/ai/AIChatModelCallExecutor.java b/experimental/ai/impl/src/main/java/io/serverlessworkflow/impl/executors/ai/AIChatModelCallExecutor.java new file mode 100644 index 00000000..adfbba91 --- /dev/null +++ b/experimental/ai/impl/src/main/java/io/serverlessworkflow/impl/executors/ai/AIChatModelCallExecutor.java @@ -0,0 +1,58 @@ +/* + * Copyright 2020-Present The Serverless Workflow Specification Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.serverlessworkflow.impl.executors.ai; + +import io.serverlessworkflow.api.types.TaskBase; +import io.serverlessworkflow.api.types.ai.AbstractCallAIChatModelTask; +import io.serverlessworkflow.api.types.ai.CallAIChatModel; +import io.serverlessworkflow.api.types.ai.CallAILangChainChatModel; +import io.serverlessworkflow.impl.TaskContext; +import io.serverlessworkflow.impl.WorkflowApplication; +import io.serverlessworkflow.impl.WorkflowContext; +import io.serverlessworkflow.impl.WorkflowModel; +import io.serverlessworkflow.impl.WorkflowModelFactory; +import io.serverlessworkflow.impl.executors.CallableTask; +import io.serverlessworkflow.impl.resources.ResourceLoader; +import java.util.concurrent.CompletableFuture; + +public class AIChatModelCallExecutor implements CallableTask { + + private AIChatModelExecutor executor; + + @Override + public void init( + AbstractCallAIChatModelTask task, WorkflowApplication application, ResourceLoader loader) { + if (task instanceof CallAILangChainChatModel model) { + executor = new CallAILangChainChatModelExecutor(model); + } else if (task instanceof CallAIChatModel model) { + executor = new CallAIChatModelExecutor(model); + } + } + + @Override + public CompletableFuture apply( + WorkflowContext workflowContext, TaskContext taskContext, WorkflowModel input) { + WorkflowModelFactory modelFactory = workflowContext.definition().application().modelFactory(); + return CompletableFuture.completedFuture( + modelFactory.fromAny(executor.apply(input.asJavaObject()))); + } + + @Override + public boolean accept(Class clazz) { + return AbstractCallAIChatModelTask.class.isAssignableFrom(clazz); + } +} diff --git a/experimental/ai/impl/src/main/java/io/serverlessworkflow/impl/executors/ai/AIChatModelExecutor.java b/experimental/ai/impl/src/main/java/io/serverlessworkflow/impl/executors/ai/AIChatModelExecutor.java new file mode 100644 index 00000000..4d40a591 --- /dev/null +++ b/experimental/ai/impl/src/main/java/io/serverlessworkflow/impl/executors/ai/AIChatModelExecutor.java @@ -0,0 +1,21 @@ +/* + * Copyright 2020-Present The Serverless Workflow Specification Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.serverlessworkflow.impl.executors.ai; + +import java.util.function.UnaryOperator; + +public interface AIChatModelExecutor extends UnaryOperator {} diff --git a/experimental/ai/impl/src/main/java/io/serverlessworkflow/impl/executors/ai/AIChatModelTaskExecutorFactory.java b/experimental/ai/impl/src/main/java/io/serverlessworkflow/impl/executors/ai/AIChatModelTaskExecutorFactory.java new file mode 100644 index 00000000..c9f78360 --- /dev/null +++ b/experimental/ai/impl/src/main/java/io/serverlessworkflow/impl/executors/ai/AIChatModelTaskExecutorFactory.java @@ -0,0 +1,21 @@ +/* + * Copyright 2020-Present The Serverless Workflow Specification Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.serverlessworkflow.impl.executors.ai; + +import io.serverlessworkflow.impl.executors.DefaultTaskExecutorFactory; + +public class AIChatModelTaskExecutorFactory extends DefaultTaskExecutorFactory {} diff --git a/experimental/ai/impl/src/main/java/io/serverlessworkflow/impl/executors/ai/CallAIChatModelExecutor.java b/experimental/ai/impl/src/main/java/io/serverlessworkflow/impl/executors/ai/CallAIChatModelExecutor.java new file mode 100644 index 00000000..0eb51c0c --- /dev/null +++ b/experimental/ai/impl/src/main/java/io/serverlessworkflow/impl/executors/ai/CallAIChatModelExecutor.java @@ -0,0 +1,151 @@ +/* + * Copyright 2020-Present The Serverless Workflow Specification Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.serverlessworkflow.impl.executors.ai; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.SystemMessage; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.chat.ChatModel; +import dev.langchain4j.model.chat.response.ChatResponse; +import dev.langchain4j.model.output.FinishReason; +import dev.langchain4j.model.output.TokenUsage; +import io.serverlessworkflow.api.types.ai.CallAIChatModel; +import io.serverlessworkflow.impl.services.ChatModelService; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.ServiceLoader; +import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +public class CallAIChatModelExecutor implements AIChatModelExecutor { + + private static final Pattern VARIABLE_PATTERN = Pattern.compile("\\{\\{\\s*(.+?)\\s*\\}\\}"); + + private final CallAIChatModel callAIChatModel; + + public CallAIChatModelExecutor(CallAIChatModel callAIChatModel) { + this.callAIChatModel = callAIChatModel; + } + + @Override + public Object apply(Object javaObject) { + validate(callAIChatModel, javaObject); + + ChatModel chatModel = createChatModel(callAIChatModel); + Map substitutions = (Map) javaObject; + + List messages = new ArrayList<>(); + + if (callAIChatModel.getChatModelRequest().getSystemMessages() != null) { + for (String systemMessage : callAIChatModel.getChatModelRequest().getSystemMessages()) { + String fixedUserMessage = replaceVariables(systemMessage, substitutions); + messages.add(new SystemMessage(fixedUserMessage)); + } + } + + if (callAIChatModel.getChatModelRequest().getUserMessages() != null) { + for (String userMessage : callAIChatModel.getChatModelRequest().getUserMessages()) { + String fixedUserMessage = replaceVariables(userMessage, substitutions); + messages.add(new UserMessage(fixedUserMessage)); + } + } + + return prepareResponse(chatModel.chat(messages), javaObject); + } + + private ChatModel createChatModel(CallAIChatModel callAIChatModel) { + ChatModelService chatModelService = getAvailableModel(); + if (chatModelService != null) { + return chatModelService.getChatModel(callAIChatModel.getPreferences()); + } + throw new IllegalStateException( + "No LLM models found. Please ensure that you have the required dependencies in your classpath."); + } + + private String replaceVariables(String template, Map substitutions) { + Set variables = extractVariables(template); + for (Map.Entry entry : substitutions.entrySet()) { + String variable = entry.getKey(); + Object value = entry.getValue(); + if (value != null && variables.contains(variable)) { + template = template.replace("{{" + variable + "}}", value.toString()); + } + } + return template; + } + + private ChatModelService getAvailableModel() { + ServiceLoader loader = ServiceLoader.load(ChatModelService.class); + + for (ChatModelService service : loader) { + return service; + } + + throw new IllegalStateException( + "No LLM models found. Please ensure that you have the required dependencies in your classpath."); + } + + private static Set extractVariables(String template) { + Set variables = new HashSet<>(); + Matcher matcher = VARIABLE_PATTERN.matcher(template); + while (matcher.find()) { + variables.add(matcher.group(1)); + } + return variables; + } + + private void validate(CallAIChatModel callAIChatModel, Object javaObject) { + // TODO + } + + protected Map prepareResponse(ChatResponse response, Object javaObject) { + String id = response.id(); + String modelName = response.modelName(); + TokenUsage tokenUsage = response.tokenUsage(); + FinishReason finishReason = response.finishReason(); + AiMessage aiMessage = response.aiMessage(); + + Map responseMap = (Map) javaObject; + if (response.id() != null) { + responseMap.put("id", id); + } + + if (modelName != null) { + responseMap.put("modelName", modelName); + } + + if (tokenUsage != null) { + responseMap.put("tokenUsage.inputTokenCount", tokenUsage.inputTokenCount()); + responseMap.put("tokenUsage.outputTokenCount", tokenUsage.outputTokenCount()); + responseMap.put("tokenUsage.totalTokenCount", tokenUsage.totalTokenCount()); + } + + if (finishReason != null) { + responseMap.put("finishReason", finishReason.name()); + } + + if (aiMessage != null) { + responseMap.put("result", aiMessage.text()); + } + + return responseMap; + } +} diff --git a/experimental/ai/impl/src/main/java/io/serverlessworkflow/impl/executors/ai/CallAILangChainChatModelExecutor.java b/experimental/ai/impl/src/main/java/io/serverlessworkflow/impl/executors/ai/CallAILangChainChatModelExecutor.java new file mode 100644 index 00000000..9ed9e760 --- /dev/null +++ b/experimental/ai/impl/src/main/java/io/serverlessworkflow/impl/executors/ai/CallAILangChainChatModelExecutor.java @@ -0,0 +1,91 @@ +/* + * Copyright 2020-Present The Serverless Workflow Specification Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.serverlessworkflow.impl.executors.ai; + +import dev.langchain4j.model.chat.ChatModel; +import dev.langchain4j.service.AiServices; +import dev.langchain4j.service.V; +import io.serverlessworkflow.api.types.ai.CallAILangChainChatModel; +import java.lang.reflect.Method; +import java.lang.reflect.Parameter; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +public class CallAILangChainChatModelExecutor implements AIChatModelExecutor { + + private final CallAILangChainChatModel callAIChatModel; + + public CallAILangChainChatModelExecutor(CallAILangChainChatModel callAIChatModel) { + this.callAIChatModel = callAIChatModel; + } + + @Override + public Object apply(Object javaObject) { + ChatModel chatModel = callAIChatModel.getChatModel(); + Class chatModelRequest = callAIChatModel.getChatModelRequest(); + Map substitutions = (Map) javaObject; + validate(chatModel, chatModelRequest, substitutions); + + Method method = getMethod(chatModelRequest, callAIChatModel.getMethodName()); + List resolvedParameters = resolvedParameters(method, substitutions); + + var aiServices = AiServices.builder(chatModelRequest).chatModel(chatModel).build(); + try { + Object[] args = new Object[resolvedParameters.size()]; + for (int i = 0; i < resolvedParameters.size(); i++) { + String paramName = resolvedParameters.get(i); + args[i] = substitutions.get(paramName); + } + + Object response = method.invoke(aiServices, args); + substitutions.put("result", response); + } catch (Exception e) { + throw new RuntimeException("Error invoking chat model method", e); + } + return substitutions; + } + + private void validate( + ChatModel chatModel, Class chatModelRequest, Map substitutions) {} + + private Method getMethod(Class chatModelRequest, String methodName) { + for (Method method : chatModelRequest.getMethods()) { + if (method.getName().equals(methodName)) { + return method; + } + } + throw new IllegalArgumentException( + "Method " + methodName + " not found in class " + chatModelRequest.getName()); + } + + private List resolvedParameters(Method method, Map substitutions) { + List resolvedParameters = new ArrayList<>(); + for (Parameter parameter : method.getParameters()) { + if (parameter.getAnnotation(V.class) != null) { + V v = parameter.getAnnotation(V.class); + String paramName = v.value(); + if (substitutions.containsKey(paramName)) { + resolvedParameters.add(paramName); + } else { + throw new IllegalArgumentException("Missing substitution for parameter: " + paramName); + } + } + } + return resolvedParameters; + } +} diff --git a/experimental/ai/impl/src/main/java/io/serverlessworkflow/impl/services/ChatModelService.java b/experimental/ai/impl/src/main/java/io/serverlessworkflow/impl/services/ChatModelService.java new file mode 100644 index 00000000..5c4d8214 --- /dev/null +++ b/experimental/ai/impl/src/main/java/io/serverlessworkflow/impl/services/ChatModelService.java @@ -0,0 +1,24 @@ +/* + * Copyright 2020-Present The Serverless Workflow Specification Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.serverlessworkflow.impl.services; + +import dev.langchain4j.model.chat.ChatModel; +import io.serverlessworkflow.api.types.ai.CallAIChatModel; + +public interface ChatModelService { + ChatModel getChatModel(CallAIChatModel.ChatModelPreferences chatModelPreferences); +} diff --git a/experimental/ai/impl/src/main/resources/META-INF/services/io.serverlessworkflow.impl.executors.CallableTask b/experimental/ai/impl/src/main/resources/META-INF/services/io.serverlessworkflow.impl.executors.CallableTask new file mode 100644 index 00000000..680fce77 --- /dev/null +++ b/experimental/ai/impl/src/main/resources/META-INF/services/io.serverlessworkflow.impl.executors.CallableTask @@ -0,0 +1 @@ +io.serverlessworkflow.impl.executors.ai.AIChatModelCallExecutor \ No newline at end of file diff --git a/experimental/ai/models/openai/pom.xml b/experimental/ai/models/openai/pom.xml new file mode 100644 index 00000000..42135bfd --- /dev/null +++ b/experimental/ai/models/openai/pom.xml @@ -0,0 +1,27 @@ + + + 4.0.0 + + io.serverlessworkflow + serverlessworkflow-experimental-ai-models-parent + 8.0.0-SNAPSHOT + + serverlessworkflow-experimental-ai-models-openai + ServelessWorkflow:: Experimental:: AI:: Models:: OpenAI + + + + dev.langchain4j + langchain4j-open-ai + 1.1.0 + + + io.serverlessworkflow + serverlessworkflow-experimental-ai-impl + 8.0.0-SNAPSHOT + compile + + + \ No newline at end of file diff --git a/experimental/ai/models/openai/src/main/java/io/serverlessworkflow/impl/services/openai/OpenAIChatModelService.java b/experimental/ai/models/openai/src/main/java/io/serverlessworkflow/impl/services/openai/OpenAIChatModelService.java new file mode 100644 index 00000000..b05eef89 --- /dev/null +++ b/experimental/ai/models/openai/src/main/java/io/serverlessworkflow/impl/services/openai/OpenAIChatModelService.java @@ -0,0 +1,96 @@ +/* + * Copyright 2020-Present The Serverless Workflow Specification Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.serverlessworkflow.impl.services.openai; + +import dev.langchain4j.model.chat.ChatModel; +import dev.langchain4j.model.openai.OpenAiChatModel; +import io.serverlessworkflow.api.types.ai.CallAIChatModel; +import io.serverlessworkflow.impl.services.ChatModelService; + +public class OpenAIChatModelService implements ChatModelService { + + @Override + public ChatModel getChatModel(CallAIChatModel.ChatModelPreferences chatModelPreferences) { + OpenAiChatModel.OpenAiChatModelBuilder builder = OpenAiChatModel.builder(); + if (chatModelPreferences.getApiKey() != null) { + builder.apiKey(chatModelPreferences.getApiKey()); + } + if (chatModelPreferences.getModelName() != null) { + builder.modelName(chatModelPreferences.getModelName()); + } + if (chatModelPreferences.getBaseUrl() != null) { + builder.baseUrl(chatModelPreferences.getBaseUrl()); + } + if (chatModelPreferences.getMaxTokens() != null) { + builder.maxTokens(chatModelPreferences.getMaxTokens()); + } + if (chatModelPreferences.getTemperature() != null) { + builder.temperature(chatModelPreferences.getTemperature()); + } + if (chatModelPreferences.getTopP() != null) { + builder.topP(chatModelPreferences.getTopP()); + } + if (chatModelPreferences.getResponseFormat() != null) { + builder.responseFormat(chatModelPreferences.getResponseFormat()); + } + if (chatModelPreferences.getMaxRetries() != null) { + builder.maxRetries(chatModelPreferences.getMaxRetries()); + } + if (chatModelPreferences.getTimeout() != null) { + builder.timeout(chatModelPreferences.getTimeout()); + } + if (chatModelPreferences.getLogRequests() != null) { + builder.logRequests(chatModelPreferences.getLogRequests()); + } + if (chatModelPreferences.getLogResponses() != null) { + builder.logResponses(chatModelPreferences.getLogResponses()); + } + if (chatModelPreferences.getResponseFormat() != null) { + builder.responseFormat(chatModelPreferences.getResponseFormat()); + } + + if (chatModelPreferences.getMaxCompletionTokens() != null) { + builder.maxCompletionTokens(chatModelPreferences.getMaxCompletionTokens()); + } + + if (chatModelPreferences.getPresencePenalty() != null) { + builder.presencePenalty(chatModelPreferences.getPresencePenalty()); + } + + if (chatModelPreferences.getFrequencyPenalty() != null) { + builder.frequencyPenalty(chatModelPreferences.getFrequencyPenalty()); + } + + if (chatModelPreferences.getStrictJsonSchema() != null) { + builder.strictJsonSchema(chatModelPreferences.getStrictJsonSchema()); + } + + if (chatModelPreferences.getSeed() != null) { + builder.seed(chatModelPreferences.getSeed()); + } + + if (chatModelPreferences.getUser() != null) { + builder.user(chatModelPreferences.getUser()); + } + + if (chatModelPreferences.getProjectId() != null) { + builder.projectId(chatModelPreferences.getProjectId()); + } + + return builder.build(); + } +} diff --git a/experimental/ai/models/openai/src/main/resources/META-INF/services/io.serverlessworkflow.impl.services.ChatModelService b/experimental/ai/models/openai/src/main/resources/META-INF/services/io.serverlessworkflow.impl.services.ChatModelService new file mode 100644 index 00000000..7f1d56b1 --- /dev/null +++ b/experimental/ai/models/openai/src/main/resources/META-INF/services/io.serverlessworkflow.impl.services.ChatModelService @@ -0,0 +1 @@ +io.serverlessworkflow.impl.services.openai.OpenAIChatModelService \ No newline at end of file diff --git a/experimental/ai/models/pom.xml b/experimental/ai/models/pom.xml new file mode 100644 index 00000000..b31fc70d --- /dev/null +++ b/experimental/ai/models/pom.xml @@ -0,0 +1,19 @@ + + + 4.0.0 + + io.serverlessworkflow + serverlessworkflow-experimental-ai-parent + 8.0.0-SNAPSHOT + + serverlessworkflow-experimental-ai-models-parent + ServelessWorkflow:: Experimental:: AI:: Models:: Parent + pom + + openai + + + + \ No newline at end of file diff --git a/experimental/ai/pom.xml b/experimental/ai/pom.xml new file mode 100644 index 00000000..2999c3a7 --- /dev/null +++ b/experimental/ai/pom.xml @@ -0,0 +1,19 @@ + + + 4.0.0 + + io.serverlessworkflow + serverlessworkflow-experimental + 8.0.0-SNAPSHOT + + serverlessworkflow-experimental-ai-parent + ServelessWorkflow:: Experimental:: AI:: Parent + pom + + impl + types + models + + \ No newline at end of file diff --git a/experimental/ai/types/pom.xml b/experimental/ai/types/pom.xml new file mode 100644 index 00000000..31f04713 --- /dev/null +++ b/experimental/ai/types/pom.xml @@ -0,0 +1,55 @@ + + + 4.0.0 + + io.serverlessworkflow + serverlessworkflow-experimental-ai-parent + 8.0.0-SNAPSHOT + + serverlessworkflow-experimental-ai-types + ServelessWorkflow:: Experimental:: AI:: Types + + + + io.serverlessworkflow + serverlessworkflow-experimental-types + + + dev.langchain4j + langchain4j + 1.1.0 + + + org.junit.jupiter + junit-jupiter-api + test + + + org.junit.jupiter + junit-jupiter-engine + test + + + org.junit.jupiter + junit-jupiter-params + test + + + org.assertj + assertj-core + test + + + ch.qos.logback + logback-classic + test + + + io.serverlessworkflow + serverlessworkflow-experimental-types + + + + \ No newline at end of file diff --git a/experimental/ai/types/src/main/java/io/serverlessworkflow/api/types/ai/AbstractCallAIChatModelTask.java b/experimental/ai/types/src/main/java/io/serverlessworkflow/api/types/ai/AbstractCallAIChatModelTask.java new file mode 100644 index 00000000..63bd40fb --- /dev/null +++ b/experimental/ai/types/src/main/java/io/serverlessworkflow/api/types/ai/AbstractCallAIChatModelTask.java @@ -0,0 +1,21 @@ +/* + * Copyright 2020-Present The Serverless Workflow Specification Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.serverlessworkflow.api.types.ai; + +import io.serverlessworkflow.api.types.TaskBase; + +public abstract class AbstractCallAIChatModelTask extends TaskBase {} diff --git a/experimental/ai/types/src/main/java/io/serverlessworkflow/api/types/ai/CallAIChatModel.java b/experimental/ai/types/src/main/java/io/serverlessworkflow/api/types/ai/CallAIChatModel.java new file mode 100644 index 00000000..0c10371a --- /dev/null +++ b/experimental/ai/types/src/main/java/io/serverlessworkflow/api/types/ai/CallAIChatModel.java @@ -0,0 +1,438 @@ +/* + * Copyright 2020-Present The Serverless Workflow Specification Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.serverlessworkflow.api.types.ai; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +public class CallAIChatModel extends AbstractCallAIChatModelTask { + + private ChatModelPreferences chatModelPreferences; + + private ChatModelRequest chatModelRequest; + + protected CallAIChatModel() {} + + public static Builder builder() { + return new Builder(); + } + + @Override + public String toString() { + return "CallAIChatModel{" + + "chatModelPreferences=" + + chatModelPreferences + + ", chatModelRequest=" + + chatModelRequest + + '}'; + } + + public ChatModelPreferences getPreferences() { + return chatModelPreferences; + } + + public ChatModelRequest getChatModelRequest() { + return chatModelRequest; + } + + public static class Builder { + + private Builder() {} + + private ChatModelPreferences chatModelPreferences; + private ChatModelRequest chatModelRequest; + + public Builder preferences(ChatModelPreferences chatModelPreferences) { + this.chatModelPreferences = chatModelPreferences; + return this; + } + + public Builder request(ChatModelRequest chatModelRequest) { + this.chatModelRequest = chatModelRequest; + return this; + } + + public CallAIChatModel build() { + CallAIChatModel callAIChatModel = new CallAIChatModel(); + callAIChatModel.chatModelPreferences = this.chatModelPreferences; + callAIChatModel.chatModelRequest = this.chatModelRequest; + return callAIChatModel; + } + } + + public static class ChatModelPreferences { + private String baseUrl; + private String apiKey; + private String organizationId; + private String projectId; + private String modelName; + private Double temperature; + private Double topP; + private Integer maxTokens; + private Integer maxCompletionTokens; + private Double presencePenalty; + private Double frequencyPenalty; + private String responseFormat; + private Boolean strictJsonSchema; + private Integer seed; + private String user; + private Duration timeout; + private Integer maxRetries; + private Boolean logRequests; + private Boolean logResponses; + + private ChatModelPreferences() {} + + public static ChatModelPreferences.Builder builder() { + return new ChatModelPreferences.Builder(); + } + + public String getBaseUrl() { + return baseUrl; + } + + public String getApiKey() { + return apiKey; + } + + public String getOrganizationId() { + return organizationId; + } + + public String getProjectId() { + return projectId; + } + + public String getModelName() { + return modelName; + } + + public Double getTemperature() { + return temperature; + } + + public Double getTopP() { + return topP; + } + + public Integer getMaxTokens() { + return maxTokens; + } + + public Integer getMaxCompletionTokens() { + return maxCompletionTokens; + } + + public Double getPresencePenalty() { + return presencePenalty; + } + + public Double getFrequencyPenalty() { + return frequencyPenalty; + } + + public String getResponseFormat() { + return responseFormat; + } + + public Boolean getStrictJsonSchema() { + return strictJsonSchema; + } + + public Integer getSeed() { + return seed; + } + + public String getUser() { + return user; + } + + public Duration getTimeout() { + return timeout; + } + + public Integer getMaxRetries() { + return maxRetries; + } + + public Boolean getLogRequests() { + return logRequests; + } + + public Boolean getLogResponses() { + return logResponses; + } + + @Override + public String toString() { + return "Builder{" + + "baseUrl='" + + baseUrl + + '\'' + + ", apiKey='" + + apiKey + + '\'' + + ", organizationId='" + + organizationId + + '\'' + + ", projectId='" + + projectId + + '\'' + + ", modelName='" + + modelName + + '\'' + + ", temperature=" + + temperature + + ", topP=" + + topP + + ", maxTokens=" + + maxTokens + + ", maxCompletionTokens=" + + maxCompletionTokens + + ", presencePenalty=" + + presencePenalty + + ", frequencyPenalty=" + + frequencyPenalty + + ", responseFormat='" + + responseFormat + + '\'' + + ", strictJsonSchema=" + + strictJsonSchema + + ", seed=" + + seed + + ", user='" + + user + + '\'' + + ", timeout=" + + timeout + + ", maxRetries=" + + maxRetries + + ", logRequests=" + + logRequests + + ", logResponses=" + + logResponses + + '}'; + } + + public static class Builder { + private String baseUrl; + private String apiKey; + private String organizationId; + private String projectId; + private String modelName; + private Double temperature; + private Double topP; + private Integer maxTokens; + private Integer maxCompletionTokens; + private Double presencePenalty; + private Double frequencyPenalty; + private String responseFormat; + private Boolean strictJsonSchema; + private Integer seed; + private String user; + private Duration timeout; + private Integer maxRetries; + private Boolean logRequests; + private Boolean logResponses; + + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + public Builder apiKey(String apiKey) { + this.apiKey = apiKey; + return this; + } + + public Builder organizationId(String organizationId) { + this.organizationId = organizationId; + return this; + } + + public Builder projectId(String projectId) { + this.projectId = projectId; + return this; + } + + public Builder modelName(String modelName) { + this.modelName = modelName; + return this; + } + + public Builder temperature(Double temperature) { + this.temperature = temperature; + return this; + } + + public Builder topP(Double topP) { + this.topP = topP; + return this; + } + + public Builder maxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + public Builder maxCompletionTokens(Integer maxCompletionTokens) { + this.maxCompletionTokens = maxCompletionTokens; + return this; + } + + public Builder presencePenalty(Double presencePenalty) { + this.presencePenalty = presencePenalty; + return this; + } + + public Builder frequencyPenalty(Double frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + return this; + } + + public Builder responseFormat(String responseFormat) { + this.responseFormat = responseFormat; + return this; + } + + public Builder strictJsonSchema(Boolean strictJsonSchema) { + this.strictJsonSchema = strictJsonSchema; + return this; + } + + public Builder seed(Integer seed) { + this.seed = seed; + return this; + } + + public Builder user(String user) { + this.user = user; + return this; + } + + public Builder timeout(Duration timeout) { + this.timeout = timeout; + return this; + } + + public Builder maxRetries(Integer maxRetries) { + this.maxRetries = maxRetries; + return this; + } + + public Builder logRequests(Boolean logRequests) { + this.logRequests = logRequests; + return this; + } + + public Builder logResponses(Boolean logResponses) { + this.logResponses = logResponses; + return this; + } + + public ChatModelPreferences build() { + ChatModelPreferences preferences = new ChatModelPreferences(); + preferences.baseUrl = this.baseUrl; + preferences.apiKey = this.apiKey; + preferences.organizationId = this.organizationId; + preferences.projectId = this.projectId; + preferences.modelName = this.modelName; + preferences.temperature = this.temperature; + preferences.topP = this.topP; + preferences.maxTokens = this.maxTokens; + preferences.maxCompletionTokens = this.maxCompletionTokens; + preferences.presencePenalty = this.presencePenalty; + preferences.frequencyPenalty = this.frequencyPenalty; + preferences.responseFormat = this.responseFormat; + preferences.strictJsonSchema = this.strictJsonSchema; + preferences.seed = this.seed; + preferences.user = this.user; + preferences.timeout = this.timeout; + preferences.maxRetries = this.maxRetries; + preferences.logRequests = this.logRequests; + preferences.logResponses = this.logResponses; + return preferences; + } + } + } + + public static class ChatModelRequest { + + private List userMessages; + private List systemMessages; + + private ChatModelRequest() {} + + public List getUserMessages() { + return userMessages; + } + + public List getSystemMessages() { + return systemMessages; + } + + public static ChatModelRequest.Builder builder() { + return new ChatModelRequest.Builder(); + } + + @Override + public String toString() { + return "ChatModelRequest{" + + "userMessages=" + + String.join(",", userMessages) + + ", systemMessages=" + + String.join(",", systemMessages) + + '}'; + } + + public static class Builder { + private List userMessages = new ArrayList<>(); + private List systemMessages = new ArrayList<>(); + + private Builder() {} + + public Builder userMessage(String userMessage) { + this.userMessages.add(userMessage); + return this; + } + + public Builder userMessages(Collection userMessages) { + this.userMessages.addAll(userMessages); + return this; + } + + public Builder systemMessage(String systemMessage) { + this.systemMessages.add(systemMessage); + return this; + } + + public Builder systemMessages(Collection systemMessages) { + this.systemMessages.addAll(systemMessages); + return this; + } + + public ChatModelRequest build() { + ChatModelRequest request = new ChatModelRequest(); + request.userMessages = this.userMessages; + request.systemMessages = this.systemMessages; + return request; + } + } + } +} diff --git a/experimental/ai/types/src/main/java/io/serverlessworkflow/api/types/ai/CallAILangChainChatModel.java b/experimental/ai/types/src/main/java/io/serverlessworkflow/api/types/ai/CallAILangChainChatModel.java new file mode 100644 index 00000000..da24b5c8 --- /dev/null +++ b/experimental/ai/types/src/main/java/io/serverlessworkflow/api/types/ai/CallAILangChainChatModel.java @@ -0,0 +1,46 @@ +/* + * Copyright 2020-Present The Serverless Workflow Specification Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.serverlessworkflow.api.types.ai; + +import dev.langchain4j.model.chat.ChatModel; + +public class CallAILangChainChatModel extends AbstractCallAIChatModelTask { + + private final ChatModel chatModel; + private final Class chatModelRequest; + + private final String methodName; + + public CallAILangChainChatModel( + ChatModel chatModel, Class chatModelRequest, String methodName) { + this.chatModel = chatModel; + this.chatModelRequest = chatModelRequest; + this.methodName = methodName; + } + + public ChatModel getChatModel() { + return chatModel; + } + + public Class getChatModelRequest() { + return chatModelRequest; + } + + public String getMethodName() { + return methodName; + } +} diff --git a/experimental/ai/types/src/main/java/io/serverlessworkflow/api/types/ai/CallTaskAIChatModel.java b/experimental/ai/types/src/main/java/io/serverlessworkflow/api/types/ai/CallTaskAIChatModel.java new file mode 100644 index 00000000..4e2bd241 --- /dev/null +++ b/experimental/ai/types/src/main/java/io/serverlessworkflow/api/types/ai/CallTaskAIChatModel.java @@ -0,0 +1,37 @@ +/* + * Copyright 2020-Present The Serverless Workflow Specification Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.serverlessworkflow.api.types.ai; + +import io.serverlessworkflow.api.types.CallTask; + +public class CallTaskAIChatModel extends CallTask { + + private AbstractCallAIChatModelTask callAIChatModel; + + public CallTaskAIChatModel(AbstractCallAIChatModelTask callAIChatModel) { + this.callAIChatModel = callAIChatModel; + } + + public AbstractCallAIChatModelTask getCallAIChatModel() { + return callAIChatModel; + } + + @Override + public Object get() { + return callAIChatModel != null ? callAIChatModel : super.get(); + } +} diff --git a/experimental/pom.xml b/experimental/pom.xml index 3312207e..1c5efd8d 100644 --- a/experimental/pom.xml +++ b/experimental/pom.xml @@ -35,10 +35,16 @@ serverlessworkflow-fluent-func ${project.version} + + io.serverlessworkflow + serverlessworkflow-experimental-ai-types + ${project.version} + types lambda + ai - \ No newline at end of file +