Skip to content

Commit 6bb2352

Browse files
committed
Allow rewriting of user messages from input guardrails
1 parent 8a29c07 commit 6bb2352

File tree

8 files changed

+157
-32
lines changed

8 files changed

+157
-32
lines changed
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
package io.quarkiverse.langchain4j.test.guardrails;
2+
3+
import static org.junit.jupiter.api.Assertions.assertEquals;
4+
5+
import java.util.List;
6+
import java.util.function.Supplier;
7+
8+
import jakarta.enterprise.context.RequestScoped;
9+
import jakarta.enterprise.context.control.ActivateRequestContext;
10+
import jakarta.inject.Inject;
11+
12+
import org.jboss.shrinkwrap.api.ShrinkWrap;
13+
import org.jboss.shrinkwrap.api.spec.JavaArchive;
14+
import org.junit.jupiter.api.Test;
15+
import org.junit.jupiter.api.extension.RegisterExtension;
16+
17+
import dev.langchain4j.data.message.AiMessage;
18+
import dev.langchain4j.data.message.ChatMessage;
19+
import dev.langchain4j.memory.ChatMemory;
20+
import dev.langchain4j.memory.chat.ChatMemoryProvider;
21+
import dev.langchain4j.model.chat.ChatLanguageModel;
22+
import dev.langchain4j.model.output.Response;
23+
import dev.langchain4j.service.UserMessage;
24+
import io.quarkiverse.langchain4j.RegisterAiService;
25+
import io.quarkiverse.langchain4j.guardrails.InputGuardrail;
26+
import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult;
27+
import io.quarkiverse.langchain4j.guardrails.InputGuardrails;
28+
import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory;
29+
import io.quarkus.test.QuarkusUnitTest;
30+
31+
public class InputGuardrailRewritingTest {
32+
33+
@RegisterExtension
34+
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
35+
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
36+
.addClasses(MyAiService.class, MessageTruncatingGuardrail.class, EchoChatModel.class,
37+
MyChatModelSupplier.class, MyMemoryProviderSupplier.class));
38+
39+
@Inject
40+
MyAiService aiService;
41+
42+
@Test
43+
@ActivateRequestContext
44+
void testRewriting() {
45+
assertEquals(MessageTruncatingGuardrail.MAX_LENGTH, aiService.test("first prompt", "second prompt").length());
46+
}
47+
48+
@RegisterAiService(chatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class)
49+
public interface MyAiService {
50+
51+
@UserMessage("Given {first} and {second} do something")
52+
@InputGuardrails(MessageTruncatingGuardrail.class)
53+
String test(String first, String second);
54+
55+
}
56+
57+
@RequestScoped
58+
public static class MessageTruncatingGuardrail implements InputGuardrail {
59+
60+
static final int MAX_LENGTH = 20;
61+
62+
@Override
63+
public InputGuardrailResult validate(dev.langchain4j.data.message.UserMessage um) {
64+
String text = um.singleText();
65+
return successWith(text.substring(0, MAX_LENGTH));
66+
}
67+
}
68+
69+
public static class MyChatModelSupplier implements Supplier<ChatLanguageModel> {
70+
71+
@Override
72+
public ChatLanguageModel get() {
73+
return new EchoChatModel();
74+
}
75+
}
76+
77+
public static class EchoChatModel implements ChatLanguageModel {
78+
79+
@Override
80+
public Response<AiMessage> generate(List<ChatMessage> messages) {
81+
return new Response<>(new AiMessage(((dev.langchain4j.data.message.UserMessage) messages.get(0)).singleText()));
82+
}
83+
}
84+
85+
public static class MyMemoryProviderSupplier implements Supplier<ChatMemoryProvider> {
86+
@Override
87+
public ChatMemoryProvider get() {
88+
return new ChatMemoryProvider() {
89+
@Override
90+
public ChatMemory get(Object memoryId) {
91+
return new NoopChatMemory();
92+
}
93+
};
94+
}
95+
}
96+
}

core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailResult.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,14 @@ enum Result {
2929
FATAL
3030
}
3131

32-
boolean isSuccess();
32+
Result getResult();
33+
34+
default boolean isSuccess() {
35+
return getResult() == Result.SUCCESS || getResult() == Result.SUCCESS_WITH_RESULT;
36+
}
3337

3438
default boolean hasRewrittenResult() {
35-
return false;
39+
return getResult() == Result.SUCCESS_WITH_RESULT;
3640
}
3741

