Skip to content

Commit 742b816

Browse files
authored
Merge pull request #1144 from cescoffier/fix-blocking-memory-store-in-streamed-response
Fix Blocking Memory Store Usage in Streamed Mode
2 parents 1ca2656 + 0690a53 commit 742b816

File tree

7 files changed

+379
-48
lines changed

7 files changed

+379
-48
lines changed

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1295,13 +1295,15 @@ private AiServiceMethodCreateInfo gatherMethodMetadata(
12951295
String responseAugmenterClassName = AiServicesMethodBuildItem.gatherResponseAugmenter(method);
12961296

12971297
// Detect if tools execution may block the caller thread.
1298-
boolean switchToWorkerThread = detectIfToolExecutionRequiresAWorkerThread(method, tools, methodToolClassNames);
1298+
boolean switchToWorkerThreadForToolExecution = detectIfToolExecutionRequiresAWorkerThread(method, tools,
1299+
methodToolClassNames);
12991300

13001301
return new AiServiceMethodCreateInfo(method.declaringClass().name().toString(), method.name(), systemMessageInfo,
13011302
userMessageInfo, memoryIdParamPosition, requiresModeration,
13021303
returnTypeSignature(method.returnType(), new TypeArgMapper(method.declaringClass(), index)),
1303-
metricsTimedInfo, metricsCountedInfo, spanInfo, responseSchemaInfo, methodToolClassNames, switchToWorkerThread,
1304-
inputGuardrails, outputGuardrails, accumulatorClassName, responseAugmenterClassName);
1304+
metricsTimedInfo, metricsCountedInfo, spanInfo, responseSchemaInfo, methodToolClassNames,
1305+
switchToWorkerThreadForToolExecution, inputGuardrails, outputGuardrails, accumulatorClassName,
1306+
responseAugmenterClassName);
13051307
}
13061308

