Skip to content

Commit 1747654

Browse files
committed
Implement support for output guardrail on streamed responses
The user can implement a custom accumulator to decide when to invoke the guardrail chain.
1 parent fb8bf24 commit 1747654

14 files changed

+1689
-40
lines changed

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
import io.quarkiverse.langchain4j.deployment.items.MethodParameterIgnoredAnnotationsBuildItem;
7171
import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem;
7272
import io.quarkiverse.langchain4j.guardrails.OutputGuardrail;
73+
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailAccumulator;
7374
import io.quarkiverse.langchain4j.runtime.AiServicesRecorder;
7475
import io.quarkiverse.langchain4j.runtime.NamedConfigUtil;
7576
import io.quarkiverse.langchain4j.runtime.QuarkusServiceOutputParser;
@@ -746,11 +747,16 @@ public void markUsedOutputGuardRailsUnremovable(List<AiServicesMethodBuildItem>
746747
for (String cn : list) {
747748
unremovableProducer.produce(UnremovableBeanBuildItem.beanTypes(DotName.createSimple(cn)));
748749
}
750+
DotName dotName = DotName.createSimple(OutputGuardrailAccumulator.class);
751+
if (method.methodInfo.hasAnnotation(dotName)) {
752+
unremovableProducer.produce(
753+
UnremovableBeanBuildItem.beanTypes(method.methodInfo.annotation(dotName).value().asClass().name()));
754+
}
749755
}
750756
}
751757

752758
@BuildStep
753-
public void detectMissingGuardRails(SynthesisFinishedBuildItem synthesisFinished,
759+
public void validateGuardrails(SynthesisFinishedBuildItem synthesisFinished,
754760
List<AiServicesMethodBuildItem> methods,
755761
BuildProducer<ValidationPhaseBuildItem.ValidationErrorBuildItem> errors) {
756762

@@ -763,6 +769,33 @@ public void detectMissingGuardRails(SynthesisFinishedBuildItem synthesisFinished
763769
new DeploymentException("Missing guardrail bean: " + cn)));
764770
}
765771
}
772+
773+
DotName dotName = DotName.createSimple(OutputGuardrailAccumulator.class);
774+
if (method.methodInfo.hasAnnotation(dotName)) {
775+
// We have an accumulator
776+
// Check that the accumulator exists
777+
var bean = method.methodInfo.annotation(dotName).value().asClass().name();
778+
if (synthesisFinished.beanStream().withBeanType(bean).isEmpty()) {
779+
errors.produce(new ValidationPhaseBuildItem.ValidationErrorBuildItem(
780+
new DeploymentException("Missing accumulator bean: " + bean.toString())));
781+
}
782+
783+
// Check that the accumulator is used on a method retuning a Multi
784+
DotName returnedType = method.methodInfo.returnType().name();
785+
if (!DotName.createSimple(Multi.class).equals(returnedType)) {
786+
errors.produce(new ValidationPhaseBuildItem.ValidationErrorBuildItem(
787+
new DeploymentException("OutputGuardrailAccumulator can only be used on method returning a " +
788+
"`Multi<X>`: found `%s` for method `%s.%s`".formatted(returnedType,
789+
method.methodInfo.declaringClass().toString(), method.methodInfo.name()))));
790+
}
791+
792+
// Check that the method have output guardrails
793+
if (method.outputGuardrails.isEmpty()) {
794+
errors.produce(new ValidationPhaseBuildItem.ValidationErrorBuildItem(
795+
new DeploymentException("OutputGuardrailAccumulator used without OutputGuardrails in method `%s.%s`"
796+
.formatted(method.methodInfo.declaringClass().toString(), method.methodInfo.name()))));
797+
}
798+
}
766799
}
767800
}
768801

@@ -1160,11 +1193,13 @@ private AiServiceMethodCreateInfo gatherMethodMetadata(
11601193
List<String> outputGuardrails = AiServicesMethodBuildItem.gatherGuardrails(method, OUTPUT_GUARDRAILS);
11611194
List<String> inputGuardrails = AiServicesMethodBuildItem.gatherGuardrails(method, INPUT_GUARDRAILS);
11621195

1196+
String accumulatorClassName = AiServicesMethodBuildItem.gatherAccumulator(method);
1197+
11631198
return new AiServiceMethodCreateInfo(method.declaringClass().name().toString(), method.name(), systemMessageInfo,
11641199
userMessageInfo, memoryIdParamPosition, requiresModeration,
11651200
returnTypeSignature(method.returnType(), new TypeArgMapper(method.declaringClass(), index)),
11661201
metricsTimedInfo, metricsCountedInfo, spanInfo, responseSchemaInfo, methodToolClassNames, inputGuardrails,
1167-
outputGuardrails);
1202+
outputGuardrails, accumulatorClassName);
11681203
}
11691204

