Skip to content

Commit 6369d00

Browse files
authored
Merge pull request #1011 from cescoffier/output-guardrail-streaming
Implement support for output guardrail on streamed responses
2 parents 7279fe5 + 1747654 commit 6369d00

14 files changed

+1689
-40
lines changed

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
import io.quarkiverse.langchain4j.deployment.items.MethodParameterIgnoredAnnotationsBuildItem;
7171
import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem;
7272
import io.quarkiverse.langchain4j.guardrails.OutputGuardrail;
73+
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailAccumulator;
7374
import io.quarkiverse.langchain4j.runtime.AiServicesRecorder;
7475
import io.quarkiverse.langchain4j.runtime.NamedConfigUtil;
7576
import io.quarkiverse.langchain4j.runtime.QuarkusServiceOutputParser;
@@ -751,11 +752,16 @@ public void markUsedOutputGuardRailsUnremovable(List<AiServicesMethodBuildItem>
751752
for (String cn : list) {
752753
unremovableProducer.produce(UnremovableBeanBuildItem.beanTypes(DotName.createSimple(cn)));
753754
}
755+
DotName dotName = DotName.createSimple(OutputGuardrailAccumulator.class);
756+
if (method.methodInfo.hasAnnotation(dotName)) {
757+
unremovableProducer.produce(
758+
UnremovableBeanBuildItem.beanTypes(method.methodInfo.annotation(dotName).value().asClass().name()));
759+
}
754760
}
755761
}
756762

757763
@BuildStep
758-
public void detectMissingGuardRails(SynthesisFinishedBuildItem synthesisFinished,
764+
public void validateGuardrails(SynthesisFinishedBuildItem synthesisFinished,
759765
List<AiServicesMethodBuildItem> methods,
760766
BuildProducer<ValidationPhaseBuildItem.ValidationErrorBuildItem> errors) {
761767

@@ -768,6 +774,33 @@ public void detectMissingGuardRails(SynthesisFinishedBuildItem synthesisFinished
768774
new DeploymentException("Missing guardrail bean: " + cn)));
769775
}
770776
}
777+
778+
DotName dotName = DotName.createSimple(OutputGuardrailAccumulator.class);
779+
if (method.methodInfo.hasAnnotation(dotName)) {
780+
// We have an accumulator
781+
// Check that the accumulator exists
782+
var bean = method.methodInfo.annotation(dotName).value().asClass().name();
783+
if (synthesisFinished.beanStream().withBeanType(bean).isEmpty()) {
784+
errors.produce(new ValidationPhaseBuildItem.ValidationErrorBuildItem(
785+
new DeploymentException("Missing accumulator bean: " + bean.toString())));
786+
}
787+
788+
// Check that the accumulator is used on a method retuning a Multi
789+
DotName returnedType = method.methodInfo.returnType().name();
790+
if (!DotName.createSimple(Multi.class).equals(returnedType)) {
791+
errors.produce(new ValidationPhaseBuildItem.ValidationErrorBuildItem(
792+
new DeploymentException("OutputGuardrailAccumulator can only be used on method returning a " +
793+
"`Multi<X>`: found `%s` for method `%s.%s`".formatted(returnedType,
794+
method.methodInfo.declaringClass().toString(), method.methodInfo.name()))));
795+
}
796+
797+
// Check that the method have output guardrails
798+
if (method.outputGuardrails.isEmpty()) {
799+
errors.produce(new ValidationPhaseBuildItem.ValidationErrorBuildItem(
800+
new DeploymentException("OutputGuardrailAccumulator used without OutputGuardrails in method `%s.%s`"
801+
.formatted(method.methodInfo.declaringClass().toString(), method.methodInfo.name()))));
802+
}
803+
}
771804
}
772805
}
773806

@@ -1165,11 +1198,13 @@ private AiServiceMethodCreateInfo gatherMethodMetadata(
11651198
List<String> outputGuardrails = AiServicesMethodBuildItem.gatherGuardrails(method, OUTPUT_GUARDRAILS);
11661199
List<String> inputGuardrails = AiServicesMethodBuildItem.gatherGuardrails(method, INPUT_GUARDRAILS);
11671200

1201+
String accumulatorClassName = AiServicesMethodBuildItem.gatherAccumulator(method);
1202+
11681203
return new AiServiceMethodCreateInfo(method.declaringClass().name().toString(), method.name(), systemMessageInfo,
11691204
userMessageInfo, memoryIdParamPosition, requiresModeration,
11701205
returnTypeSignature(method.returnType(), new TypeArgMapper(method.declaringClass(), index)),
11711206
metricsTimedInfo, metricsCountedInfo, spanInfo, responseSchemaInfo, methodToolClassNames, inputGuardrails,
1172-
outputGuardrails);
1207+
outputGuardrails, accumulatorClassName);
11731208
}
11741209

