Skip to content

Commit 5faca9b

Browse files
authored
Merge pull request #992 from dennysfredericci/feature/946-Add_prompt_template_and_variables_to_input_guardrail
Add prompt template and variables to input / output guardrails
2 parents afe026d + 54116ee commit 5faca9b

File tree

8 files changed

+530
-23
lines changed

8 files changed

+530
-23
lines changed

core/deployment/pom.xml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,11 @@
6363
<artifactId>quarkus-junit5-internal</artifactId>
6464
<scope>test</scope>
6565
</dependency>
66-
66+
<dependency>
67+
<groupId>io.quarkus</groupId>
68+
<artifactId>quarkus-websockets-next</artifactId>
69+
<scope>test</scope>
70+
</dependency>
6771
<dependency>
6872
<groupId>org.assertj</groupId>
6973
<artifactId>assertj-core</artifactId>
Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
package io.quarkiverse.langchain4j.test.guardrails;
2+
3+
import static org.assertj.core.api.Assertions.assertThat;
4+
5+
import java.util.List;
6+
import java.util.Map;
7+
import java.util.function.Supplier;
8+
9+
import jakarta.enterprise.context.RequestScoped;
10+
import jakarta.enterprise.context.control.ActivateRequestContext;
11+
import jakarta.inject.Inject;
12+
13+
import org.jboss.shrinkwrap.api.ShrinkWrap;
14+
import org.jboss.shrinkwrap.api.spec.JavaArchive;
15+
import org.junit.jupiter.api.Test;
16+
import org.junit.jupiter.api.extension.RegisterExtension;
17+
18+
import dev.langchain4j.data.message.AiMessage;
19+
import dev.langchain4j.data.message.ChatMessage;
20+
import dev.langchain4j.memory.chat.ChatMemoryProvider;
21+
import dev.langchain4j.model.chat.ChatLanguageModel;
22+
import dev.langchain4j.model.output.Response;
23+
import dev.langchain4j.service.MemoryId;
24+
import dev.langchain4j.service.UserMessage;
25+
import io.quarkiverse.langchain4j.RegisterAiService;
26+
import io.quarkiverse.langchain4j.guardrails.InputGuardrail;
27+
import io.quarkiverse.langchain4j.guardrails.InputGuardrailParams;
28+
import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult;
29+
import io.quarkiverse.langchain4j.guardrails.InputGuardrails;
30+
import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory;
31+
import io.quarkus.test.QuarkusUnitTest;
32+
33+
public class InputGuardrailPromptTemplateTest {
34+
35+
@RegisterExtension
36+
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
37+
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
38+
.addClasses(MyAiService.class, MyAiService.class, GuardrailValidation.class,
39+
MyChatModel.class, MyChatModelSupplier.class, MyMemoryProviderSupplier.class));
40+
@Inject
41+
MyAiService aiService;
42+
43+
@Inject
44+
GuardrailValidation guardrailValidation;
45+
46+
@Test
47+
@ActivateRequestContext
48+
void shouldWorkNoParameters() {
49+
aiService.getJoke();
50+
assertThat(guardrailValidation.spyUserMessageTemplate()).isEqualTo("Tell me a joke");
51+
assertThat(guardrailValidation.spyVariables()).isEmpty();
52+
}
53+
54+
@Test
55+
@ActivateRequestContext
56+
void shouldWorkWithMemoryId() {
57+
aiService.getAnotherJoke("memory-id-001");
58+
assertThat(guardrailValidation.spyUserMessageTemplate()).isEqualTo("Tell me another joke");
59+
assertThat(guardrailValidation.spyVariables()).containsExactlyInAnyOrderEntriesOf(Map.of(
60+
"memoryId", "memory-id-001",
61+
"it", "memory-id-001"));
62+
}
63+
64+
@Test
65+
@ActivateRequestContext
66+
void shouldWorkWithNoMemoryIdAndOneParameter() {
67+
aiService.sayHiToMyFriendNoMemory("Rambo");
68+
assertThat(guardrailValidation.spyUserMessageTemplate()).isEqualTo("Say hi to my friend {friend}!");
69+
assertThat(guardrailValidation.spyVariables())
70+
.containsExactlyInAnyOrderEntriesOf(Map.of(
71+
"friend", "Rambo",
72+
"it", "Rambo"));
73+
}
74+
75+
@Test
76+
@ActivateRequestContext
77+
void shouldWorkWithMemoryIdAndOneParameter() {
78+
aiService.sayHiToMyFriend("1", "Chuck Norris");
79+
assertThat(guardrailValidation.spyUserMessageTemplate()).isEqualTo("Say hi to my friend {friend}!");
80+
assertThat(guardrailValidation.spyVariables())
81+
.containsExactlyInAnyOrderEntriesOf(Map.of(
82+
"friend", "Chuck Norris",
83+
"mem", "1"));
84+
}
85+
86+
@Test
87+
@ActivateRequestContext
88+
void shouldWorkWithNoMemoryIdAndThreeParameters() {
89+
aiService.sayHiToMyFriends("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone");
90+
assertThat(guardrailValidation.spyUserMessageTemplate())
91+
.isEqualTo("Tell me something about {topic1}, {topic2}, {topic3}!");
92+
assertThat(guardrailValidation.spyVariables())
93+
.containsExactlyInAnyOrderEntriesOf(Map.of(
94+
"topic1", "Chuck Norris",
95+
"topic2", "Jean-Claude Van Damme",
96+
"topic3", "Silvester Stallone"));
97+
}
98+
99+
@Test
100+
@ActivateRequestContext
101+
void shouldWorkWithNoMemoryIdAndList() {
102+
aiService.sayHiToMyFriends(List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone"));
103+
assertThat(guardrailValidation.spyUserMessageText())
104+
.isEqualTo("Tell me something about [Chuck Norris, Jean-Claude Van Damme, Silvester Stallone]!");
105+
assertThat(guardrailValidation.spyUserMessageTemplate()).isEqualTo("Tell me something about {topics}!");
106+
assertThat(guardrailValidation.spyVariables())
107+
.containsExactlyInAnyOrderEntriesOf(Map.of(
108+
"topics", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone"),
109+
"it", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone")));
110+
}
111+
112+
@Test
113+
@ActivateRequestContext
114+
void shouldWorkWithMemoryIdAndList() {
115+
aiService.sayHiToMyFriends("memory-id-007", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone"));
116+
assertThat(guardrailValidation.spyUserMessageText()).isEqualTo(
117+
"Tell me something about [Chuck Norris, Jean-Claude Van Damme, Silvester Stallone]! This is my memory id: memory-id-007");
118+
assertThat(guardrailValidation.spyUserMessageTemplate())
119+
.isEqualTo("Tell me something about {topics}! This is my memory id: {memoryId}");
120+
assertThat(guardrailValidation.spyVariables())
121+
.containsExactlyInAnyOrderEntriesOf(Map.of(
122+
"topics", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone"),
123+
"memoryId", "memory-id-007"));
124+
}
125+
126+
@Test
127+
@ActivateRequestContext
128+
void shouldWorkWithMemoryIdAndOneItemFromList() {
129+
aiService.sayHiToMyFriend("memory-id-007", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone"));
130+
assertThat(guardrailValidation.spyUserMessageText())
131+
.isEqualTo("Tell me something about Chuck Norris! This is my memory id: memory-id-007");
132+
assertThat(guardrailValidation.spyUserMessageTemplate())
133+
.isEqualTo("Tell me something about {topics[0]}! This is my memory id: {memoryId}");
134+
assertThat(guardrailValidation.spyVariables())
135+
.containsExactlyInAnyOrderEntriesOf(Map.of(
136+
"topics", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone"),
137+
"memoryId", "memory-id-007"));
138+
}
139+
140+
@Test
141+
@ActivateRequestContext
142+
void shouldWorkWithNoUserMessage() {
143+
// UserMessage annotation is not provided, then no user message template should be available
144+
aiService.saySomething("Is this a parameter or a prompt?");
145+
assertThat(guardrailValidation.spyUserMessageTemplate()).isEmpty();
146+
assertThat(guardrailValidation.spyVariables()).isEmpty();
147+
}
148+
149+
@RegisterAiService(chatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class)
150+
public interface MyAiService {
151+
152+
@InputGuardrails(GuardrailValidation.class)
153+
@UserMessage("Tell me a joke")
154+
String getJoke();
155+
156+
@UserMessage("Tell me another joke")
157+
@InputGuardrails(GuardrailValidation.class)
158+
String getAnotherJoke(@MemoryId String memoryId);
159+
160+
@UserMessage("Say hi to my friend {friend}!")
161+
@InputGuardrails(GuardrailValidation.class)
162+
String sayHiToMyFriendNoMemory(String friend);
163+
164+
@UserMessage("Say hi to my friend {friend}!")
165+
@InputGuardrails(GuardrailValidation.class)
166+
String sayHiToMyFriend(@MemoryId String mem, String friend);
167+
168+
@UserMessage("Tell me something about {topic1}, {topic2}, {topic3}!")
169+
@InputGuardrails(GuardrailValidation.class)
170+
String sayHiToMyFriends(String topic1, String topic2, String topic3);
171+
172+
@UserMessage("Tell me something about {topics}!")
173+
@InputGuardrails(GuardrailValidation.class)
174+
String sayHiToMyFriends(List<String> topics);
175+
176+
@UserMessage("Tell me something about {topics}! This is my memory id: {memoryId}")
177+
@InputGuardrails(GuardrailValidation.class)
178+
String sayHiToMyFriends(@MemoryId String memoryId, List<String> topics);
179+
180+
@UserMessage("Tell me something about {topics[0]}! This is my memory id: {memoryId}")
181+
@InputGuardrails(GuardrailValidation.class)
182+
String sayHiToMyFriend(@MemoryId String memoryId, List<String> topics);
183+
184+
@InputGuardrails(GuardrailValidation.class)
185+
String saySomething(String isThisAPromptOrAParameter);
186+
187+
}
188+
189+
@RequestScoped
190+
public static class GuardrailValidation implements InputGuardrail {
191+
192+
InputGuardrailParams params;
193+
194+
public InputGuardrailResult validate(InputGuardrailParams params) {
195+
this.params = params;
196+
return success();
197+
}
198+
199+
public String spyUserMessageTemplate() {
200+
return params.userMessageTemplate();
201+
}
202+
203+
public String spyUserMessageText() {
204+
return params.userMessage().singleText();
205+
}
206+
207+
public Map<String, Object> spyVariables() {
208+
return params.variables();
209+
}
210+
}
211+
212+
public static class MyChatModelSupplier implements Supplier<ChatLanguageModel> {
213+
214+
@Override
215+
public ChatLanguageModel get() {
216+
return new MyChatModel();
217+
}
218+
}
219+
220+
public static class MyChatModel implements ChatLanguageModel {
221+
222+
@Override
223+
public Response<AiMessage> generate(List<ChatMessage> messages) {
224+
return new Response<>(new AiMessage("Hi!"));
225+
}
226+
}
227+
228+
public static class MyMemoryProviderSupplier implements Supplier<ChatMemoryProvider> {
229+
@Override
230+
public ChatMemoryProvider get() {
231+
return memoryId -> new NoopChatMemory();
232+
}
233+
}
234+
}

0 commit comments

Comments
 (0)