11701205
private void validateReturnType(MethodInfo method) {
@@ -1685,5 +1720,18 @@ public static List<String> gatherGuardrails(MethodInfo methodInfo, DotName annot
16851720
}
16861721
return guardrails;
16871722
}
1723+
1724+
public static String gatherAccumulator(MethodInfo methodInfo) {
1725+
DotName annotation = DotName.createSimple(OutputGuardrailAccumulator.class);
1726+
AnnotationInstance instance = methodInfo.annotation(annotation);
1727+
if (instance == null) {
1728+
// Check on class
1729+
instance = methodInfo.declaringClass().declaredAnnotation(annotation);
1730+
}
1731+
if (instance != null) {
1732+
return instance.value().asClass().name().toString();
1733+
}
1734+
return null;
1735+
}
16881736
}
16891737
}
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
package io.quarkiverse.langchain4j.test.guardrails;
2+
3+
import static org.assertj.core.api.Assertions.assertThat;
4+
import static org.assertj.core.api.Fail.fail;
5+
6+
import java.util.List;
7+
import java.util.function.Supplier;
8+
9+
import jakarta.enterprise.context.ApplicationScoped;
10+
import jakarta.enterprise.context.control.ActivateRequestContext;
11+
import jakarta.enterprise.inject.spi.DeploymentException;
12+
13+
import org.jboss.shrinkwrap.api.ShrinkWrap;
14+
import org.jboss.shrinkwrap.api.spec.JavaArchive;
15+
import org.junit.jupiter.api.Test;
16+
import org.junit.jupiter.api.extension.RegisterExtension;
17+
18+
import dev.langchain4j.data.message.AiMessage;
19+
import dev.langchain4j.data.message.ChatMessage;
20+
import dev.langchain4j.memory.ChatMemory;
21+
import dev.langchain4j.memory.chat.ChatMemoryProvider;
22+
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
23+
import dev.langchain4j.model.StreamingResponseHandler;
24+
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
25+
import dev.langchain4j.model.output.Response;
26+
import dev.langchain4j.service.MemoryId;
27+
import dev.langchain4j.service.UserMessage;
28+
import io.quarkiverse.langchain4j.RegisterAiService;
29+
import io.quarkiverse.langchain4j.guardrails.OutputGuardrail;
30+
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailAccumulator;
31+
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult;
32+
import io.quarkiverse.langchain4j.guardrails.OutputGuardrails;
33+
import io.quarkiverse.langchain4j.guardrails.OutputTokenAccumulator;
34+
import io.quarkus.test.QuarkusUnitTest;
35+
import io.smallrye.mutiny.Multi;
36+
37+
public class InvalidOutputGuardrailAccumulatorTest {
38+
39+
@RegisterExtension
40+
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
41+
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
42+
.addClasses(MyAiService.class,
43+
MyMemoryProviderSupplier.class))
44+
.assertException(t -> {
45+
assertThat(t).isInstanceOf(DeploymentException.class);
46+
assertThat(t).hasMessageContaining(
47+
"io.quarkiverse.langchain4j.test.guardrails.InvalidOutputGuardrailAccumulatorTest$MyAiService.hi");
48+
});
49+
50+
@Test
51+
@ActivateRequestContext
52+
void testThatInvalidAccumulatorAreReported() {
53+
fail("Should not be called");
54+
}
55+
56+
@RegisterAiService(streamingChatLanguageModelSupplier = MyStreamingChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class)
57+
public interface MyAiService {
58+
59+
@UserMessage("Say Hi!")
60+
@OutputGuardrails(MyGuardRail.class)
61+
@OutputGuardrailAccumulator(MyAccumulator.class)
62+
String hi(@MemoryId String mem);
63+
64+
}
65+
66+
@ApplicationScoped
67+
public static class MyAccumulator implements OutputTokenAccumulator {
68+
69+
@Override
70+
public Multi<String> accumulate(Multi<String> tokens) {
71+
return tokens;
72+
}
73+
}
74+
75+
@ApplicationScoped
76+
public static class MyGuardRail implements OutputGuardrail {
77+
78+
@Override
79+
public OutputGuardrailResult validate(AiMessage responseFromLLM) {
80+
throw new RuntimeException("Should not be invoked");
81+
}
82+
83+
}
84+
85+
public static class MyMemoryProviderSupplier implements Supplier<ChatMemoryProvider> {
86+
@Override
87+
public ChatMemoryProvider get() {
88+
return new ChatMemoryProvider() {
89+
@Override
90+
public ChatMemory get(Object memoryId) {
91+
return new MessageWindowChatMemory.Builder().maxMessages(5).build();
92+
}
93+
};
94+
}
95+
}
96+
97+
public static class MyStreamingChatModelSupplier implements Supplier<StreamingChatLanguageModel> {
98+
99+
@Override
100+
public StreamingChatLanguageModel get() {
101+
return new StreamingChatLanguageModel() {
102+
@Override
103+
public void generate(List<ChatMessage> messages, StreamingResponseHandler<AiMessage> handler) {
104+
handler.onNext("Stream");
105+
handler.onNext("ing");
106+
handler.onNext(" ");
107+
handler.onNext("world");
108+
handler.onNext("!");
109+
handler.onComplete(Response.from(AiMessage.from("")));
110+
}
111+
};
112+
}
113+
}
114+
115+
}
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
package io.quarkiverse.langchain4j.test.guardrails;
2+
3+
import static org.assertj.core.api.Assertions.assertThat;
4+
import static org.assertj.core.api.Fail.fail;
5+
6+
import java.util.List;
7+
import java.util.function.Supplier;
8+
9+
import jakarta.enterprise.context.ApplicationScoped;
10+
import jakarta.enterprise.context.control.ActivateRequestContext;
11+
import jakarta.enterprise.inject.spi.DeploymentException;
12+
13+
import org.jboss.shrinkwrap.api.ShrinkWrap;
14+
import org.jboss.shrinkwrap.api.spec.JavaArchive;
15+
import org.junit.jupiter.api.Test;
16+
import org.junit.jupiter.api.extension.RegisterExtension;
17+
18+
import dev.langchain4j.data.message.AiMessage;
19+
import dev.langchain4j.data.message.ChatMessage;
20+
import dev.langchain4j.memory.ChatMemory;
21+
import dev.langchain4j.memory.chat.ChatMemoryProvider;
22+
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
23+
import dev.langchain4j.model.StreamingResponseHandler;
24+
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
25+
import dev.langchain4j.model.output.Response;
26+
import dev.langchain4j.service.MemoryId;
27+
import dev.langchain4j.service.UserMessage;
28+
import io.quarkiverse.langchain4j.RegisterAiService;
29+
import io.quarkiverse.langchain4j.guardrails.OutputGuardrail;
30+
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailAccumulator;
31+
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult;
32+
import io.quarkiverse.langchain4j.guardrails.OutputGuardrails;
33+
import io.quarkiverse.langchain4j.guardrails.OutputTokenAccumulator;
34+
import io.quarkus.test.QuarkusUnitTest;
35+
import io.smallrye.mutiny.Multi;
36+
37+
public class OutputGuardrailAccumulatorNotFoundTest {
38+
39+
@RegisterExtension
40+
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
41+
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
42+
.addClasses(MyAiService.class,
43+
MyMemoryProviderSupplier.class))
44+
.assertException(t -> {
45+
assertThat(t).isInstanceOf(DeploymentException.class);
46+
assertThat(t).hasMessageContaining(
47+
"io.quarkiverse.langchain4j.test.guardrails.OutputGuardrailAccumulatorNotFoundTest$MissingAccumulator");
48+
});
49+
50+
@Test
51+
@ActivateRequestContext
52+
void testThatNotFoundAccumulatorAreReported() {
53+
fail("Should not be called");
54+
}
55+
56+
@RegisterAiService(streamingChatLanguageModelSupplier = MyStreamingChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class)
57+
public interface MyAiService {
58+
59+
@UserMessage("Say Hi!")
60+
@OutputGuardrails(MyGuardRail.class)
61+
@OutputGuardrailAccumulator(MissingAccumulator.class)
62+
Multi<String> hi(@MemoryId String mem);
63+
64+
}
65+
66+
// Not a bean
67+
public static class MissingAccumulator implements OutputTokenAccumulator {
68+
69+
@Override
70+
public Multi<String> accumulate(Multi<String> tokens) {
71+
return tokens;
72+
}
73+
}
74+
75+
@ApplicationScoped
76+
public static class MyGuardRail implements OutputGuardrail {
77+
78+
@Override
79+
public OutputGuardrailResult validate(AiMessage responseFromLLM) {
80+
throw new RuntimeException("Should not be invoked");
81+
}
82+
83+
}
84+
85+
public static class MyMemoryProviderSupplier implements Supplier<ChatMemoryProvider> {
86+
@Override
87+
public ChatMemoryProvider get() {
88+
return new ChatMemoryProvider() {
89+
@Override
90+
public ChatMemory get(Object memoryId) {
91+
return new MessageWindowChatMemory.Builder().maxMessages(5).build();
92+
}
93+
};
94+
}
95+
}
96+
97+
public static class MyStreamingChatModelSupplier implements Supplier<StreamingChatLanguageModel> {
98+
99+
@Override
100+
public StreamingChatLanguageModel get() {
101+
return new StreamingChatLanguageModel() {
102+
@Override
103+
public void generate(List<ChatMessage> messages, StreamingResponseHandler<AiMessage> handler) {
104+
handler.onNext("Stream");
105+
handler.onNext("ing");
106+
handler.onNext(" ");
107+
handler.onNext("world");
108+
handler.onNext("!");
109+
handler.onComplete(Response.from(AiMessage.from("")));
110+
}
111+
};
112+
}
113+
}
114+
115+
}

0 commit comments

Comments
 (0)