16
16
17
17
package io .serverlessworkflow .impl .executors .ai ;
18
18
19
- import dev .langchain4j .data .message .AiMessage ;
20
- import dev .langchain4j .data .message .ChatMessage ;
21
- import dev .langchain4j .data .message .SystemMessage ;
22
- import dev .langchain4j .data .message .UserMessage ;
23
- import dev .langchain4j .model .chat .ChatModel ;
24
- import dev .langchain4j .model .chat .response .ChatResponse ;
25
- import dev .langchain4j .model .output .FinishReason ;
26
- import dev .langchain4j .model .output .TokenUsage ;
27
19
import io .serverlessworkflow .ai .api .types .CallAILangChainChatModel ;
28
20
import io .serverlessworkflow .api .types .TaskBase ;
29
21
import io .serverlessworkflow .api .types .ai .CallAIChatModel ;
34
26
import io .serverlessworkflow .impl .WorkflowModelFactory ;
35
27
import io .serverlessworkflow .impl .executors .CallableTask ;
36
28
import io .serverlessworkflow .impl .resources .ResourceLoader ;
37
- import io .serverlessworkflow .impl .services .ChatModelService ;
38
- import java .util .ArrayList ;
39
- import java .util .HashSet ;
40
- import java .util .List ;
41
- import java .util .Map ;
42
- import java .util .ServiceLoader ;
43
- import java .util .Set ;
44
29
import java .util .concurrent .CompletableFuture ;
45
- import java .util .regex .Matcher ;
46
- import java .util .regex .Pattern ;
47
30
48
31
public class AIChatModelCallExecutor implements CallableTask <CallAIChatModel > {
49
32
50
- private static final Pattern VARIABLE_PATTERN = Pattern .compile ("\\ {\\ {\\ s*(.+?)\\ s*\\ }\\ }" );
51
-
52
33
@ Override
53
34
public void init (CallAIChatModel task , WorkflowApplication application , ResourceLoader loader ) {}
54
35
@@ -58,12 +39,13 @@ public CompletableFuture<WorkflowModel> apply(
58
39
WorkflowModelFactory modelFactory = workflowContext .definition ().application ().modelFactory ();
59
40
if (taskContext .task () instanceof CallAILangChainChatModel callAILangChainChatModel ) {
60
41
return CompletableFuture .completedFuture (
61
- modelFactory .fromAny (doCall ( callAILangChainChatModel , input . asJavaObject ())));
62
- }
63
-
64
- if (taskContext .task () instanceof CallAIChatModel callAIChatModel ) {
42
+ modelFactory .fromAny (
43
+ new CallAILangChainChatModelExecutor ()
44
+ . apply ( callAILangChainChatModel , input . asJavaObject ())));
45
+ } else if (taskContext .task () instanceof CallAIChatModel callAIChatModel ) {
65
46
return CompletableFuture .completedFuture (
66
- modelFactory .fromAny (doCall (callAIChatModel , input .asJavaObject ())));
47
+ modelFactory .fromAny (
48
+ new CallAIChatModelExecutor ().apply (callAIChatModel , input .asJavaObject ())));
67
49
}
68
50
throw new IllegalArgumentException (
69
51
"AIChatModelCallExecutor can only process CallAIChatModel tasks, but received: "
@@ -74,112 +56,4 @@ public CompletableFuture<WorkflowModel> apply(
74
56
public boolean accept (Class <? extends TaskBase > clazz ) {
75
57
return CallAIChatModel .class .isAssignableFrom (clazz );
76
58
}
77
-
78
- private Object doCall (CallAILangChainChatModel callAIChatModel , Object javaObject ) {
79
- ChatModel chatModel = callAIChatModel .getChatModel ();
80
- Class <?> chatModelRequest = callAIChatModel .getChatModelRequest ();
81
- }
82
-
83
- private Object doCall (CallAIChatModel callAIChatModel , Object javaObject ) {
84
- validate (callAIChatModel , javaObject );
85
- ChatModel chatModel = createChatModel (callAIChatModel );
86
- Map <String , Object > substitutions = (Map <String , Object >) javaObject ;
87
-
88
- List <ChatMessage > messages = new ArrayList <>();
89
-
90
- if (callAIChatModel .getChatModelRequest ().getSystemMessages () != null ) {
91
- for (String systemMessage : callAIChatModel .getChatModelRequest ().getSystemMessages ()) {
92
- String fixedUserMessage = replaceVariables (systemMessage , substitutions );
93
- messages .add (new SystemMessage (fixedUserMessage ));
94
- }
95
- }
96
-
97
- if (callAIChatModel .getChatModelRequest ().getUserMessages () != null ) {
98
- for (String userMessage : callAIChatModel .getChatModelRequest ().getUserMessages ()) {
99
- String fixedUserMessage = replaceVariables (userMessage , substitutions );
100
- messages .add (new UserMessage (fixedUserMessage ));
101
- }
102
- }
103
-
104
- return prepareResponse (chatModel .chat (messages ), javaObject );
105
- }
106
-
107
- private String replaceVariables (String template , Map <String , Object > substitutions ) {
108
- Set <String > variables = extractVariables (template );
109
- for (Map .Entry <String , Object > entry : substitutions .entrySet ()) {
110
- String variable = entry .getKey ();
111
- Object value = entry .getValue ();
112
- if (value != null && variables .contains (variable )) {
113
- template = template .replace ("{{" + variable + "}}" , value .toString ());
114
- }
115
- }
116
- return template ;
117
- }
118
-
119
- private void validate (CallAIChatModel callAIChatModel , Object javaObject ) {
120
- // TODO
121
- }
122
-
123
- private ChatModel createChatModel (CallAIChatModel callAIChatModel ) {
124
- ChatModelService chatModelService = getAvailableModel ();
125
- if (chatModelService != null ) {
126
- return chatModelService .getChatModel (callAIChatModel .getPreferences ());
127
- }
128
- throw new IllegalStateException (
129
- "No LLM models found. Please ensure that you have the required dependencies in your classpath." );
130
- }
131
-
132
- private ChatModelService getAvailableModel () {
133
- ServiceLoader <ChatModelService > loader = ServiceLoader .load (ChatModelService .class );
134
-
135
- for (ChatModelService service : loader ) {
136
- return service ;
137
- }
138
-
139
- throw new IllegalStateException (
140
- "No LLM models found. Please ensure that you have the required dependencies in your classpath." );
141
- }
142
-
143
- private Map <String , Object > prepareResponse (ChatResponse response , Object javaObject ) {
144
-
145
- String id = response .id ();
146
- String modelName = response .modelName ();
147
- TokenUsage tokenUsage = response .tokenUsage ();
148
- FinishReason finishReason = response .finishReason ();
149
- AiMessage aiMessage = response .aiMessage ();
150
-
151
- Map <String , Object > responseMap = (Map <String , Object >) javaObject ;
152
- if (response .id () != null ) {
153
- responseMap .put ("id" , id );
154
- }
155
-
156
- if (modelName != null ) {
157
- responseMap .put ("modelName" , modelName );
158
- }
159
-
160
- if (tokenUsage != null ) {
161
- responseMap .put ("tokenUsage.inputTokenCount" , tokenUsage .inputTokenCount ());
162
- responseMap .put ("tokenUsage.outputTokenCount" , tokenUsage .outputTokenCount ());
163
- responseMap .put ("tokenUsage.totalTokenCount" , tokenUsage .totalTokenCount ());
164
- }
165
-
166
- if (finishReason != null ) {
167
- responseMap .put ("finishReason" , finishReason .name ());
168
- }
169
-
170
- if (aiMessage != null ) {
171
- responseMap .put ("text" , aiMessage .text ());
172
- }
173
-
174
- return responseMap ;
175
- }
176
-
177
- private static Set <String > extractVariables (String template ) {
178
- Set <String > variables = new HashSet <>();
179
- Matcher matcher = VARIABLE_PATTERN .matcher (template );
180
- while (matcher .find ()) {
181
- variables .add (matcher .group (1 ));
182
- }
183
- return variables ;
184
- }
185
59
}
0 commit comments