13071309
private Optional<JsonSchema> jsonSchemaFrom(java.lang.reflect.Type returnType) {
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
package io.quarkiverse.langchain4j.test.streaming;
2+
3+
import static org.assertj.core.api.Assertions.assertThat;
4+
5+
import java.util.List;
6+
import java.util.UUID;
7+
import java.util.concurrent.CountDownLatch;
8+
import java.util.concurrent.TimeUnit;
9+
10+
import jakarta.enterprise.context.control.ActivateRequestContext;
11+
import jakarta.inject.Inject;
12+
13+
import org.jboss.shrinkwrap.api.ShrinkWrap;
14+
import org.jboss.shrinkwrap.api.spec.JavaArchive;
15+
import org.junit.jupiter.api.BeforeEach;
16+
import org.junit.jupiter.api.RepeatedTest;
17+
import org.junit.jupiter.api.Test;
18+
import org.junit.jupiter.api.extension.RegisterExtension;
19+
20+
import dev.langchain4j.service.MemoryId;
21+
import dev.langchain4j.service.UserMessage;
22+
import io.quarkiverse.langchain4j.RegisterAiService;
23+
import io.quarkus.arc.Arc;
24+
import io.quarkus.test.QuarkusUnitTest;
25+
import io.smallrye.common.vertx.VertxContext;
26+
import io.smallrye.mutiny.Multi;
27+
import io.vertx.core.Context;
28+
import io.vertx.core.Vertx;
29+
30+
public class BlockingMemoryStoreOnStreamedResponseTest {
31+
32+
@RegisterExtension
33+
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
34+
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
35+
.addClasses(StreamTestUtils.class));
36+
37+
@Inject
38+
MyAiService service;
39+
40+
@RepeatedTest(100) // Verify that the order is preserved.
41+
@ActivateRequestContext
42+
void testFromWorkerThread() {
43+
// We are on a worker thread.
44+
List<String> list = service.hi("123", "Say hello").collect().asList().await().indefinitely();
45+
// We cannot guarantee the order, as we do not have a context.
46+
assertThat(list).containsExactly("Hi!", " ", "World!");
47+
48+
list = service.hi("123", "Second message").collect().asList().await().indefinitely();
49+
assertThat(list).containsExactly("OK!");
50+
}
51+
52+
@BeforeEach
53+
void cleanup() {
54+
StreamTestUtils.FakeMemoryStore.DC_DATA = null;
55+
}
56+
57+
@RepeatedTest(10)
58+
void testFromDuplicatedContextThread() throws InterruptedException {
59+
Context context = VertxContext.getOrCreateDuplicatedContext(vertx);
60+
CountDownLatch latch = new CountDownLatch(1);
61+
context.executeBlocking(v -> {
62+
try {
63+
Arc.container().requestContext().activate();
64+
var value = UUID.randomUUID().toString();
65+
StreamTestUtils.FakeMemoryStore.DC_DATA = value;
66+
Vertx.currentContext().putLocal("DC_DATA", value);
67+
List<String> list = service.hi("123", "Say hello").collect().asList().await().indefinitely();
68+
assertThat(list).containsExactly("Hi!", " ", "World!");
69+
Arc.container().requestContext().deactivate();
70+
71+
Arc.container().requestContext().activate();
72+
73+
list = service.hi("123", "Second message").collect().asList().await().indefinitely();
74+
assertThat(list).containsExactly("OK!");
75+
latch.countDown();
76+
77+
} finally {
78+
Arc.container().requestContext().deactivate();
79+
Vertx.currentContext().removeLocal("DC_DATA");
80+
81+
}
82+
}, false);
83+
assertThat(latch.await(10, TimeUnit.SECONDS)).isTrue();
84+
}
85+
86+
@Inject
87+
Vertx vertx;
88+
89+
@Test
90+
void testFromEventLoopThread() throws InterruptedException {
91+
CountDownLatch latch = new CountDownLatch(1);
92+
Context context = vertx.getOrCreateContext();
93+
context.runOnContext(v -> {
94+
Arc.container().requestContext().activate();
95+
try {
96+
service.hi("123", "Say hello").collect().asList()
97+
.subscribe().asCompletionStage();
98+
} catch (Exception e) {
99+
assertThat(e)
100+
.isNotNull()
101+
.hasMessageContaining("Expected to be able to block");
102+
} finally {
103+
Arc.container().requestContext().deactivate();
104+
latch.countDown();
105+
}
106+
});
107+
latch.await();
108+
}
109+
110+
@RegisterAiService(streamingChatLanguageModelSupplier = StreamTestUtils.FakeStreamedChatModelSupplier.class, chatMemoryProviderSupplier = StreamTestUtils.FakeMemoryProviderSupplier.class)
111+
public interface MyAiService {
112+
113+
Multi<String> hi(@MemoryId String id, @UserMessage String query);
114+
115+
}
116+
117+
}
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
package io.quarkiverse.langchain4j.test.streaming;
2+
3+
import java.util.List;
4+
import java.util.Map;
5+
import java.util.concurrent.ConcurrentHashMap;
6+
import java.util.concurrent.CopyOnWriteArrayList;
7+
import java.util.function.Supplier;
8+
9+
import dev.langchain4j.data.message.AiMessage;
10+
import dev.langchain4j.data.message.ChatMessage;
11+
import dev.langchain4j.data.message.UserMessage;
12+
import dev.langchain4j.memory.ChatMemory;
13+
import dev.langchain4j.memory.chat.ChatMemoryProvider;
14+
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
15+
import dev.langchain4j.model.StreamingResponseHandler;
16+
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
17+
import dev.langchain4j.model.output.Response;
18+
import dev.langchain4j.store.memory.chat.ChatMemoryStore;
19+
import io.quarkus.arc.Arc;
20+
import io.smallrye.mutiny.infrastructure.Infrastructure;
21+
import io.vertx.core.Vertx;
22+
23+
/**
24+
* Utility class for streaming tests.
25+
*/
26+
public class StreamTestUtils {
27+
28+
public static class FakeMemoryProviderSupplier implements Supplier<ChatMemoryProvider> {
29+
@Override
30+
public ChatMemoryProvider get() {
31+
return new ChatMemoryProvider() {
32+
@Override
33+
public ChatMemory get(Object memoryId) {
34+
return new MessageWindowChatMemory.Builder()
35+
.id(memoryId)
36+
.maxMessages(10)
37+
.chatMemoryStore(new FakeMemoryStore())
38+
.build();
39+
}
40+
};
41+
}
42+
}
43+
44+
public static class FakeStreamedChatModelSupplier implements Supplier<StreamingChatLanguageModel> {
45+
46+
@Override
47+
public StreamingChatLanguageModel get() {
48+
return new FakeStreamedChatModel();
49+
}
50+
}
51+
52+
public static class FakeStreamedChatModel implements StreamingChatLanguageModel {
53+
54+
@Override
55+
public void generate(List<ChatMessage> messages, StreamingResponseHandler<AiMessage> handler) {
56+
Vertx vertx = Arc.container().select(Vertx.class).get();
57+
var ctxt = vertx.getOrCreateContext();
58+
59+
if (messages.size() > 1) {
60+
var last = (UserMessage) messages.get(messages.size() - 1);
61+
if (last.singleText().equalsIgnoreCase("Second message")) {
62+
if (messages.size() < 3) {
63+
ctxt.runOnContext(x -> handler.onError(new IllegalStateException("Error - no memory")));
64+
return;
65+
} else {
66+
ctxt.runOnContext(x -> {
67+
handler.onNext("OK!");
68+
handler.onComplete(Response.from(AiMessage.from("")));
69+
});
70+
return;
71+
}
72+
}
73+
}
74+
75+
ctxt.runOnContext(x1 -> {
76+
handler.onNext("Hi!");
77+
ctxt.runOnContext(x2 -> {
78+
handler.onNext(" ");
79+
ctxt.runOnContext(x3 -> {
80+
handler.onNext("World!");
81+
ctxt.runOnContext(x -> handler.onComplete(Response.from(AiMessage.from(""))));
82+
});
83+
});
84+
});
85+
}
86+
}
87+
88+
public static class FakeMemoryStore implements ChatMemoryStore {
89+
90+
public static String DC_DATA;
91+
92+
private static final Map<Object, List<ChatMessage>> memories = new ConcurrentHashMap<>();
93+
94+
private void checkDuplicatedContext() {
95+
if (DC_DATA != null) {
96+
if (!DC_DATA.equals(Vertx.currentContext().getLocal("DC_DATA"))) {
97+
throw new AssertionError("Expected to be in the same context");
98+
}
99+
}
100+
}
101+
102+
@Override
103+
public List<ChatMessage> getMessages(Object memoryId) {
104+
if (!Infrastructure.canCallerThreadBeBlocked()) {
105+
throw new AssertionError("Expected to be able to block");
106+
}
107+
checkDuplicatedContext();
108+
return memories.computeIfAbsent(memoryId, x -> new CopyOnWriteArrayList<>());
109+
}
110+
111+
@Override
112+
public void updateMessages(Object memoryId, List<ChatMessage> messages) {
113+
if (!Infrastructure.canCallerThreadBeBlocked()) {
114+
throw new AssertionError("Expected to be able to block");
115+
}
116+
memories.put(memoryId, messages);
117+
}
118+
119+
@Override
120+
public void deleteMessages(Object memoryId) {
121+
if (!Infrastructure.canCallerThreadBeBlocked()) {
122+
throw new AssertionError("Expected to be able to block");
123+
}
124+
}
125+
}
126+
}

