Skip to content

Commit 0690a53

Browse files
committed
Fix Blocking Memory Store Usage in Streamed Mode
This commit addresses issues with using the blocking memory store in streamed responses. * Ensures the execution captures whether the caller is running on a worker thread. * Switches to worker threads for every emission and completion event when the caller is using a worker thread. * Relies on executeBlocking to propagate the context automatically when possible. Note: * The blocking memory store cannot be used when invoked on the event loop. It now requires that the caller must be on a worker thread.
1 parent 5ec8e05 commit 0690a53

File tree

7 files changed

+379
-48
lines changed

7 files changed

+379
-48
lines changed

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1284,13 +1284,15 @@ private AiServiceMethodCreateInfo gatherMethodMetadata(
12841284
String responseAugmenterClassName = AiServicesMethodBuildItem.gatherResponseAugmenter(method);
12851285

12861286
// Detect if tools execution may block the caller thread.
1287-
boolean switchToWorkerThread = detectIfToolExecutionRequiresAWorkerThread(method, tools, methodToolClassNames);
1287+
boolean switchToWorkerThreadForToolExecution = detectIfToolExecutionRequiresAWorkerThread(method, tools,
1288+
methodToolClassNames);
12881289

12891290
return new AiServiceMethodCreateInfo(method.declaringClass().name().toString(), method.name(), systemMessageInfo,
12901291
userMessageInfo, memoryIdParamPosition, requiresModeration,
12911292
returnTypeSignature(method.returnType(), new TypeArgMapper(method.declaringClass(), index)),
1292-
metricsTimedInfo, metricsCountedInfo, spanInfo, responseSchemaInfo, methodToolClassNames, switchToWorkerThread,
1293-
inputGuardrails, outputGuardrails, accumulatorClassName, responseAugmenterClassName);
1293+
metricsTimedInfo, metricsCountedInfo, spanInfo, responseSchemaInfo, methodToolClassNames,
1294+
switchToWorkerThreadForToolExecution, inputGuardrails, outputGuardrails, accumulatorClassName,
1295+
responseAugmenterClassName);
12941296
}
12951297

