Skip to content

Commit 180d25e

Browse files
946 - Add prompt template and variables to input guardrails
1 parent 89ea554 commit 180d25e

File tree

5 files changed

+279
-16
lines changed

5 files changed

+279
-16
lines changed

core/deployment/pom.xml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,11 @@
6363
<artifactId>quarkus-junit5-internal</artifactId>
6464
<scope>test</scope>
6565
</dependency>
66-
66+
<dependency>
67+
<groupId>io.quarkus</groupId>
68+
<artifactId>quarkus-websockets-next</artifactId>
69+
<scope>test</scope>
70+
</dependency>
6771
<dependency>
6872
<groupId>org.assertj</groupId>
6973
<artifactId>assertj-core</artifactId>
Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
package io.quarkiverse.langchain4j.test.guardrails;
2+
3+
import static org.assertj.core.api.Assertions.assertThat;
4+
5+
import java.util.List;
6+
import java.util.Map;
7+
import java.util.function.Supplier;
8+
9+
import jakarta.enterprise.context.RequestScoped;
10+
import jakarta.enterprise.context.control.ActivateRequestContext;
11+
import jakarta.inject.Inject;
12+
13+
import org.jboss.shrinkwrap.api.ShrinkWrap;
14+
import org.jboss.shrinkwrap.api.spec.JavaArchive;
15+
import org.junit.jupiter.api.Test;
16+
import org.junit.jupiter.api.extension.RegisterExtension;
17+
18+
import dev.langchain4j.data.message.AiMessage;
19+
import dev.langchain4j.data.message.ChatMessage;
20+
import dev.langchain4j.memory.chat.ChatMemoryProvider;
21+
import dev.langchain4j.model.chat.ChatLanguageModel;
22+
import dev.langchain4j.model.output.Response;
23+
import dev.langchain4j.service.MemoryId;
24+
import dev.langchain4j.service.UserMessage;
25+
import io.quarkiverse.langchain4j.RegisterAiService;
26+
import io.quarkiverse.langchain4j.guardrails.InputGuardrail;
27+
import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult;
28+
import io.quarkiverse.langchain4j.guardrails.InputGuardrails;
29+
import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory;
30+
import io.quarkus.test.QuarkusUnitTest;
31+
32+
public class InputGuardrailPromptTemplateTest {
33+
34+
@RegisterExtension
35+
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
36+
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
37+
.addClasses(MyAiService.class, MyAiService.class, GuardrailValidation.class,
38+
MyChatModel.class, MyChatModelSupplier.class, MyMemoryProviderSupplier.class));
39+
@Inject
40+
MyAiService aiService;
41+
42+
@Inject
43+
GuardrailValidation guardrailValidation;
44+
45+
@Test
46+
@ActivateRequestContext
47+
void shouldWorkNoParameters() {
48+
aiService.getJoke();
49+
assertThat(guardrailValidation.spyUserMessageTemplate()).isEqualTo("Tell me a joke");
50+
assertThat(guardrailValidation.spyVariables()).isEmpty();
51+
}
52+
53+
@Test
54+
@ActivateRequestContext
55+
void shouldWorkWithMemoryId() {
56+
aiService.getAnotherJoke("memory-id-001");
57+
assertThat(guardrailValidation.spyUserMessageTemplate()).isEqualTo("Tell me another joke");
58+
assertThat(guardrailValidation.spyVariables()).containsExactlyInAnyOrderEntriesOf(Map.of(
59+
"memoryId", "memory-id-001",
60+
"it", "memory-id-001" // is this correct?
61+
));
62+
}
63+
64+
@Test
65+
@ActivateRequestContext
66+
void shouldWorkWithNoMemoryIdAndOneParameter() {
67+
aiService.sayHiToMyFriendNoMemory("Rambo");
68+
assertThat(guardrailValidation.spyUserMessageTemplate()).isEqualTo("Say hi to my friend {friend}!");
69+
assertThat(guardrailValidation.spyVariables())
70+
.containsExactlyInAnyOrderEntriesOf(Map.of(
71+
"friend", "Rambo",
72+
"it", "Rambo"));
73+
}
74+
75+
@Test
76+
@ActivateRequestContext
77+
void shouldWorkWithMemoryIdAndOneParameter() {
78+
aiService.sayHiToMyFriend("1", "Chuck Norris");
79+
assertThat(guardrailValidation.spyUserMessageTemplate()).isEqualTo("Say hi to my friend {friend}!");
80+
assertThat(guardrailValidation.spyVariables())
81+
.containsExactlyInAnyOrderEntriesOf(Map.of(
82+
"friend", "Chuck Norris",
83+
"mem", "1"));
84+
}
85+
86+
@Test
87+
@ActivateRequestContext
88+
void shouldWorkWithNoMemoryIdAndThreeParameters() {
89+
aiService.sayHiToMyFriends("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone");
90+
assertThat(guardrailValidation.spyUserMessageTemplate())
91+
.isEqualTo("Tell me something about {topic1}, {topic2}, {topic3}!");
92+
assertThat(guardrailValidation.spyVariables())
93+
.containsExactlyInAnyOrderEntriesOf(Map.of(
94+
"topic1", "Chuck Norris",
95+
"topic2", "Jean-Claude Van Damme",
96+
"topic3", "Silvester Stallone"));
97+
}
98+
99+
@Test
100+
@ActivateRequestContext
101+
void shouldWorkWithNoMemoryIdAndList() {
102+
aiService.sayHiToMyFriends(List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone"));
103+
assertThat(guardrailValidation.spyUserMessageText())
104+
.isEqualTo("Tell me something about [Chuck Norris, Jean-Claude Van Damme, Silvester Stallone]!");
105+
assertThat(guardrailValidation.spyUserMessageTemplate()).isEqualTo("Tell me something about {topics}!");
106+
assertThat(guardrailValidation.spyVariables())
107+
.containsExactlyInAnyOrderEntriesOf(Map.of(
108+
"topics", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone"),
109+
"it", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone")));
110+
}
111+
112+
@Test
113+
@ActivateRequestContext
114+
void shouldWorkWithMemoryIdAndList() {
115+
aiService.sayHiToMyFriends("memory-id-007", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone"));
116+
assertThat(guardrailValidation.spyUserMessageText()).isEqualTo(
117+
"Tell me something about [Chuck Norris, Jean-Claude Van Damme, Silvester Stallone]! This is my memory id: memory-id-007");
118+
assertThat(guardrailValidation.spyUserMessageTemplate())
119+
.isEqualTo("Tell me something about {topics}! This is my memory id: {memoryId}");
120+
assertThat(guardrailValidation.spyVariables())
121+
.containsExactlyInAnyOrderEntriesOf(Map.of(
122+
"topics", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone"),
123+
"memoryId", "memory-id-007"));
124+
}
125+
126+
@Test
127+
@ActivateRequestContext
128+
void shouldWorkWithMemoryIdAndOneItemFromList() {
129+
aiService.sayHiToMyFriend("memory-id-007", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone"));
130+
assertThat(guardrailValidation.spyUserMessageText())
131+
.isEqualTo("Tell me something about Chuck Norris! This is my memory id: memory-id-007");
132+
assertThat(guardrailValidation.spyUserMessageTemplate())
133+
.isEqualTo("Tell me something about {topics[0]}! This is my memory id: {memoryId}");
134+
assertThat(guardrailValidation.spyVariables())
135+
.containsExactlyInAnyOrderEntriesOf(Map.of(
136+
"topics", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone"),
137+
"memoryId", "memory-id-007"));
138+
}
139+
140+
@Test
141+
@ActivateRequestContext
142+
void shouldWorkWithNoUserMessage() {
143+
// This is a special case where the UserMessage annotation is not present
144+
// The prompt template doesn't exist in this case
145+
// But the current implementation use the parameter name as prompt template
146+
// Not sure if this is the correct behavior, should we always have @UserMessage?
147+
// I need some thoughts on this case
148+
aiService.saySomething("Is this a parameter or a prompt?");
149+
assertThat(guardrailValidation.spyUserMessageTemplate()).isNull();
150+
assertThat(guardrailValidation.spyVariables()).isEmpty();
151+
}
152+
153+
@RegisterAiService(chatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class)
154+
public interface MyAiService {
155+
156+
@InputGuardrails(GuardrailValidation.class)
157+
@UserMessage("Tell me a joke")
158+
String getJoke();
159+
160+
@UserMessage("Tell me another joke")
161+
@InputGuardrails(GuardrailValidation.class)
162+
String getAnotherJoke(@MemoryId String memoryId);
163+
164+
@UserMessage("Say hi to my friend {friend}!")
165+
@InputGuardrails(GuardrailValidation.class)
166+
String sayHiToMyFriendNoMemory(String friend);
167+
168+
@UserMessage("Say hi to my friend {friend}!")
169+
@InputGuardrails(GuardrailValidation.class)
170+
String sayHiToMyFriend(@MemoryId String mem, String friend);
171+
172+
@UserMessage("Tell me something about {topic1}, {topic2}, {topic3}!")
173+
@InputGuardrails(GuardrailValidation.class)
174+
String sayHiToMyFriends(String topic1, String topic2, String topic3);
175+
176+
@UserMessage("Tell me something about {topics}!")
177+
@InputGuardrails(GuardrailValidation.class)
178+
String sayHiToMyFriends(List<String> topics);
179+
180+
@UserMessage("Tell me something about {topics}! This is my memory id: {memoryId}")
181+
@InputGuardrails(GuardrailValidation.class)
182+
String sayHiToMyFriends(@MemoryId String memoryId, List<String> topics);
183+
184+
@UserMessage("Tell me something about {topics[0]}! This is my memory id: {memoryId}")
185+
@InputGuardrails(GuardrailValidation.class)
186+
String sayHiToMyFriend(@MemoryId String memoryId, List<String> topics);
187+
188+
@InputGuardrails(GuardrailValidation.class)
189+
String saySomething(String isThisAPromptOrAParameter);
190+
191+
}
192+
193+
@RequestScoped
194+
public static class GuardrailValidation implements InputGuardrail {
195+
196+
InputGuardrailParams params;
197+
198+
public InputGuardrailResult validate(InputGuardrailParams params) {
199+
this.params = params;
200+
return success();
201+
}
202+
203+
public String spyUserMessageTemplate() {
204+
return params.userMessageTemplate();
205+
}
206+
207+
public String spyUserMessageText() {
208+
return params.userMessage().singleText();
209+
}
210+
211+
public Map<String, Object> spyVariables() {
212+
return params.variables();
213+
}
214+
}
215+
216+
public static class MyChatModelSupplier implements Supplier<ChatLanguageModel> {
217+
218+
@Override
219+
public ChatLanguageModel get() {
220+
return new MyChatModel();
221+
}
222+
}
223+
224+
public static class MyChatModel implements ChatLanguageModel {
225+
226+
@Override
227+
public Response<AiMessage> generate(List<ChatMessage> messages) {
228+
return new Response<>(new AiMessage("Hi!"));
229+
}
230+
}
231+
232+
public static class MyMemoryProviderSupplier implements Supplier<ChatMemoryProvider> {
233+
@Override
234+
public ChatMemoryProvider get() {
235+
return memoryId -> new NoopChatMemory();
236+
}
237+
}
238+
}

core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrail.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package io.quarkiverse.langchain4j.guardrails;
22

33
import java.util.Arrays;
4+
import java.util.Map;
45

56
import dev.langchain4j.data.message.UserMessage;
67
import io.smallrye.common.annotation.Experimental;

core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,12 @@ private static Object doImplement(AiServiceMethodCreateInfo methodCreateInfo, Ob
146146
Optional<SystemMessage> systemMessage = prepareSystemMessage(methodCreateInfo, methodArgs,
147147
context.hasChatMemory() ? context.chatMemory(memoryId).messages() : Collections.emptyList());
148148
UserMessage userMessage = prepareUserMessage(context, methodCreateInfo, methodArgs);
149+
Map<String, Object> templateParams = getTemplateParams(methodArgs, methodCreateInfo.getUserMessageInfo());
149150

150151
Type returnType = methodCreateInfo.getReturnType();
151152
if (isImage(returnType) || isResultImage(returnType)) {
152-
return doImplementGenerateImage(methodCreateInfo, context, audit, systemMessage, userMessage, memoryId, returnType);
153+
return doImplementGenerateImage(methodCreateInfo, context, audit, systemMessage, userMessage, memoryId, returnType,
154+
templateParams);
153155
}
154156

155157
if (audit != null) {
@@ -203,8 +205,9 @@ public AugmentationResult get() {
203205
@Override
204206
public Flow.Publisher<?> apply(AugmentationResult ar) {
205207
ChatMessage augmentedUserMessage = ar.chatMessage();
208+
206209
GuardrailsSupport.invokeInputGuardrails(methodCreateInfo, (UserMessage) augmentedUserMessage,
207-
context.chatMemory(memoryId), ar);
210+
context.chatMemory(memoryId), ar, templateParams);
208211
List<ChatMessage> messagesToSend = messagesToSend(augmentedUserMessage, needsMemorySeed);
209212
return new TokenStreamMulti(messagesToSend, effectiveToolSpecifications,
210213
finalToolExecutors, ar.contents(), context, memoryId);
@@ -230,7 +233,7 @@ private List<ChatMessage> messagesToSend(ChatMessage augmentedUserMessage,
230233

231234
GuardrailsSupport.invokeInputGuardrails(methodCreateInfo, userMessage,
232235
context.hasChatMemory() ? context.chatMemory(memoryId) : null,
233-
augmentationResult);
236+
augmentationResult, templateParams);
234237

235238
CommittableChatMemory chatMemory;
236239
List<ChatMessage> messagesToSend;
@@ -379,7 +382,7 @@ private List<ChatMessage> messagesToSend(ChatMessage augmentedUserMessage,
379382

380383
private static Object doImplementGenerateImage(AiServiceMethodCreateInfo methodCreateInfo, QuarkusAiServiceContext context,
381384
Audit audit, Optional<SystemMessage> systemMessage, UserMessage userMessage,
382-
Object memoryId, Type returnType) {
385+
Object memoryId, Type returnType, Map<String, Object> templateParams) {
383386
String imagePrompt;
384387
if (systemMessage.isPresent()) {
385388
imagePrompt = systemMessage.get().text() + "\n" + userMessage.singleText();
@@ -397,7 +400,7 @@ private static Object doImplementGenerateImage(AiServiceMethodCreateInfo methodC
397400
// TODO: we can only support input guardrails for now as it is tied to AiMessage
398401
GuardrailsSupport.invokeInputGuardrails(methodCreateInfo, userMessage,
399402
context.hasChatMemory() ? context.chatMemory(memoryId) : null,
400-
augmentationResult);
403+
augmentationResult, templateParams);
401404

402405
Response<Image> imageResponse = context.imageModel.generate(imagePrompt);
403406
if (audit != null) {
@@ -589,12 +592,7 @@ private static UserMessage prepareUserMessage(AiServiceContext context, AiServic
589592

590593
if (userMessageInfo.template().isPresent()) {
591594
AiServiceMethodCreateInfo.TemplateInfo templateInfo = userMessageInfo.template().get();
592-
Map<String, Object> templateParams = new HashMap<>();
593-
Map<String, Integer> nameToParamPosition = templateInfo.nameToParamPosition();
594-
for (var entry : nameToParamPosition.entrySet()) {
595-
Object value = transformTemplateParamValue(methodArgs[entry.getValue()]);
596-
templateParams.put(entry.getKey(), value);
597-
}
595+
Map<String, Object> templateParams = getTemplateParams(methodArgs, userMessageInfo);
598596
String templateText;
599597
if (templateInfo.text().isPresent()) {
600598
templateText = templateInfo.text().get();
@@ -642,6 +640,23 @@ private static UserMessage prepareUserMessage(AiServiceContext context, AiServic
642640
}
643641
}
644642

643+
private static Map<String, Object> getTemplateParams(Object[] methodArgs,
644+
AiServiceMethodCreateInfo.UserMessageInfo userMessageInfo) {
645+
Map<String, Object> templateParams = new HashMap<>();
646+
647+
if (userMessageInfo.template().isPresent()) {
648+
AiServiceMethodCreateInfo.TemplateInfo templateInfo = userMessageInfo.template().get();
649+
Map<String, Integer> nameToParamPosition = templateInfo.nameToParamPosition();
650+
651+
for (var entry : nameToParamPosition.entrySet()) {
652+
Object value = transformTemplateParamValue(methodArgs[entry.getValue()]);
653+
templateParams.put(entry.getKey(), value);
654+
}
655+
}
656+
657+
return templateParams;
658+
}
659+
645660
private static UserMessage createUserMessage(String name, ImageContent imageContent, String text) {
646661
if (name == null) {
647662
if (imageContent == null) {

core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailsSupport.java

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22

33
import static dev.langchain4j.data.message.UserMessage.userMessage;
44

5-
import java.util.ArrayList;
6-
import java.util.List;
5+
import java.util.*;
76
import java.util.function.Function;
87

98
import jakarta.enterprise.inject.spi.CDI;
@@ -30,11 +29,17 @@
3029
public class GuardrailsSupport {
3130

3231
public static void invokeInputGuardrails(AiServiceMethodCreateInfo methodCreateInfo, UserMessage userMessage,
33-
ChatMemory chatMemory, AugmentationResult augmentationResult) {
32+
ChatMemory chatMemory, AugmentationResult augmentationResult, Map<String, Object> templateParams) {
3433
InputGuardrailResult result;
3534
try {
35+
36+
Optional<String> userMessageTemplateOpt = methodCreateInfo.getUserMessageInfo().template()
37+
.flatMap(AiServiceMethodCreateInfo.TemplateInfo::text);
38+
39+
String userMessageTemplate = userMessageTemplateOpt.orElse(null);
40+
3641
result = invokeInputGuardRails(methodCreateInfo,
37-
new InputGuardrailParams(userMessage, chatMemory, augmentationResult));
42+
new InputGuardrailParams(userMessage, chatMemory, augmentationResult, promptTemplate, templateParams));
3843
} catch (Exception e) {
3944
throw new GuardrailException(e.getMessage(), e);
4045
}

0 commit comments

Comments
 (0)