core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodCreateInfo.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ public final class AiServiceMethodCreateInfo {
6060
private OutputTokenAccumulator accumulator;
6161

6262
private final LazyValue<Integer> guardrailsMaxRetry;
63-
private final boolean switchToWorkerThread;
63+
private final boolean switchToWorkerThreadForToolExecution;
6464

6565
@RecordableConstructor
6666
public AiServiceMethodCreateInfo(String interfaceName, String methodName,
@@ -74,7 +74,7 @@ public AiServiceMethodCreateInfo(String interfaceName, String methodName,
7474
Optional<SpanInfo> spanInfo,
7575
ResponseSchemaInfo responseSchemaInfo,
7676
List<String> toolClassNames,
77-
boolean switchToWorkerThread,
77+
boolean switchToWorkerThreadForToolExecution,
7878
List<String> inputGuardrailsClassNames,
7979
List<String> outputGuardrailsClassNames,
8080
String outputTokenAccumulatorClassName,
@@ -108,7 +108,7 @@ public Integer get() {
108108
.orElse(GuardrailsConfig.MAX_RETRIES_DEFAULT);
109109
}
110110
});
111-
this.switchToWorkerThread = switchToWorkerThread;
111+
this.switchToWorkerThreadForToolExecution = switchToWorkerThreadForToolExecution;
112112
this.responseAugmenterClassName = responseAugmenterClassName;
113113
}
114114

