Skip to content

Commit b1fe7df

Browse files
committed
Prevent Jlama inference to block vertx event loop
1 parent c63b96e commit b1fe7df

File tree

3 files changed

+73
-17
lines changed

3 files changed

+73
-17
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
package io.quarkiverse.langchain4j.runtime;
2+
3+
import java.util.concurrent.Callable;
4+
5+
import io.smallrye.common.vertx.VertxContext;
6+
import io.smallrye.mutiny.infrastructure.Infrastructure;
7+
import io.vertx.core.Context;
8+
9+
public class VertxUtil {
10+
11+
public static void runOutEventLoop(Runnable runnable) {
12+
if (Context.isOnEventLoopThread()) {
13+
Context executionContext = VertxContext.getOrCreateDuplicatedContext();
14+
if (executionContext != null) {
15+
executionContext.executeBlocking(new Callable<Object>() {
16+
@Override
17+
public Object call() {
18+
runnable.run();
19+
return null;
20+
}
21+
});
22+
} else {
23+
Infrastructure.getDefaultWorkerPool().execute(runnable);
24+
}
25+
} else {
26+
runnable.run();
27+
}
28+
}
29+
}

model-providers/jlama/runtime/src/main/java/io/quarkiverse/langchain4j/jlama/JlamaStreamingChatModel.java

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package io.quarkiverse.langchain4j.jlama;
22

33
import static io.quarkiverse.langchain4j.jlama.JlamaModel.toFinishReason;
4+
import static io.quarkiverse.langchain4j.runtime.VertxUtil.runOutEventLoop;
45

56
import java.nio.file.Path;
67
import java.util.List;
@@ -10,6 +11,7 @@
1011
import com.github.tjake.jlama.model.AbstractModel;
1112
import com.github.tjake.jlama.model.functions.Generator;
1213
import com.github.tjake.jlama.safetensors.DType;
14+
import com.github.tjake.jlama.safetensors.prompt.PromptContext;
1315
import com.github.tjake.jlama.safetensors.prompt.PromptSupport;
1416

1517
import dev.langchain4j.data.message.AiMessage;
@@ -32,17 +34,21 @@ public JlamaStreamingChatModel(JlamaStreamingChatModelBuilder builder) {
3234
.withRetry(() -> registry.downloadModel(builder.modelName, Optional.ofNullable(builder.authToken)), 3);
3335

3436
JlamaModel.Loader loader = jlamaModel.loader();
35-
if (builder.quantizeModelAtRuntime != null && builder.quantizeModelAtRuntime)
37+
if (builder.quantizeModelAtRuntime != null && builder.quantizeModelAtRuntime) {
3638
loader = loader.quantized();
39+
}
3740

38-
if (builder.workingQuantizedType != null)
41+
if (builder.workingQuantizedType != null) {
3942
loader = loader.workingQuantizationType(builder.workingQuantizedType);
43+
}
4044

41-
if (builder.threadCount != null)
45+
if (builder.threadCount != null) {
4246
loader = loader.threadCount(builder.threadCount);
47+
}
4348

44-
if (builder.workingDirectory != null)
49+
if (builder.workingDirectory != null) {
4550
loader = loader.workingDirectory(builder.workingDirectory);
51+
}
4652

4753
this.model = loader.load();
4854
this.temperature = builder.temperature == null ? 0.7f : builder.temperature;
@@ -55,21 +61,18 @@ public static JlamaStreamingChatModelBuilder builder() {
5561

5662
@Override
5763
public void generate(List<ChatMessage> messages, StreamingResponseHandler<AiMessage> handler) {
58-
if (model.promptSupport().isEmpty())
59-
throw new UnsupportedOperationException("This model does not support chat generation");
60-
61-
PromptSupport.Builder promptBuilder = model.promptSupport().get().builder();
62-
for (ChatMessage message : messages) {
63-
switch (message.type()) {
64-
case SYSTEM -> promptBuilder.addSystemMessage(message.text());
65-
case USER -> promptBuilder.addUserMessage(message.text());
66-
case AI -> promptBuilder.addAssistantMessage(message.text());
67-
default -> throw new IllegalArgumentException("Unsupported message type: " + message.type());
64+
PromptContext promptContext = createPromptContext(messages);
65+
runOutEventLoop(new Runnable() {
66+
@Override
67+
public void run() {
68+
internalGenerate(handler, promptContext);
6869
}
69-
}
70+
});
71+
}
7072

73+
private void internalGenerate(StreamingResponseHandler<AiMessage> handler, PromptContext promptContext) {
7174
try {
72-
Generator.Response r = model.generate(id, promptBuilder.build(), temperature, maxTokens, (token, time) -> {
75+
Generator.Response r = model.generate(id, promptContext, temperature, maxTokens, (token, time) -> {
7376
handler.onNext(token);
7477
});
7578

@@ -80,6 +83,23 @@ public void generate(List<ChatMessage> messages, StreamingResponseHandler<AiMess
8083
}
8184
}
8285

86+
private PromptContext createPromptContext(List<ChatMessage> messages) {
87+
if (model.promptSupport().isEmpty()) {
88+
throw new UnsupportedOperationException("This model does not support chat generation");
89+
}
90+
91+
PromptSupport.Builder promptBuilder = model.promptSupport().get().builder();
92+
for (ChatMessage message : messages) {
93+
switch (message.type()) {
94+
case SYSTEM -> promptBuilder.addSystemMessage(message.text());
95+
case USER -> promptBuilder.addUserMessage(message.text());
96+
case AI -> promptBuilder.addAssistantMessage(message.text());
97+
default -> throw new IllegalArgumentException("Unsupported message type: " + message.type());
98+
}
99+
}
100+
return promptBuilder.build();
101+
}
102+
83103
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
84104
public static class JlamaStreamingChatModelBuilder {
85105

model-providers/llama3-java/runtime/src/main/java/io/quarkiverse/langchain4j/llama3/Llama3StreamingChatModel.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import static io.quarkiverse.langchain4j.llama3.MessageMapper.toLlama3Message;
55
import static io.quarkiverse.langchain4j.llama3.copy.Llama3.BATCH_SIZE;
66
import static io.quarkiverse.langchain4j.llama3.copy.Llama3.selectSampler;
7+
import static io.quarkiverse.langchain4j.runtime.VertxUtil.runOutEventLoop;
78

89
import java.io.IOException;
910
import java.io.UncheckedIOException;
@@ -79,7 +80,13 @@ public void generate(List<ChatMessage> messages, StreamingResponseHandler<AiMess
7980
);
8081
Sampler sampler = selectSampler(model.configuration().vocabularySize, options.temperature(), options.topp(),
8182
options.seed());
82-
runInference(model, sampler, options, llama3Messages, handler);
83+
84+
runOutEventLoop(new Runnable() {
85+
@Override
86+
public void run() {
87+
runInference(model, sampler, options, llama3Messages, handler);
88+
}
89+
});
8390
}
8491

8592
private void runInference(Llama model, Sampler sampler, Llama3.Options options,

0 commit comments

Comments
 (0)