Skip to content

Commit ed12693

Browse files
authored
Merge pull request #1021 from mariofusco/out_guard_with_result
2 parents 5faca9b + 4ac0106 commit ed12693

File tree

10 files changed

+184
-13
lines changed

10 files changed

+184
-13
lines changed

core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailChainTest.java

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package io.quarkiverse.langchain4j.test.guardrails;
22

33
import static org.assertj.core.api.Assertions.assertThat;
4+
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
45

56
import java.util.List;
67
import java.util.concurrent.atomic.AtomicInteger;
@@ -28,6 +29,7 @@
2829
import io.quarkiverse.langchain4j.guardrails.OutputGuardrail;
2930
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult;
3031
import io.quarkiverse.langchain4j.guardrails.OutputGuardrails;
32+
import io.quarkiverse.langchain4j.runtime.aiservice.GuardrailException;
3133
import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory;
3234
import io.quarkus.test.QuarkusUnitTest;
3335

@@ -78,6 +80,20 @@ void testThatRetryRestartTheChain() {
7880
assertThat(firstGuardrail.lastAccess()).isLessThan(secondGuardrail.lastAccess());
7981
}
8082

83+
@Test
84+
@ActivateRequestContext
85+
void testThatRewritesTheOutputTwiceInTheChain() {
86+
assertThat(aiService.rewritingSuccess("1", "foo")).isEqualTo("Hi!,1,2");
87+
}
88+
89+
@Test
90+
@ActivateRequestContext
91+
void testThatRepromptAfterRewriteIsNotAllowed() {
92+
assertThatExceptionOfType(GuardrailException.class)
93+
.isThrownBy(() -> aiService.repromptAfterRewrite("1", "foo"))
94+
.withMessageContaining("Retry or reprompt is not allowed after a rewritten output");
95+
}
96+
8197
@RegisterAiService(chatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class)
8298
public interface MyAiService {
8399

@@ -90,6 +106,12 @@ public interface MyAiService {
90106
@OutputGuardrails({ FirstGuardrail.class, FailingGuardrail.class, SecondGuardrail.class })
91107
String failingFirstTwo(@MemoryId String mem, @UserMessage String message);
92108

109+
@OutputGuardrails({ FirstRewritingGuardrail.class, SecondRewritingGuardrail.class })
110+
String rewritingSuccess(@MemoryId String mem, @UserMessage String message);
111+
112+
@OutputGuardrails({ FirstRewritingGuardrail.class, RepromptingGuardrail.class })
113+
String repromptAfterRewrite(@MemoryId String mem, @UserMessage String message);
114+
93115
}
94116

95117
@RequestScoped
@@ -164,6 +186,42 @@ public int spy() {
164186
}
165187
}
166188