3842
default GuardrailResult<GR> blockRetry() {

core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrail.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,14 @@ default InputGuardrailResult success() {
4747
return InputGuardrailResult.success();
4848
}
4949

50+
/**
51+
* @return The result of a successful input guardrail validation with a specific text.
52+
* @param successfulText The text of the successful result.
53+
*/
54+
default InputGuardrailResult successWith(String successfulText) {
55+
return InputGuardrailResult.successWith(successfulText);
56+
}
57+
5058
/**
5159
* @param message A message describing the failure.
5260
* @return The result of a failed input guardrail validation.

core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrailParams.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
package io.quarkiverse.langchain4j.guardrails;
22

3+
import java.util.List;
34
import java.util.Map;
45

6+
import dev.langchain4j.data.message.Content;
7+
import dev.langchain4j.data.message.ContentType;
8+
import dev.langchain4j.data.message.TextContent;
59
import dev.langchain4j.data.message.UserMessage;
610
import dev.langchain4j.memory.ChatMemory;
711
import dev.langchain4j.rag.AugmentationResult;
@@ -21,6 +25,14 @@ public record InputGuardrailParams(UserMessage userMessage, ChatMemory memory,
2125

2226
@Override
2327
public InputGuardrailParams withText(String text) {
24-
throw new UnsupportedOperationException();
28+
return new InputGuardrailParams(rewriteUserMessage(userMessage, text), memory, augmentationResult, userMessageTemplate,
29+
variables);
30+
}
31+
32+
public static UserMessage rewriteUserMessage(UserMessage userMessage, String text) {
33+
List<Content> rewrittenContent = userMessage.contents().stream()
34+
.map(c -> c.type() == ContentType.TEXT ? new TextContent(text) : c).toList();
35+
return userMessage.name() == null ? new UserMessage(rewrittenContent)
36+
: new UserMessage(userMessage.name(), rewrittenContent);
2537
}
2638
}

core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrailResult.java

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,29 +10,38 @@
1010
* @param result The result of the input guardrail validation.
1111
* @param failures The list of failures, empty if the validation succeeded.
1212
*/
13-
public record InputGuardrailResult(Result result, List<Failure> failures) implements GuardrailResult<InputGuardrailResult> {
13+
public record InputGuardrailResult(Result result, String successfulText,
14+
List<Failure> failures) implements GuardrailResult<InputGuardrailResult> {
1415

1516
private static final InputGuardrailResult SUCCESS = new InputGuardrailResult();
1617

1718
private InputGuardrailResult() {
18-
this(Result.SUCCESS, Collections.emptyList());
19+
this(Result.SUCCESS, null, Collections.emptyList());
20+
}
21+
22+
private InputGuardrailResult(String successfulText) {
23+
this(Result.SUCCESS_WITH_RESULT, successfulText, Collections.emptyList());
1924
}
2025

2126
InputGuardrailResult(List<Failure> failures, boolean fatal) {
22-
this(fatal ? Result.FATAL : Result.FAILURE, failures);
27+
this(fatal ? Result.FATAL : Result.FAILURE, null, failures);
2328
}
2429

2530
public static InputGuardrailResult success() {
2631
return InputGuardrailResult.SUCCESS;
2732
}
2833

34+
public static InputGuardrailResult successWith(String successfulText) {
35+
return new InputGuardrailResult(successfulText);
36+
}
37+
2938
public static InputGuardrailResult failure(List<? extends GuardrailResult.Failure> failures) {
3039
return new InputGuardrailResult((List<Failure>) failures, false);
3140
}
3241

3342
@Override
34-
public boolean isSuccess() {
35-
return result == Result.SUCCESS;
43+
public Result getResult() {
44+
return result;
3645
}
3746

3847
@Override
@@ -54,7 +63,7 @@ public InputGuardrailResult validatedBy(Class<? extends Guardrail> guardrailClas
5463
@Override
5564
public String toString() {
5665
if (isSuccess()) {
57-
return "success";
66+
return hasRewrittenResult() ? "Success with '" + successfulText + "'" : "Success";
5867
}
5968
return failures.stream().map(Failure::toString).collect(Collectors.joining(", "));
6069
}

core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailResult.java

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,8 @@ public static OutputGuardrailResult failure(List<? extends GuardrailResult.Failu
4848
}
4949

5050
@Override
51-
public boolean isSuccess() {
52-
return result == Result.SUCCESS || result == Result.SUCCESS_WITH_RESULT;
53-
}
54-
55-
@Override
56-
public boolean hasRewrittenResult() {
57-
return result == Result.SUCCESS_WITH_RESULT;
51+
public Result getResult() {
52+
return result;
5853
}
5954

6055
public boolean isRetry() {

core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -211,9 +211,10 @@ public Flow.Publisher<?> apply(AugmentationResult ar) {
211211
ChatMessage augmentedUserMessage = ar.chatMessage();
212212

213213
ChatMemory memory = context.chatMemory(memoryId);
214-
GuardrailsSupport.invokeInputGuardrails(methodCreateInfo, (UserMessage) augmentedUserMessage,
214+
UserMessage guardrailsMessage = GuardrailsSupport.invokeInputGuardrails(methodCreateInfo,
215+
(UserMessage) augmentedUserMessage,
215216
memory, ar, templateVariables);
216-
List<ChatMessage> messagesToSend = messagesToSend(augmentedUserMessage, needsMemorySeed);
217+
List<ChatMessage> messagesToSend = messagesToSend(guardrailsMessage, needsMemorySeed);
217218
var stream = new TokenStreamMulti(messagesToSend, effectiveToolSpecifications,
218219
finalToolExecutors, ar.contents(), context, memoryId,
219220
methodCreateInfo.isSwitchToWorkerThread());
@@ -223,25 +224,19 @@ public Flow.Publisher<?> apply(AugmentationResult ar) {
223224
templateVariables)));
224225
}
225226

226-
private List<ChatMessage> messagesToSend(ChatMessage augmentedUserMessage,
227+
private List<ChatMessage> messagesToSend(UserMessage augmentedUserMessage,
227228
boolean needsMemorySeed) {
228-
List<ChatMessage> messagesToSend;
229-
ChatMemory chatMemory;
230-
if (context.hasChatMemory()) {
231-
chatMemory = context.chatMemory(memoryId);
232-
messagesToSend = createMessagesToSendForExistingMemory(systemMessage, augmentedUserMessage,
233-
chatMemory, needsMemorySeed, context, methodCreateInfo);
234-
} else {
235-
messagesToSend = createMessagesToSendForNoMemory(systemMessage, augmentedUserMessage,
236-
needsMemorySeed, context, methodCreateInfo);
237-
}
238-
return messagesToSend;
229+
return context.hasChatMemory()
230+
? createMessagesToSendForExistingMemory(systemMessage, augmentedUserMessage,
231+
context.chatMemory(memoryId), needsMemorySeed, context, methodCreateInfo)
232+
: createMessagesToSendForNoMemory(systemMessage, augmentedUserMessage,
233+
needsMemorySeed, context, methodCreateInfo);
239234
}
240235
});
241236
}
242237
}
243238

244-
GuardrailsSupport.invokeInputGuardrails(methodCreateInfo, userMessage,
239+
userMessage = GuardrailsSupport.invokeInputGuardrails(methodCreateInfo, userMessage,
245240
context.hasChatMemory() ? context.chatMemory(memoryId) : null,
246241
augmentationResult, templateVariables);
247242

core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailsSupport.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package io.quarkiverse.langchain4j.runtime.aiservice;
22

33
import static dev.langchain4j.data.message.UserMessage.userMessage;
4+
import static io.quarkiverse.langchain4j.guardrails.InputGuardrailParams.rewriteUserMessage;
45

56
import java.util.ArrayList;
67
import java.util.Collections;
@@ -32,7 +33,7 @@
3233

3334
public class GuardrailsSupport {
3435

35-
public static void invokeInputGuardrails(AiServiceMethodCreateInfo methodCreateInfo, UserMessage userMessage,
36+
public static UserMessage invokeInputGuardrails(AiServiceMethodCreateInfo methodCreateInfo, UserMessage userMessage,
3637
ChatMemory chatMemory, AugmentationResult augmentationResult, Map<String, Object> templateVariables) {
3738
InputGuardrailResult result;
3839
try {
@@ -48,6 +49,11 @@ public static void invokeInputGuardrails(AiServiceMethodCreateInfo methodCreateI
4849
if (!result.isSuccess()) {
4950
throw new GuardrailException(result.toString(), result.getFirstFailureException());
5051
}
52+
53+
if (result.hasRewrittenResult()) {
54+
userMessage = rewriteUserMessage(userMessage, result.successfulText());
55+
}
56+
return userMessage;
5157
}
5258

5359
public static Response<AiMessage> invokeOutputGuardrails(AiServiceMethodCreateInfo methodCreateInfo,

0 commit comments

Comments
 (0)