12961298
private boolean detectIfToolExecutionRequiresAWorkerThread(MethodInfo method, List<ToolMethodBuildItem> tools,
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
package io.quarkiverse.langchain4j.test.streaming;
2+
3+
import static org.assertj.core.api.Assertions.assertThat;
4+
5+
import java.util.List;
6+
import java.util.UUID;
7+
import java.util.concurrent.CountDownLatch;
8+
import java.util.concurrent.TimeUnit;
9+
10+
import jakarta.enterprise.context.control.ActivateRequestContext;
11+
import jakarta.inject.Inject;
12+
13+
import org.jboss.shrinkwrap.api.ShrinkWrap;
14+
import org.jboss.shrinkwrap.api.spec.JavaArchive;
15+
import org.junit.jupiter.api.BeforeEach;
16+
import org.junit.jupiter.api.RepeatedTest;
17+
import org.junit.jupiter.api.Test;
18+
import org.junit.jupiter.api.extension.RegisterExtension;
19+
20+
import dev.langchain4j.service.MemoryId;
21+
import dev.langchain4j.service.UserMessage;
22+
import io.quarkiverse.langchain4j.RegisterAiService;
23+
import io.quarkus.arc.Arc;
24+
import io.quarkus.test.QuarkusUnitTest;
25+
import io.smallrye.common.vertx.VertxContext;
26+
import io.smallrye.mutiny.Multi;
27+
import io.vertx.core.Context;
28+
import io.vertx.core.Vertx;
29+
30+
public class BlockingMemoryStoreOnStreamedResponseTest {
31+
32+
@RegisterExtension
33+
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
34+
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
35+
.addClasses(StreamTestUtils.class));
36+
37+
@Inject
38+
MyAiService service;
39+
40+
@RepeatedTest(100) // Verify that the order is preserved.
41+
@ActivateRequestContext
42+
void testFromWorkerThread() {
43+
// We are on a worker thread.
44+
List<String> list = service.hi("123", "Say hello").collect().asList().await().indefinitely();
45+
// We cannot guarantee the order, as we do not have a context.
46+
assertThat(list).containsExactly("Hi!", " ", "World!");
47+
48+
list = service.hi("123", "Second message").collect().asList().await().indefinitely();
49+
assertThat(list).containsExactly("OK!");
50+
}
51+
52+
@BeforeEach
53+
void cleanup() {
54+
StreamTestUtils.FakeMemoryStore.DC_DATA = null;
55+
}
56+
57+
@RepeatedTest(10)
58+
void testFromDuplicatedContextThread() throws InterruptedException {
59+
Context context = VertxContext.getOrCreateDuplicatedContext(vertx);
60+
CountDownLatch latch = new CountDownLatch(1);
61+
context.executeBlocking(v -> {
62+
try {
63+
Arc.container().requestContext().activate();
64+
var value = UUID.randomUUID().toString();
65+
StreamTestUtils.FakeMemoryStore.DC_DATA = value;
66+
Vertx.currentContext().putLocal("DC_DATA", value);
67+
List<String> list = service.hi("123", "Say hello").collect().asList().await().indefinitely();
68+
assertThat(list).containsExactly("Hi!", " ", "World!");
69+
Arc.container().requestContext().deactivate();
70+
71+
Arc.container().requestContext().activate();
72+
73+
list = service.hi("123", "Second message").collect().asList().await().indefinitely();
74+
assertThat(list).containsExactly("OK!");
75+
latch.countDown();
76+
77+
} finally {
78+
Arc.container().requestContext().deactivate();
79+
Vertx.currentContext().removeLocal("DC_DATA");
80+
81+
}
82+
}, false);
83+
assertThat(latch.await(10, TimeUnit.SECONDS)).isTrue();
84+
}
85+
86+
@Inject
87+
Vertx vertx;
88+
89+
@Test
90+
void testFromEventLoopThread() throws InterruptedException {
91+
CountDownLatch latch = new CountDownLatch(1);
92+
Context context = vertx.getOrCreateContext();
93+
context.runOnContext(v -> {
94+
Arc.container().requestContext().activate();
95+
try {
96+
service.hi("123", "Say hello").collect().asList()
97+
.subscribe().asCompletionStage();
98+
} catch (Exception e) {
99+
assertThat(e)
100+
.isNotNull()
101+
.hasMessageContaining("Expected to be able to block");
102+
} finally {
103+
Arc.container().requestContext().deactivate();
104+
latch.countDown();
105+
}
106+
});
107+
latch.await();
108+
}
109+
110+
@RegisterAiService(streamingChatLanguageModelSupplier = StreamTestUtils.FakeStreamedChatModelSupplier.class, chatMemoryProviderSupplier = StreamTestUtils.FakeMemoryProviderSupplier.class)
111+
public interface MyAiService {
112+
113+
Multi<String> hi(@MemoryId String id, @UserMessage String query);
114+
115+
}
116+
117+
}
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
package io.quarkiverse.langchain4j.test.streaming;
2+
3+
import java.util.List;
4+
import java.util.Map;
5+
import java.util.concurrent.ConcurrentHashMap;
6+
import java.util.concurrent.CopyOnWriteArrayList;
7+
import java.util.function.Supplier;
8+
9+
import dev.langchain4j.data.message.AiMessage;
10+
import dev.langchain4j.data.message.ChatMessage;
11+
import dev.langchain4j.data.message.UserMessage;
12+
import dev.langchain4j.memory.ChatMemory;
13+
import dev.langchain4j.memory.chat.ChatMemoryProvider;
14+
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
15+
import dev.langchain4j.model.StreamingResponseHandler;
16+
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
17+
import dev.langchain4j.model.output.Response;
18+
import dev.langchain4j.store.memory.chat.ChatMemoryStore;
19+
import io.quarkus.arc.Arc;
20+
import io.smallrye.mutiny.infrastructure.Infrastructure;
21+
import io.vertx.core.Vertx;
22+
23+
/**
24+
* Utility class for streaming tests.
25+
*/
26+
public class StreamTestUtils {
27+
28+
public static class FakeMemoryProviderSupplier implements Supplier<ChatMemoryProvider> {
29+
@Override
30+
public ChatMemoryProvider get() {
31+
return new ChatMemoryProvider() {
32+
@Override
33+
public ChatMemory get(Object memoryId) {
34+
return new MessageWindowChatMemory.Builder()
35+
.id(memoryId)
36+
.maxMessages(10)
37+
.chatMemoryStore(new FakeMemoryStore())
38+
.build();
39+
}
40+
};
41+
}
42+
}
43+
44+
public static class FakeStreamedChatModelSupplier implements Supplier<StreamingChatLanguageModel> {
45+
46+
@Override
47+
public StreamingChatLanguageModel get() {
48+
return new FakeStreamedChatModel();
49+
}
50+
}
51+
52+
public static class FakeStreamedChatModel implements StreamingChatLanguageModel {
53+
54+
@Override
55+
public void generate(List<ChatMessage> messages, StreamingResponseHandler<AiMessage> handler) {
56+
Vertx vertx = Arc.container().select(Vertx.class).get();
57+
var ctxt = vertx.getOrCreateContext();
58+
59+
if (messages.size() > 1) {
60+
var last = (UserMessage) messages.get(messages.size() - 1);
61+
if (last.singleText().equalsIgnoreCase("Second message")) {
62+
if (messages.size() < 3) {
63+
ctxt.runOnContext(x -> handler.onError(new IllegalStateException("Error - no memory")));
64+
return;
65+
} else {
66+
ctxt.runOnContext(x -> {
67+
handler.onNext("OK!");
68+
handler.onComplete(Response.from(AiMessage.from("")));
69+
});
70+
return;
71+
}
72+
}
73+
}
74+
75+
ctxt.runOnContext(x1 -> {
76+
handler.onNext("Hi!");
77+
ctxt.runOnContext(x2 -> {
78+
handler.onNext(" ");
79+
ctxt.runOnContext(x3 -> {
80+
handler.onNext("World!");
81+
ctxt.runOnContext(x -> handler.onComplete(Response.from(AiMessage.from(""))));
82+
});
83+
});
84+
});
85+
}
86+
}
87+
88+
public static class FakeMemoryStore implements ChatMemoryStore {
89+
90+
public static String DC_DATA;
91+
92+
private static final Map<Object, List<ChatMessage>> memories = new ConcurrentHashMap<>();
93+
94+
private void checkDuplicatedContext() {
95+
if (DC_DATA != null) {
96+
if (!DC_DATA.equals(Vertx.currentContext().getLocal("DC_DATA"))) {
97+
throw new AssertionError("Expected to be in the same context");
98+
}
99+
}
100+
}
101+
102+
@Override
103+
public List<ChatMessage> getMessages(Object memoryId) {
104+
if (!Infrastructure.canCallerThreadBeBlocked()) {
105+
throw new AssertionError("Expected to be able to block");
106+
}
107+
checkDuplicatedContext();
108+
return memories.computeIfAbsent(memoryId, x -> new CopyOnWriteArrayList<>());
109+
}
110+
111+
@Override
112+
public void updateMessages(Object memoryId, List<ChatMessage> messages) {
113+
if (!Infrastructure.canCallerThreadBeBlocked()) {
114+
throw new AssertionError("Expected to be able to block");
115+
}
116+
memories.put(memoryId, messages);
117+
}
118+
119+
@Override
120+
public void deleteMessages(Object memoryId) {
121+
if (!Infrastructure.canCallerThreadBeBlocked()) {
122+
throw new AssertionError("Expected to be able to block");
123+
}
124+
}
125+
}
126+
}

