Skip to content

Commit 5f55595

Browse files
946 - Add prompt template and variables to output guardrails
1 parent 180d25e commit 5f55595

File tree

8 files changed

+273
-32
lines changed

8 files changed

+273
-32
lines changed

core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailPromptTemplateTest.java

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import dev.langchain4j.service.UserMessage;
2525
import io.quarkiverse.langchain4j.RegisterAiService;
2626
import io.quarkiverse.langchain4j.guardrails.InputGuardrail;
27+
import io.quarkiverse.langchain4j.guardrails.InputGuardrailParams;
2728
import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult;
2829
import io.quarkiverse.langchain4j.guardrails.InputGuardrails;
2930
import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory;
@@ -57,8 +58,7 @@ void shouldWorkWithMemoryId() {
5758
assertThat(guardrailValidation.spyUserMessageTemplate()).isEqualTo("Tell me another joke");
5859
assertThat(guardrailValidation.spyVariables()).containsExactlyInAnyOrderEntriesOf(Map.of(
5960
"memoryId", "memory-id-001",
60-
"it", "memory-id-001" // is this correct?
61-
));
61+
"it", "memory-id-001"));
6262
}
6363

6464
@Test
@@ -140,11 +140,7 @@ void shouldWorkWithMemoryIdAndOneItemFromList() {
140140
@Test
141141
@ActivateRequestContext
142142
void shouldWorkWithNoUserMessage() {
143-
// This is a special case where the UserMessage annotation is not present
144-
// The prompt template doesn't exist in this case
145-
// But the current implementation use the parameter name as prompt template
146-
// Not sure if this is the correct behavior, should we always have @UserMessage?
147-
// I need some thoughts on this case
143+
// UserMessage annotation is not provided, then no user message template should be available
148144
aiService.saySomething("Is this a parameter or a prompt?");
149145
assertThat(guardrailValidation.spyUserMessageTemplate()).isNull();
150146
assertThat(guardrailValidation.spyVariables()).isEmpty();
Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
package io.quarkiverse.langchain4j.test.guardrails;
2+
3+
import static org.assertj.core.api.Assertions.assertThat;
4+
5+
import java.util.List;
6+
import java.util.Map;
7+
import java.util.function.Supplier;
8+
9+
import jakarta.enterprise.context.RequestScoped;
10+
import jakarta.enterprise.context.control.ActivateRequestContext;
11+
import jakarta.inject.Inject;
12+
13+
import org.jboss.shrinkwrap.api.ShrinkWrap;
14+
import org.jboss.shrinkwrap.api.spec.JavaArchive;
15+
import org.junit.jupiter.api.Test;
16+
import org.junit.jupiter.api.extension.RegisterExtension;
17+
18+
import dev.langchain4j.data.message.AiMessage;
19+
import dev.langchain4j.data.message.ChatMessage;
20+
import dev.langchain4j.memory.chat.ChatMemoryProvider;
21+
import dev.langchain4j.model.chat.ChatLanguageModel;
22+
import dev.langchain4j.model.output.Response;
23+
import dev.langchain4j.service.MemoryId;
24+
import dev.langchain4j.service.UserMessage;
25+
import io.quarkiverse.langchain4j.RegisterAiService;
26+
import io.quarkiverse.langchain4j.guardrails.OutputGuardrail;
27+
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailParams;
28+
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult;
29+
import io.quarkiverse.langchain4j.guardrails.OutputGuardrails;
30+
import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory;
31+
import io.quarkus.test.QuarkusUnitTest;
32+
33+
public class OutputGuardrailPromptTemplateTest {
34+
35+
@RegisterExtension
36+
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
37+
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
38+
.addClasses(MyAiService.class, MyAiService.class, GuardrailValidation.class,
39+
MyChatModel.class, MyChatModelSupplier.class, MyMemoryProviderSupplier.class));
40+
@Inject
41+
MyAiService aiService;
42+
43+
@Inject
44+
GuardrailValidation guardrailValidation;
45+
46+
@Test
47+
@ActivateRequestContext
48+
void shouldWorkNoParameters() {
49+
aiService.getJoke();
50+
assertThat(guardrailValidation.spyUserMessageTemplate()).isEqualTo("Tell me a joke");
51+
assertThat(guardrailValidation.spyVariables()).isEmpty();
52+
}
53+
54+
@Test
55+
@ActivateRequestContext
56+
void shouldWorkWithMemoryId() {
57+
aiService.getAnotherJoke("memory-id-001");
58+
assertThat(guardrailValidation.spyUserMessageTemplate()).isEqualTo("Tell me another joke");
59+
assertThat(guardrailValidation.spyVariables()).containsExactlyInAnyOrderEntriesOf(Map.of(
60+
"memoryId", "memory-id-001",
61+
"it", "memory-id-001"));
62+
}
63+
64+
@Test
65+
@ActivateRequestContext
66+
void shouldWorkWithNoMemoryIdAndOneParameter() {
67+
aiService.sayHiToMyFriendNoMemory("Rambo");
68+
assertThat(guardrailValidation.spyUserMessageTemplate()).isEqualTo("Say hi to my friend {friend}!");
69+
assertThat(guardrailValidation.spyVariables())
70+
.containsExactlyInAnyOrderEntriesOf(Map.of(
71+
"friend", "Rambo",
72+
"it", "Rambo"));
73+
}
74+
75+
@Test
76+
@ActivateRequestContext
77+
void shouldWorkWithMemoryIdAndOneParameter() {
78+
aiService.sayHiToMyFriend("1", "Chuck Norris");
79+
assertThat(guardrailValidation.spyUserMessageTemplate()).isEqualTo("Say hi to my friend {friend}!");
80+
assertThat(guardrailValidation.spyVariables())
81+
.containsExactlyInAnyOrderEntriesOf(Map.of(
82+
"friend", "Chuck Norris",
83+
"mem", "1"));
84+
}
85+
86+
@Test
87+
@ActivateRequestContext
88+
void shouldWorkWithNoMemoryIdAndThreeParameters() {
89+
aiService.sayHiToMyFriends("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone");
90+
assertThat(guardrailValidation.spyUserMessageTemplate())
91+
.isEqualTo("Tell me something about {topic1}, {topic2}, {topic3}!");
92+
assertThat(guardrailValidation.spyVariables())
93+
.containsExactlyInAnyOrderEntriesOf(Map.of(
94+
"topic1", "Chuck Norris",
95+
"topic2", "Jean-Claude Van Damme",
96+
"topic3", "Silvester Stallone"));
97+
}
98+
99+
@Test
100+
@ActivateRequestContext
101+
void shouldWorkWithNoMemoryIdAndList() {
102+
aiService.sayHiToMyFriends(List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone"));
103+
104+
assertThat(guardrailValidation.spyUserMessageTemplate()).isEqualTo("Tell me something about {topics}!");
105+
assertThat(guardrailValidation.spyVariables())
106+
.containsExactlyInAnyOrderEntriesOf(Map.of(
107+
"topics", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone"),
108+
"it", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone")));
109+
}
110+
111+
@Test
112+
@ActivateRequestContext
113+
void shouldWorkWithMemoryIdAndList() {
114+
aiService.sayHiToMyFriends("memory-id-007", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone"));
115+
116+
assertThat(guardrailValidation.spyUserMessageTemplate())
117+
.isEqualTo("Tell me something about {topics}! This is my memory id: {memoryId}");
118+
assertThat(guardrailValidation.spyVariables())
119+
.containsExactlyInAnyOrderEntriesOf(Map.of(
120+
"topics", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone"),
121+
"memoryId", "memory-id-007"));
122+
}
123+
124+
@Test
125+
@ActivateRequestContext
126+
void shouldWorkWithMemoryIdAndOneItemFromList() {
127+
aiService.sayHiToMyFriend("memory-id-007", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone"));
128+
129+
assertThat(guardrailValidation.spyUserMessageTemplate())
130+
.isEqualTo("Tell me something about {topics[0]}! This is my memory id: {memoryId}");
131+
assertThat(guardrailValidation.spyVariables())
132+
.containsExactlyInAnyOrderEntriesOf(Map.of(
133+
"topics", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone"),
134+
"memoryId", "memory-id-007"));
135+
}
136+
137+
@Test
138+
@ActivateRequestContext
139+
void shouldWorkWithNoUserMessage() {
140+
// UserMessage annotation is not provided, then no user message template should be available
141+
aiService.saySomething("Is this a parameter or a prompt?");
142+
assertThat(guardrailValidation.spyUserMessageTemplate()).isNull();
143+
assertThat(guardrailValidation.spyVariables()).isEmpty();
144+
}
145+
146+
@RegisterAiService(chatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class)
147+
public interface MyAiService {
148+
149+
@OutputGuardrails(GuardrailValidation.class)
150+
@UserMessage("Tell me a joke")
151+
String getJoke();
152+
153+
@UserMessage("Tell me another joke")
154+
@OutputGuardrails(GuardrailValidation.class)
155+
String getAnotherJoke(@MemoryId String memoryId);
156+
157+
@UserMessage("Say hi to my friend {friend}!")
158+
@OutputGuardrails(GuardrailValidation.class)
159+
String sayHiToMyFriendNoMemory(String friend);
160+
161+
@UserMessage("Say hi to my friend {friend}!")
162+
@OutputGuardrails(GuardrailValidation.class)
163+
String sayHiToMyFriend(@MemoryId String mem, String friend);
164+
165+
@UserMessage("Tell me something about {topic1}, {topic2}, {topic3}!")
166+
@OutputGuardrails(GuardrailValidation.class)
167+
String sayHiToMyFriends(String topic1, String topic2, String topic3);
168+
169+
@UserMessage("Tell me something about {topics}!")
170+
@OutputGuardrails(GuardrailValidation.class)
171+
String sayHiToMyFriends(List<String> topics);
172+
173+
@UserMessage("Tell me something about {topics}! This is my memory id: {memoryId}")
174+
@OutputGuardrails(GuardrailValidation.class)
175+
String sayHiToMyFriends(@MemoryId String memoryId, List<String> topics);
176+
177+
@UserMessage("Tell me something about {topics[0]}! This is my memory id: {memoryId}")
178+
@OutputGuardrails(GuardrailValidation.class)
179+
String sayHiToMyFriend(@MemoryId String memoryId, List<String> topics);
180+
181+
@OutputGuardrails(GuardrailValidation.class)
182+
String saySomething(String isThisAPromptOrAParameter);
183+
184+
}
185+
186+
@RequestScoped
187+
public static class GuardrailValidation implements OutputGuardrail {
188+
189+
OutputGuardrailParams params;
190+
191+
public OutputGuardrailResult validate(OutputGuardrailParams params) {
192+
this.params = params;
193+
return success();
194+
}
195+
196+
public String spyUserMessageTemplate() {
197+
return params.userMessageTemplate();
198+
}
199+
200+
public Map<String, Object> spyVariables() {
201+
return params.variables();
202+
}
203+
}
204+
205+
public static class MyChatModelSupplier implements Supplier<ChatLanguageModel> {
206+
207+
@Override
208+
public ChatLanguageModel get() {
209+
return new MyChatModel();
210+
}
211+
}
212+
213+
public static class MyChatModel implements ChatLanguageModel {
214+
215+
@Override
216+
public Response<AiMessage> generate(List<ChatMessage> messages) {
217+
return new Response<>(new AiMessage("Hi!"));
218+
}
219+
}
220+
221+
public static class MyMemoryProviderSupplier implements Supplier<ChatMemoryProvider> {
222+
@Override
223+
public ChatMemoryProvider get() {
224+
return memoryId -> new NoopChatMemory();
225+
}
226+
}
227+
}

core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrail.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package io.quarkiverse.langchain4j.guardrails;
22

33
import java.util.Arrays;
4-
import java.util.Map;
54

65
import dev.langchain4j.data.message.UserMessage;
76
import io.smallrye.common.annotation.Experimental;

core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrailParams.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package io.quarkiverse.langchain4j.guardrails;
22

3+
import java.util.Map;
4+
35
import dev.langchain4j.data.message.UserMessage;
46
import dev.langchain4j.memory.ChatMemory;
57
import dev.langchain4j.rag.AugmentationResult;
@@ -10,7 +12,10 @@
1012
* @param userMessage the user message, cannot be {@code null}
1113
* @param memory the memory, can be {@code null} or empty
1214
* @param augmentationResult the augmentation result, can be {@code null}
15+
* @param userMessageTemplate the user message template, can be {@code null} when @UserMessage is not provided.
16+
* @param variables the variable to be used with userMessageTemplate, can be {@code null} or empty
1317
*/
1418
public record InputGuardrailParams(UserMessage userMessage, ChatMemory memory,
15-
AugmentationResult augmentationResult) implements GuardrailParams {
19+
AugmentationResult augmentationResult, String userMessageTemplate,
20+
Map<String, Object> variables) implements GuardrailParams {
1621
}

core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailParams.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package io.quarkiverse.langchain4j.guardrails;
22

3+
import java.util.Map;
4+
35
import dev.langchain4j.data.message.AiMessage;
46
import dev.langchain4j.memory.ChatMemory;
57
import dev.langchain4j.rag.AugmentationResult;
@@ -10,7 +12,10 @@
1012
* @param responseFromLLM the response from the LLM
1113
* @param memory the memory, can be {@code null} or empty
1214
* @param augmentationResult the augmentation result, can be {@code null}
15+
* @param userMessageTemplate the user message template, can be {@code null} when @UserMessage is not provided.
16+
* @param variables the variable to be used with userMessageTemplate, can be {@code null} or empty
1317
*/
1418
public record OutputGuardrailParams(AiMessage responseFromLLM, ChatMemory memory,
15-
AugmentationResult augmentationResult) implements GuardrailParams {
19+
AugmentationResult augmentationResult, String userMessageTemplate,
20+
Map<String, Object> variables) implements GuardrailParams {
1621
}

core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodCreateInfo.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,13 @@ public OutputTokenAccumulator getOutputTokenAccumulator() {
193193
return accumulator;
194194
}
195195

196+
public String getUserMessageTemplate() {
197+
Optional<String> userMessageTemplateOpt = this.getUserMessageInfo().template()
198+
.flatMap(AiServiceMethodCreateInfo.TemplateInfo::text);
199+
200+
return userMessageTemplateOpt.orElse(null);
201+
}
202+
196203
public record UserMessageInfo(Optional<TemplateInfo> template,
197204
Optional<Integer> paramPosition,
198205
Optional<Integer> userNameParamPosition,

0 commit comments

Comments
 (0)