189+
@RequestScoped
190+
public static class FirstRewritingGuardrail implements OutputGuardrail {
191+
192+
@Override
193+
public OutputGuardrailResult validate(AiMessage responseFromLLM) {
194+
String text = responseFromLLM.text();
195+
return successWith(text + ",1");
196+
}
197+
}
198+
199+
@RequestScoped
200+
public static class SecondRewritingGuardrail implements OutputGuardrail {
201+
202+
@Override
203+
public OutputGuardrailResult validate(AiMessage responseFromLLM) {
204+
String text = responseFromLLM.text();
205+
return successWith(text + ",2");
206+
}
207+
}
208+
209+
@RequestScoped
210+
public static class RepromptingGuardrail implements OutputGuardrail {
211+
212+
private boolean firstCall = true;
213+
214+
@Override
215+
public OutputGuardrailResult validate(AiMessage responseFromLLM) {
216+
if (firstCall) {
217+
firstCall = false;
218+
String text = responseFromLLM.text();
219+
return reprompt("Wrong message", text + ", " + text);
220+
}
221+
return success();
222+
}
223+
}
224+
167225
public static class MyChatModelSupplier implements Supplier<ChatLanguageModel> {
168226

169227
@Override

core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailOnStreamedResponseValidationTest.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,14 @@ void testFatalExceptionWithPassThroughAccumulator() {
139139
assertThat(fatal.spy()).isEqualTo(1);
140140
}
141141

142+
@Test
143+
@ActivateRequestContext
144+
void testRewritingWhileStreamingIsNotAllowed() {
145+
assertThatThrownBy(() -> aiService.rewriting("1").collect().asList().await().indefinitely())
146+
.isInstanceOf(GuardrailException.class)
147+
.hasMessageContaining("Attempting to rewrite the LLM output while streaming is not allowed");
148+
}
149+
142150
@RegisterAiService(streamingChatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class)
143151
public interface MyAiService {
144152

@@ -187,6 +195,9 @@ public interface MyAiService {
187195
@OutputGuardrailAccumulator(PassThroughAccumulator.class)
188196
Multi<String> fatalWithPassThroughAccumulator(@MemoryId String mem);
189197

198+
@UserMessage("Say Hi!")
199+
@OutputGuardrails({ RewritingGuardrail.class })
200+
Multi<String> rewriting(@MemoryId String mem);
190201
}
191202

192203
@RequestScoped
@@ -272,6 +283,16 @@ public int spy() {
272283
}
273284
}
274285

286+
@RequestScoped
287+
public static class RewritingGuardrail implements OutputGuardrail {
288+
289+
@Override
290+
public OutputGuardrailResult validate(AiMessage responseFromLLM) {
291+
String text = responseFromLLM.text();
292+
return successWith(text + ",1");
293+
}
294+
}
295+
275296
public static class MyChatModelSupplier implements Supplier<StreamingChatLanguageModel> {
276297

277298
@Override

core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailParams.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,12 @@ public interface GuardrailParams {
1818
* @return the augmentation result, can be {@code null}
1919
*/
2020
AugmentationResult augmentationResult();
21+
22+
/**
23+
* Recreate this guardrail param with the given input or output text.
24+
*
25+
* @param text The text of the rewritten param.
26+
* @return A clone of this guardrail params with the given input or output text.
27+
*/
28+
GuardrailParams withText(String text);
2129
}

core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailResult.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ enum Result {
1515
* A successful validation.
1616
*/
1717
SUCCESS,
18+
/**
19+
* A successful validation with a specific result.
20+
*/
21+
SUCCESS_WITH_RESULT,
1822
/**
1923
* A failed validation not preventing the subsequent validations eventually registered to be evaluated.
2024
*/
@@ -27,6 +31,18 @@ enum Result {
2731

2832
boolean isSuccess();
2933

34+
default boolean isRewrittenResult() {
35+
return false;
36+
}
37+
38+
default GuardrailResult<GR> blockRetry() {
39+
throw new UnsupportedOperationException();
40+
}
41+
42+
default String successfulResult() {
43+
throw new UnsupportedOperationException();
44+
}
45+
3046
boolean isFatal();
3147

3248
/**

core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrailParams.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,9 @@
1818
public record InputGuardrailParams(UserMessage userMessage, ChatMemory memory,
1919
AugmentationResult augmentationResult, String userMessageTemplate,
2020
Map<String, Object> variables) implements GuardrailParams {
21+
22+
@Override
23+
public InputGuardrailParams withText(String text) {
24+
throw new UnsupportedOperationException();
25+
}
2126
}

core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrail.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,14 @@ default OutputGuardrailResult success() {
5050
return OutputGuardrailResult.success();
5151
}
5252

53+
/**
54+
* @return The result of a successful output guardrail validation with a specific result.
55+
* @param successfulResult The successful result.
56+
*/
57+
default OutputGuardrailResult successWith(String successfulResult) {
58+
return OutputGuardrailResult.successWith(successfulResult);
59+
}
60+
5361
/**
5462
* @param message A message describing the failure.
5563
* @return The result of a failed output guardrail validation.

core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailParams.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
package io.quarkiverse.langchain4j.guardrails;
22

3+
import java.util.List;
34
import java.util.Map;
45

6+
import dev.langchain4j.agent.tool.ToolExecutionRequest;
57
import dev.langchain4j.data.message.AiMessage;
68
import dev.langchain4j.memory.ChatMemory;
79
import dev.langchain4j.rag.AugmentationResult;
@@ -18,4 +20,11 @@
1820
public record OutputGuardrailParams(AiMessage responseFromLLM, ChatMemory memory,
1921
AugmentationResult augmentationResult, String userMessageTemplate,
2022
Map<String, Object> variables) implements GuardrailParams {
23+
24+
@Override
25+
public OutputGuardrailParams withText(String text) {
26+
List<ToolExecutionRequest> tools = responseFromLLM.toolExecutionRequests();
27+
AiMessage aiMessage = tools != null && !tools.isEmpty() ? new AiMessage(text, tools) : new AiMessage(text);
28+
return new OutputGuardrailParams(aiMessage, memory, augmentationResult, userMessageTemplate, variables);
29+
}
2130
}

core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailResult.java

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,35 +10,54 @@
1010
* @param result The result of the output guardrail validation.
1111
* @param failures The list of failures, empty if the validation succeeded.
1212
*/
13-
public record OutputGuardrailResult(Result result, List<Failure> failures) implements GuardrailResult<OutputGuardrailResult> {
13+
public record OutputGuardrailResult(Result result, String successfulResult,
14+
List<Failure> failures) implements GuardrailResult<OutputGuardrailResult> {
1415

1516
private static final OutputGuardrailResult SUCCESS = new OutputGuardrailResult();
1617

1718
private OutputGuardrailResult() {
18-
this(Result.SUCCESS, Collections.emptyList());
19+
this(Result.SUCCESS, null, Collections.emptyList());
20+
}
21+
22+
private OutputGuardrailResult(String successfulResult) {
23+
this(Result.SUCCESS_WITH_RESULT, successfulResult, Collections.emptyList());
1924
}
2025

2126
OutputGuardrailResult(List<Failure> failures, boolean fatal) {
22-
this(fatal ? Result.FATAL : Result.FAILURE, failures);
27+
this(fatal ? Result.FATAL : Result.FAILURE, null, failures);
2328
}
2429

2530
public static OutputGuardrailResult success() {
2631
return SUCCESS;
2732
}
2833

34+
public static OutputGuardrailResult successWith(String successfulResult) {
35+
return new OutputGuardrailResult(successfulResult);
36+
}
37+
2938
public static OutputGuardrailResult failure(List<? extends GuardrailResult.Failure> failures) {
3039
return new OutputGuardrailResult((List<Failure>) failures, false);
3140
}
3241

3342
@Override
3443
public boolean isSuccess() {
35-
return result == Result.SUCCESS;
44+
return result == Result.SUCCESS || result == Result.SUCCESS_WITH_RESULT;
45+
}
46+
47+
@Override
48+
public boolean isRewrittenResult() {
49+
return result == Result.SUCCESS_WITH_RESULT;
3650
}
3751

3852
public boolean isRetry() {
3953
return !isSuccess() && failures.stream().anyMatch(Failure::retry);
4054
}
4155

56+
public OutputGuardrailResult blockRetry() {
57+
failures().set(0, failures().get(0).blockRetry());
58+
return this;
59+
}
60+
4261
public String getReprompt() {
4362
if (!isSuccess()) {
4463
for (Failure failure : failures) {
@@ -97,6 +116,13 @@ public Failure withGuardrailClass(Class<? extends Guardrail> guardrailClass) {
97116
return new Failure(message(), cause(), guardrailClass, retry, reprompt);
98117
}
99118

119+
public Failure blockRetry() {
120+
return retry
121+
? new Failure("Retry or reprompt is not allowed after a rewritten output", cause(), guardrailClass, false,
122+
reprompt)
123+
: this;
124+
}
125+
100126
@Override
101127
public String toString() {
102128
return "The guardrail " + guardrailClass.getName() + " failed with this message: " + message;

core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,10 @@ private List<ChatMessage> messagesToSend(ChatMessage augmentedUserMessage,
290290
throw new GuardrailsSupport.GuardrailRetryException();
291291
}
292292
} else {
293+
if (result.isRewrittenResult()) {
294+
throw new GuardrailException(
295+
"Attempting to rewrite the LLM output while streaming is not allowed");
296+
}
293297
return chunk;
294298
}
295299
})

core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailsSupport.java

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import jakarta.enterprise.inject.spi.CDI;
99

10+
import dev.langchain4j.agent.tool.ToolExecutionRequest;
1011
import dev.langchain4j.agent.tool.ToolSpecification;
1112
import dev.langchain4j.data.message.AiMessage;
1213
import dev.langchain4j.data.message.UserMessage;
@@ -57,8 +58,9 @@ public static Response<AiMessage> invokeOutputGuardrails(AiServiceMethodCreateIn
5758
if (max <= 0) {
5859
max = 1;
5960
}
61+
62+
OutputGuardrailResult result = null;
6063
while (attempt < max) {
61-
OutputGuardrailResult result;
6264
try {
6365
result = invokeOutputGuardRails(methodCreateInfo, output);
6466
} catch (Exception e) {
@@ -97,9 +99,20 @@ public static Response<AiMessage> invokeOutputGuardrails(AiServiceMethodCreateIn
9799
if (attempt == max) {
98100
throw new GuardrailException("Output validation failed. The guardrails have reached the maximum number of retries");
99101
}
102+
103+
if (result.isRewrittenResult()) {
104+
response = rewriteResponseWithText(response, result.successfulResult());
105+
}
106+
100107
return response;
101108
}
102109

110+
public static Response<AiMessage> rewriteResponseWithText(Response<AiMessage> response, String text) {
111+
List<ToolExecutionRequest> tools = response.content().toolExecutionRequests();
112+
AiMessage content = tools != null && !tools.isEmpty() ? new AiMessage(text, tools) : new AiMessage(text);
113+
return new Response<>(content, response.tokenUsage(), response.finishReason(), response.metadata());
114+
}
115+
103116
@SuppressWarnings("unchecked")
104117
private static OutputGuardrailResult invokeOutputGuardRails(AiServiceMethodCreateInfo methodCreateInfo,
105118
OutputGuardrailParams params) {
@@ -160,25 +173,28 @@ private static <GR extends GuardrailResult> GR guardrailResult(GuardrailParams p
160173
for (Class<? extends Guardrail> bean : classes) {
161174
GR result = (GR) CDI.current().select(bean).get().validate(params).validatedBy(bean);
162175
if (result.isFatal()) {
163-
return result;
176+
return accumulatedResults.isRewrittenResult() ? (GR) result.blockRetry() : result;
177+
}
178+
if (result.isRewrittenResult()) {
179+
params = params.withText(result.successfulResult());
164180
}
165181
accumulatedResults = compose(accumulatedResults, result, producer);
166182
}
167183

168184
return accumulatedResults;
169185
}
170186

171-
private static <GR extends GuardrailResult> GR compose(GR first, GR second,
187+
private static <GR extends GuardrailResult> GR compose(GR oldResult, GR newResult,
172188
Function<List<? extends GuardrailResult.Failure>, GR> producer) {
173-
if (first.isSuccess()) {
174-
return second;
189+
if (oldResult.isSuccess()) {
190+
return newResult;
175191
}
176-
if (second.isSuccess()) {
177-
return first;
192+
if (newResult.isSuccess()) {
193+
return oldResult;
178194
}
179195
List<? extends GuardrailResult.Failure> failures = new ArrayList<>();
180-
failures.addAll(first.failures());
181-
failures.addAll(second.failures());
196+
failures.addAll(oldResult.failures());
197+
failures.addAll(newResult.failures());
182198
return producer.apply(failures);
183199
}
184200

0 commit comments

Comments
 (0)