@@ -238,8 +238,8 @@ public String getUserMessageTemplate() {
238238
return userMessageTemplateOpt.orElse("");
239239
}
240240

241-
public boolean isSwitchToWorkerThread() {
242-
return switchToWorkerThread;
241+
public boolean isSwitchToWorkerThreadForToolExecution() {
242+
return switchToWorkerThreadForToolExecution;
243243
}
244244

245245
public void setResponseAugmenter(Class<? extends AiResponseAugmenter<?>> augmenter) {

core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ public Object implement(Input input) {
152152

153153
private static Object doImplement(AiServiceMethodCreateInfo methodCreateInfo, Object[] methodArgs,
154154
QuarkusAiServiceContext context, Audit audit) {
155+
boolean isRunningOnWorkerThread = !Context.isOnEventLoopThread();
155156
Object memoryId = memoryId(methodCreateInfo, methodArgs, context.chatMemoryProvider != null);
156157
Optional<SystemMessage> systemMessage = prepareSystemMessage(methodCreateInfo, methodArgs,
157158
context.hasChatMemory() ? context.chatMemory(memoryId).messages() : Collections.emptyList());
@@ -227,7 +228,7 @@ public Flow.Publisher<?> apply(AugmentationResult ar) {
227228
List<ChatMessage> messagesToSend = messagesToSend(guardrailsMessage, needsMemorySeed);
228229
var stream = new TokenStreamMulti(messagesToSend, effectiveToolSpecifications,
229230
finalToolExecutors, ar.contents(), context, memoryId,
230-
methodCreateInfo.isSwitchToWorkerThread());
231+
methodCreateInfo.isSwitchToWorkerThreadForToolExecution(), isRunningOnWorkerThread);
231232
return stream.plug(m -> ResponseAugmenterSupport.apply(m, methodCreateInfo,
232233
new ResponseAugmenterParams((UserMessage) augmentedUserMessage,
233234
memory, ar, methodCreateInfo.getUserMessageTemplate(),
@@ -278,7 +279,7 @@ private List<ChatMessage> messagesToSend(UserMessage augmentedUserMessage,
278279
if (methodCreateInfo.getOutputGuardrailsClassNames().isEmpty()) {
279280
var stream = new TokenStreamMulti(messagesToSend, toolSpecifications, toolExecutors,
280281
(augmentationResult != null ? augmentationResult.contents() : null), context, memoryId,
281-
methodCreateInfo.isSwitchToWorkerThread());
282+
methodCreateInfo.isSwitchToWorkerThreadForToolExecution(), isRunningOnWorkerThread);
282283
return stream.plug(m -> ResponseAugmenterSupport.apply(m, methodCreateInfo,
283284
new ResponseAugmenterParams(actualUserMessage,
284285
chatMemory, actualAugmentationResult, methodCreateInfo.getUserMessageTemplate(),
@@ -287,7 +288,7 @@ private List<ChatMessage> messagesToSend(UserMessage augmentedUserMessage,
287288

288289
return new TokenStreamMulti(messagesToSend, toolSpecifications, toolExecutors,
289290
(augmentationResult != null ? augmentationResult.contents() : null), context, memoryId,
290-
methodCreateInfo.isSwitchToWorkerThread())
291+
methodCreateInfo.isSwitchToWorkerThreadForToolExecution(), isRunningOnWorkerThread)
291292
.plug(s -> GuardrailsSupport.accumulate(s, methodCreateInfo))
292293
.map(chunk -> {
293294
OutputGuardrailResult result;
@@ -785,19 +786,22 @@ private static class TokenStreamMulti extends AbstractMulti<String> implements M
785786
private final List<Content> contents;
786787
private final QuarkusAiServiceContext context;
787788
private final Object memoryId;
788-
private final boolean mustSwitchToWorkerThread;
789+
private final boolean switchToWorkerThreadForToolExecution;
790+
private final boolean isCallerRunningOnWorkerThread;
789791

790792
public TokenStreamMulti(List<ChatMessage> messagesToSend, List<ToolSpecification> toolSpecifications,
791793
Map<String, ToolExecutor> toolExecutors,
792-
List<Content> contents, QuarkusAiServiceContext context, Object memoryId, boolean mustSwitchToWorkerThread) {
794+
List<Content> contents, QuarkusAiServiceContext context, Object memoryId,
795+
boolean switchToWorkerThreadForToolExecution, boolean isCallerRunningOnWorkerThread) {
793796
// We need to pass and store the parameters to the constructor because we need to re-create a stream on every subscription.
794797
this.messagesToSend = messagesToSend;
795798
this.toolSpecifications = toolSpecifications;
796799
this.toolsExecutors = toolExecutors;
797800
this.contents = contents;
798801
this.context = context;
799802
this.memoryId = memoryId;
800-
this.mustSwitchToWorkerThread = mustSwitchToWorkerThread;
803+
this.switchToWorkerThreadForToolExecution = switchToWorkerThreadForToolExecution;
804+
this.isCallerRunningOnWorkerThread = isCallerRunningOnWorkerThread;
801805
}
802806

803807
@Override
@@ -810,19 +814,20 @@ public void subscribe(MultiSubscriber<? super String> subscriber) {
810814

811815
private void createTokenStream(UnicastProcessor<String> processor) {
812816
Context ctxt = null;
813-
if (mustSwitchToWorkerThread) {
817+
if (switchToWorkerThreadForToolExecution || isCallerRunningOnWorkerThread) {
814818
// we create or retrieve the current context, to use `executeBlocking` when required.
815819
ctxt = VertxContext.getOrCreateDuplicatedContext();
816820
}
817821

818822
var stream = new QuarkusAiServiceTokenStream(messagesToSend, toolSpecifications,
819-
toolsExecutors, contents, context, memoryId, ctxt, mustSwitchToWorkerThread);
823+
toolsExecutors, contents, context, memoryId, ctxt, switchToWorkerThreadForToolExecution,
824+
isCallerRunningOnWorkerThread);
820825
TokenStream tokenStream = stream
821826
.onNext(processor::onNext)
822827
.onComplete(message -> processor.onComplete())
823828
.onError(processor::onError);
824829
// This is equivalent to "run subscription on worker thread"
825-
if (mustSwitchToWorkerThread && Context.isOnEventLoopThread()) {
830+
if (switchToWorkerThreadForToolExecution && Context.isOnEventLoopThread()) {
826831
ctxt.executeBlocking(new Callable<Void>() {
827832
@Override
828833
public Void call() {

0 commit comments

Comments
 (0)