core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodCreateInfo.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ public final class AiServiceMethodCreateInfo {
5959
private OutputTokenAccumulator accumulator;
6060

6161
private final LazyValue<Integer> guardrailsMaxRetry;
62-
private final boolean switchToWorkerThread;
62+
private final boolean switchToWorkerThreadForToolExecution;
6363

6464
@RecordableConstructor
6565
public AiServiceMethodCreateInfo(String interfaceName, String methodName,
@@ -73,7 +73,7 @@ public AiServiceMethodCreateInfo(String interfaceName, String methodName,
7373
Optional<SpanInfo> spanInfo,
7474
ResponseSchemaInfo responseSchemaInfo,
7575
List<String> toolClassNames,
76-
boolean switchToWorkerThread,
76+
boolean switchToWorkerThreadForToolExecution,
7777
List<String> inputGuardrailsClassNames,
7878
List<String> outputGuardrailsClassNames,
7979
String outputTokenAccumulatorClassName,
@@ -107,7 +107,7 @@ public Integer get() {
107107
.orElse(GuardrailsConfig.MAX_RETRIES_DEFAULT);
108108
}
109109
});
110-
this.switchToWorkerThread = switchToWorkerThread;
110+
this.switchToWorkerThreadForToolExecution = switchToWorkerThreadForToolExecution;
111111
this.responseAugmenterClassName = responseAugmenterClassName;
112112
}
113113