11751210
private void validateReturnType(MethodInfo method) {
@@ -1690,5 +1725,18 @@ public static List<String> gatherGuardrails(MethodInfo methodInfo, DotName annot
16901725
}
16911726
return guardrails;
16921727
}
1728+
1729+
public static String gatherAccumulator(MethodInfo methodInfo) {
1730+
DotName annotation = DotName.createSimple(OutputGuardrailAccumulator.class);
1731+
AnnotationInstance instance = methodInfo.annotation(annotation);
1732+
if (instance == null) {
1733+
// Check on class
1734+
instance = methodInfo.declaringClass().declaredAnnotation(annotation);
1735+
}
1736+
if (instance != null) {
1737+
return instance.value().asClass().name().toString();
1738+
}
1739+
return null;
1740+
}
16931741
}
16941742
}
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
package io.quarkiverse.langchain4j.test.guardrails;
2+
3+
import static org.assertj.core.api.Assertions.assertThat;
4+
import static org.assertj.core.api.Fail.fail;
5+
6+
import java.util.List;
7+
import java.util.function.Supplier;
8+
9+
import jakarta.enterprise.context.ApplicationScoped;
10+
import jakarta.enterprise.context.control.ActivateRequestContext;
11+
import jakarta.enterprise.inject.spi.DeploymentException;
12+
13+
import org.jboss.shrinkwrap.api.ShrinkWrap;
14+
import org.jboss.shrinkwrap.api.spec.JavaArchive;
15+
import org.junit.jupiter.api.Test;
16+
import org.junit.jupiter.api.extension.RegisterExtension;
17+
18+
import dev.langchain4j.data.message.AiMessage;
19+
import dev.langchain4j.data.message.ChatMessage;
20+
import dev.langchain4j.memory.ChatMemory;
21+
import dev.langchain4j.memory.chat.ChatMemoryProvider;
22+
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
23+
import dev.langchain4j.model.StreamingResponseHandler;
24+
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
25+
import dev.langchain4j.model.output.Response;
26+
import dev.langchain4j.service.MemoryId;
27+
import dev.langchain4j.service.UserMessage;
28+
import io.quarkiverse.langchain4j.RegisterAiService;
29+
import io.quarkiverse.langchain4j.guardrails.OutputGuardrail;
30+
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailAccumulator;
31+
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult;
32+
import io.quarkiverse.langchain4j.guardrails.OutputGuardrails;
33+
import io.quarkiverse.langchain4j.guardrails.OutputTokenAccumulator;
34+
import io.quarkus.test.QuarkusUnitTest;
35+
import io.smallrye.mutiny.Multi;
36+
37+
public class InvalidOutputGuardrailAccumulatorTest {
38+
39+
@RegisterExtension
40+
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
41+
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
42+
.addClasses(MyAiService.class,
43+
MyMemoryProviderSupplier.class))
44+
.assertException(t -> {
45+
assertThat(t).isInstanceOf(DeploymentException.class);
46+
assertThat(t).hasMessageContaining(
47+
"io.quarkiverse.langchain4j.test.guardrails.InvalidOutputGuardrailAccumulatorTest$MyAiService.hi");
48+
});
49+
50+
@Test
51+
@ActivateRequestContext
52+
void testThatInvalidAccumulatorAreReported() {
53+
fail("Should not be called");
54+
}
55+
56+
@RegisterAiService(streamingChatLanguageModelSupplier = MyStreamingChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class)
57+
public interface MyAiService {
58+
59+
@UserMessage("Say Hi!")
60+
@OutputGuardrails(MyGuardRail.class)
61+
@OutputGuardrailAccumulator(MyAccumulator.class)
62+
String hi(@MemoryId String mem);
63+
64+
}
65+
66+
@ApplicationScoped
67+
public static class MyAccumulator implements OutputTokenAccumulator {
68+
69+
@Override
70+
public Multi<String> accumulate(Multi<String> tokens) {
71+
return tokens;
72+
}
73+
}
74+
75+
@ApplicationScoped
76+
public static class MyGuardRail implements OutputGuardrail {
77+
78+
@Override
79+
public OutputGuardrailResult validate(AiMessage responseFromLLM) {
80+
throw new RuntimeException("Should not be invoked");
81+
}
82+
83+
}
84+
85+
public static class MyMemoryProviderSupplier implements Supplier<ChatMemoryProvider> {
86+
@Override
87+
public ChatMemoryProvider get() {
88+
return new ChatMemoryProvider() {
89+
@Override
90+
public ChatMemory get(Object memoryId) {
91+
return new MessageWindowChatMemory.Builder().maxMessages(5).build();
92+
}
93+
};
94+
}
95+
}
96+
97+
public static class MyStreamingChatModelSupplier implements Supplier<StreamingChatLanguageModel> {
98+
99+
@Override
100+
public StreamingChatLanguageModel get() {
101+
return new StreamingChatLanguageModel() {
102+
@Override
103+
public void generate(List<ChatMessage> messages, StreamingResponseHandler<AiMessage> handler) {
104+
handler.onNext("Stream");
105+
handler.onNext("ing");
106+
handler.onNext(" ");
107+
handler.onNext("world");
108+
handler.onNext("!");
109+
handler.onComplete(Response.from(AiMessage.from("")));
110+
}
111+
};
112+
}
113+
}
114+
115+
}
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
package io.quarkiverse.langchain4j.test.guardrails;
2+
3+
import static org.assertj.core.api.Assertions.assertThat;
4+
import static org.assertj.core.api.Fail.fail;
5+
6+
import java.util.List;
7+
import java.util.function.Supplier;
8+
9+
import jakarta.enterprise.context.ApplicationScoped;
10+
import jakarta.enterprise.context.control.ActivateRequestContext;
11+
import jakarta.enterprise.inject.spi.DeploymentException;
12+
13+
import org.jboss.shrinkwrap.api.ShrinkWrap;
14+
import org.jboss.shrinkwrap.api.spec.JavaArchive;
15+
import org.junit.jupiter.api.Test;
16+
import org.junit.jupiter.api.extension.RegisterExtension;
17+
18+
import dev.langchain4j.data.message.AiMessage;
19+
import dev.langchain4j.data.message.ChatMessage;
20+
import dev.langchain4j.memory.ChatMemory;
21+
import dev.langchain4j.memory.chat.ChatMemoryProvider;
22+
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
23+
import dev.langchain4j.model.StreamingResponseHandler;
24+
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
25+
import dev.langchain4j.model.output.Response;
26+
import dev.langchain4j.service.MemoryId;
27+
import dev.langchain4j.service.UserMessage;
28+
import io.quarkiverse.langchain4j.RegisterAiService;
29+
import io.quarkiverse.langchain4j.guardrails.OutputGuardrail;
30+
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailAccumulator;
31+
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult;
32+
import io.quarkiverse.langchain4j.guardrails.OutputGuardrails;
33+
import io.quarkiverse.langchain4j.guardrails.OutputTokenAccumulator;
34+
import io.quarkus.test.QuarkusUnitTest;
35+
import io.smallrye.mutiny.Multi;
36+
37+
public class OutputGuardrailAccumulatorNotFoundTest {
38+
39+
@RegisterExtension
40+
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
41+
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
42+
.addClasses(MyAiService.class,
43+
MyMemoryProviderSupplier.class))
44+
.assertException(t -> {
45+
assertThat(t).isInstanceOf(DeploymentException.class);
46+
assertThat(t).hasMessageContaining(
47+
"io.quarkiverse.langchain4j.test.guardrails.OutputGuardrailAccumulatorNotFoundTest$MissingAccumulator");
48+
});
49+
50+
@Test
51+
@ActivateRequestContext
52+
void testThatNotFoundAccumulatorAreReported() {
53+
fail("Should not be called");
54+
}
55+
56+
@RegisterAiService(streamingChatLanguageModelSupplier = MyStreamingChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class)
57+
public interface MyAiService {
58+
59+
@UserMessage("Say Hi!")
60+
@OutputGuardrails(MyGuardRail.class)
61+
@OutputGuardrailAccumulator(MissingAccumulator.class)
62+
Multi<String> hi(@MemoryId String mem);
63+
64+
}
65+
66+
// Not a bean
67+
public static class MissingAccumulator implements OutputTokenAccumulator {
68+
69+
@Override
70+
public Multi<String> accumulate(Multi<String> tokens) {
71+
return tokens;
72+
}
73+
}
74+
75+
@ApplicationScoped
76+
public static class MyGuardRail implements OutputGuardrail {
77+
78+
@Override
79+
public OutputGuardrailResult validate(AiMessage responseFromLLM) {
80+
throw new RuntimeException("Should not be invoked");
81+
}
82+
83+
}
84+
85+
public static class MyMemoryProviderSupplier implements Supplier<ChatMemoryProvider> {
86+
@Override
87+
public ChatMemoryProvider get() {
88+
return new ChatMemoryProvider() {
89+
@Override
90+
public ChatMemory get(Object memoryId) {
91+
return new MessageWindowChatMemory.Builder().maxMessages(5).build();
92+
}
93+
};
94+
}
95+
}
96+
97+
public static class MyStreamingChatModelSupplier implements Supplier<StreamingChatLanguageModel> {
98+
99+
@Override
100+
public StreamingChatLanguageModel get() {
101+
return new StreamingChatLanguageModel() {
102+
@Override
103+
public void generate(List<ChatMessage> messages, StreamingResponseHandler<AiMessage> handler) {
104+
handler.onNext("Stream");
105+
handler.onNext("ing");
106+
handler.onNext(" ");
107+
handler.onNext("world");
108+
handler.onNext("!");
109+
handler.onComplete(Response.from(AiMessage.from("")));
110+
}
111+
};
112+
}
113+
}
114+
115+
}

0 commit comments

Comments
 (0)