@@ -237,8 +237,8 @@ public String getUserMessageTemplate() {
237237
return userMessageTemplateOpt.orElse("");
238238
}
239239

240-
public boolean isSwitchToWorkerThread() {
241-
return switchToWorkerThread;
240+
public boolean isSwitchToWorkerThreadForToolExecution() {
241+
return switchToWorkerThreadForToolExecution;
242242
}
243243

244244
public void setResponseAugmenter(Class<? extends AiResponseAugmenter<?>> augmenter) {

core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ public Object implement(Input input) {
145145

146146
private static Object doImplement(AiServiceMethodCreateInfo methodCreateInfo, Object[] methodArgs,
147147
QuarkusAiServiceContext context, Audit audit) {
148+
boolean isRunningOnWorkerThread = !Context.isOnEventLoopThread();
148149
Object memoryId = memoryId(methodCreateInfo, methodArgs, context.chatMemoryProvider != null);
149150
Optional<SystemMessage> systemMessage = prepareSystemMessage(methodCreateInfo, methodArgs,
150151
context.hasChatMemory() ? context.chatMemory(memoryId).messages() : Collections.emptyList());
@@ -217,7 +218,7 @@ public Flow.Publisher<?> apply(AugmentationResult ar) {
217218
List<ChatMessage> messagesToSend = messagesToSend(guardrailsMessage, needsMemorySeed);
218219
var stream = new TokenStreamMulti(messagesToSend, effectiveToolSpecifications,
219220
finalToolExecutors, ar.contents(), context, memoryId,
220-
methodCreateInfo.isSwitchToWorkerThread());
221+
methodCreateInfo.isSwitchToWorkerThreadForToolExecution(), isRunningOnWorkerThread);
221222
return stream.plug(m -> ResponseAugmenterSupport.apply(m, methodCreateInfo,
222223
new ResponseAugmenterParams((UserMessage) augmentedUserMessage,
223224
memory, ar, methodCreateInfo.getUserMessageTemplate(),
@@ -268,7 +269,7 @@ private List<ChatMessage> messagesToSend(UserMessage augmentedUserMessage,
268269
if (methodCreateInfo.getOutputGuardrailsClassNames().isEmpty()) {
269270
var stream = new TokenStreamMulti(messagesToSend, toolSpecifications, toolExecutors,
270271
(augmentationResult != null ? augmentationResult.contents() : null), context, memoryId,
271-
methodCreateInfo.isSwitchToWorkerThread());
272+
methodCreateInfo.isSwitchToWorkerThreadForToolExecution(), isRunningOnWorkerThread);
272273
return stream.plug(m -> ResponseAugmenterSupport.apply(m, methodCreateInfo,
273274
new ResponseAugmenterParams(actualUserMessage,
274275
chatMemory, actualAugmentationResult, methodCreateInfo.getUserMessageTemplate(),
@@ -277,7 +278,7 @@ private List<ChatMessage> messagesToSend(UserMessage augmentedUserMessage,
277278

278279
return new TokenStreamMulti(messagesToSend, toolSpecifications, toolExecutors,
279280
(augmentationResult != null ? augmentationResult.contents() : null), context, memoryId,
280-
methodCreateInfo.isSwitchToWorkerThread())
281+
methodCreateInfo.isSwitchToWorkerThreadForToolExecution(), isRunningOnWorkerThread)
281282
.plug(s -> GuardrailsSupport.accumulate(s, methodCreateInfo))
282283
.map(chunk -> {
283284
OutputGuardrailResult result;
@@ -786,19 +787,22 @@ private static class TokenStreamMulti extends AbstractMulti<String> implements M
786787
private final List<Content> contents;
787788
private final QuarkusAiServiceContext context;
788789
private final Object memoryId;
789-
private final boolean mustSwitchToWorkerThread;
790+
private final boolean switchToWorkerThreadForToolExecution;
791+
private final boolean isCallerRunningOnWorkerThread;
790792

791793
public TokenStreamMulti(List<ChatMessage> messagesToSend, List<ToolSpecification> toolSpecifications,
792794
Map<String, ToolExecutor> toolExecutors,
793-
List<Content> contents, QuarkusAiServiceContext context, Object memoryId, boolean mustSwitchToWorkerThread) {
795+
List<Content> contents, QuarkusAiServiceContext context, Object memoryId,
796+
boolean switchToWorkerThreadForToolExecution, boolean isCallerRunningOnWorkerThread) {
794797
// We need to pass and store the parameters to the constructor because we need to re-create a stream on every subscription.
795798
this.messagesToSend = messagesToSend;
796799
this.toolSpecifications = toolSpecifications;
797800
this.toolsExecutors = toolExecutors;
798801
this.contents = contents;
799802
this.context = context;
800803
this.memoryId = memoryId;
801-
this.mustSwitchToWorkerThread = mustSwitchToWorkerThread;
804+
this.switchToWorkerThreadForToolExecution = switchToWorkerThreadForToolExecution;
805+
this.isCallerRunningOnWorkerThread = isCallerRunningOnWorkerThread;
802806
}
803807

804808
@Override
@@ -811,19 +815,20 @@ public void subscribe(MultiSubscriber<? super String> subscriber) {
811815

812816
private void createTokenStream(UnicastProcessor<String> processor) {
813817
Context ctxt = null;
814-
if (mustSwitchToWorkerThread) {
818+
if (switchToWorkerThreadForToolExecution || isCallerRunningOnWorkerThread) {
815819
// we create or retrieve the current context, to use `executeBlocking` when required.
816820
ctxt = VertxContext.getOrCreateDuplicatedContext();
817821
}
818822

819823
var stream = new QuarkusAiServiceTokenStream(messagesToSend, toolSpecifications,
820-
toolsExecutors, contents, context, memoryId, ctxt, mustSwitchToWorkerThread);
824+
toolsExecutors, contents, context, memoryId, ctxt, switchToWorkerThreadForToolExecution,
825+
isCallerRunningOnWorkerThread);
821826
TokenStream tokenStream = stream
822827
.onNext(processor::onNext)
823828
.onComplete(message -> processor.onComplete())
824829
.onError(processor::onError);
825830
// This is equivalent to "run subscription on worker thread"
826-
if (mustSwitchToWorkerThread && Context.isOnEventLoopThread()) {
831+
if (switchToWorkerThreadForToolExecution && Context.isOnEventLoopThread()) {
827832
ctxt.executeBlocking(new Callable<Void>() {
828833
@Override
829834
public Void call() {

0 commit comments

Comments
 (0)