Skip to content

Commit 78fcaf4

Browse files
committed
Provide an abstract output guardrails for json data extraction
1 parent 6efb975 commit 78fcaf4

File tree

10 files changed

+265
-29
lines changed

10 files changed

+265
-29
lines changed

core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailChainTest.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,12 @@ void testThatRepromptAfterRewriteIsNotAllowed() {
9494
.withMessageContaining("Retry or reprompt is not allowed after a rewritten output");
9595
}
9696

97+
@Test
98+
@ActivateRequestContext
99+
void testThatRewritesTheOutputWithAResult() {
100+
assertThat(aiService.rewritingSuccessWithResult("1", "foo")).isSameAs(RewritingGuardrailWithResult.RESULT);
101+
}
102+
97103
@RegisterAiService(chatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class)
98104
public interface MyAiService {
99105

@@ -112,6 +118,9 @@ public interface MyAiService {
112118
@OutputGuardrails({ FirstRewritingGuardrail.class, RepromptingGuardrail.class })
113119
String repromptAfterRewrite(@MemoryId String mem, @UserMessage String message);
114120

121+
@OutputGuardrails({ FirstRewritingGuardrail.class, RewritingGuardrailWithResult.class })
122+
Integer rewritingSuccessWithResult(@MemoryId String mem, @UserMessage String message);
123+
115124
}
116125

117126
@RequestScoped
@@ -206,6 +215,18 @@ public OutputGuardrailResult validate(AiMessage responseFromLLM) {
206215
}
207216
}
208217

218+
@RequestScoped
219+
public static class RewritingGuardrailWithResult implements OutputGuardrail {
220+
221+
static final Integer RESULT = 1_000;
222+
223+
@Override
224+
public OutputGuardrailResult validate(AiMessage responseFromLLM) {
225+
String text = responseFromLLM.text();
226+
return successWith(text + ",2", RESULT);
227+
}
228+
}
229+
209230
@RequestScoped
210231
public static class RepromptingGuardrail implements OutputGuardrail {
211232

core/runtime/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,11 @@
140140
<artifactId>junit-jupiter</artifactId>
141141
<scope>test</scope>
142142
</dependency>
143+
<dependency>
144+
<groupId>io.quarkus</groupId>
145+
<artifactId>quarkus-junit5</artifactId>
146+
<scope>test</scope>
147+
</dependency>
143148
</dependencies>
144149
<build>
145150
<plugins>
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
package io.quarkiverse.langchain4j.guardrails;
2+
3+
import jakarta.inject.Inject;
4+
5+
import org.jboss.logging.Logger;
6+
7+
import com.fasterxml.jackson.core.type.TypeReference;
8+
9+
import dev.langchain4j.data.message.AiMessage;
10+
11+
public abstract class AbstractJsonExtractorOutputGuardrail implements OutputGuardrail {
12+
13+
@Inject
14+
Logger logger;
15+
16+
@Inject
17+
JsonGuardrailsUtils jsonGuardrailsUtils;
18+
19+
protected AbstractJsonExtractorOutputGuardrail() {
20+
if (getOutputClass() == null && getOutputType() == null) {
21+
throw new IllegalArgumentException("Either getOutputClass() or getOutputType() must be implemented");
22+
}
23+
}
24+
25+
@Override
26+
public OutputGuardrailResult validate(AiMessage responseFromLLM) {
27+
String llmResponse = responseFromLLM.text();
28+
logger.debugf("LLM output: %s", llmResponse);
29+
30+
Object result = deserialize(llmResponse);
31+
if (result != null) {
32+
return successWith(llmResponse, result);
33+
}
34+
35+
String json = jsonGuardrailsUtils.trimNonJson(llmResponse);
36+
if (json != null) {
37+
result = deserialize(json);
38+
if (result != null) {
39+
return successWith(json, result);
40+
}
41+
}
42+
43+
return reprompt("Invalid JSON",
44+
"Make sure you return a valid JSON object following "
45+
+ "the specified format");
46+
}
47+
48+
protected Object deserialize(String llmResponse) {
49+
return getOutputClass() != null ? jsonGuardrailsUtils.deserialize(llmResponse, getOutputClass())
50+
: jsonGuardrailsUtils.deserialize(llmResponse, getOutputType());
51+
}
52+
53+
protected Class<?> getOutputClass() {
54+
return null;
55+
}
56+
57+
protected TypeReference<?> getOutputType() {
58+
return null;
59+
}
60+
}

core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailResult.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,19 @@ enum Result {
3131

3232
boolean isSuccess();
3333

34-
default boolean isRewrittenResult() {
34+
default boolean hasRewrittenResult() {
3535
return false;
3636
}
3737

3838
default GuardrailResult<GR> blockRetry() {
3939
throw new UnsupportedOperationException();
4040
}
4141

42-
default String successfulResult() {
42+
default String successfulText() {
43+
throw new UnsupportedOperationException();
44+
}
45+
46+
default Object successfulResult() {
4347
throw new UnsupportedOperationException();
4448
}
4549

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package io.quarkiverse.langchain4j.guardrails;
2+
3+
import jakarta.enterprise.context.ApplicationScoped;
4+
import jakarta.inject.Inject;
5+
6+
import com.fasterxml.jackson.core.JsonProcessingException;
7+
import com.fasterxml.jackson.core.type.TypeReference;
8+
import com.fasterxml.jackson.databind.ObjectMapper;
9+
10+
@ApplicationScoped
11+
class JsonGuardrailsUtils {
12+
13+
@Inject
14+
ObjectMapper objectMapper;
15+
16+
private JsonGuardrailsUtils() {
17+
}
18+
19+
String trimNonJson(String llmResponse) {
20+
int jsonMapStart = llmResponse.indexOf('{');
21+
int jsonListStart = llmResponse.indexOf('[');
22+
if (jsonMapStart < 0 && jsonListStart < 0) {
23+
return null;
24+
}
25+
boolean isJsonMap = jsonMapStart >= 0 && (jsonMapStart < jsonListStart || jsonListStart < 0);
26+
27+
int jsonStart = isJsonMap ? jsonMapStart : jsonListStart;
28+
int jsonEnd = isJsonMap ? llmResponse.lastIndexOf('}') : llmResponse.lastIndexOf(']');
29+
return jsonEnd >= 0 && jsonStart < jsonEnd ? llmResponse.substring(jsonStart, jsonEnd + 1) : null;
30+
}
31+
32+
<T> T deserialize(String json, Class<T> expectedOutputClass) {
33+
try {
34+
return objectMapper.readValue(json, expectedOutputClass);
35+
} catch (JsonProcessingException e) {
36+
return null;
37+
}
38+
}
39+
40+
<T> T deserialize(String json, TypeReference<T> expectedOutputType) {
41+
try {
42+
return objectMapper.readValue(json, expectedOutputType);
43+
} catch (JsonProcessingException e) {
44+
return null;
45+
}
46+
}
47+
}

core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrail.java

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,20 @@ default OutputGuardrailResult success() {
5151
}
5252

5353
/**
54-
* @return The result of a successful output guardrail validation with a specific result.
55-
* @param successfulResult The successful result.
54+
* @return The result of a successful output guardrail validation with a specific text.
55+
* @param successfulText The text of the successful result.
5656
*/
57-
default OutputGuardrailResult successWith(String successfulResult) {
58-
return OutputGuardrailResult.successWith(successfulResult);
57+
default OutputGuardrailResult successWith(String successfulText) {
58+
return OutputGuardrailResult.successWith(successfulText);
59+
}
60+
61+
/**
62+
* @return The result of a successful output guardrail validation with a specific text.
63+
* @param successfulText The text of the successful result.
64+
* @param successfulResult The object generated by this successful result.
65+
*/
66+
default OutputGuardrailResult successWith(String successfulText, Object successfulResult) {
67+
return OutputGuardrailResult.successWith(successfulText, successfulResult);
5968
}
6069

6170
/**

core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailResult.java

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,29 +10,37 @@
1010
* @param result The result of the output guardrail validation.
1111
* @param failures The list of failures, empty if the validation succeeded.
1212
*/
13-
public record OutputGuardrailResult(Result result, String successfulResult,
13+
public record OutputGuardrailResult(Result result, String successfulText, Object successfulResult,
1414
List<Failure> failures) implements GuardrailResult<OutputGuardrailResult> {
1515

1616
private static final OutputGuardrailResult SUCCESS = new OutputGuardrailResult();
1717

1818
private OutputGuardrailResult() {
19-
this(Result.SUCCESS, null, Collections.emptyList());
19+
this(Result.SUCCESS, null, null, Collections.emptyList());
2020
}
2121

22-
private OutputGuardrailResult(String successfulResult) {
23-
this(Result.SUCCESS_WITH_RESULT, successfulResult, Collections.emptyList());
22+
private OutputGuardrailResult(String successfulText) {
23+
this(Result.SUCCESS_WITH_RESULT, successfulText, null, Collections.emptyList());
24+
}
25+
26+
private OutputGuardrailResult(String successfulText, Object successfulResult) {
27+
this(Result.SUCCESS_WITH_RESULT, successfulText, successfulResult, Collections.emptyList());
2428
}
2529

2630
OutputGuardrailResult(List<Failure> failures, boolean fatal) {
27-
this(fatal ? Result.FATAL : Result.FAILURE, null, failures);
31+
this(fatal ? Result.FATAL : Result.FAILURE, null, null, failures);
2832
}
2933

3034
public static OutputGuardrailResult success() {
3135
return SUCCESS;
3236
}
3337

34-
public static OutputGuardrailResult successWith(String successfulResult) {
35-
return new OutputGuardrailResult(successfulResult);
38+
public static OutputGuardrailResult successWith(String successfulText) {
39+
return new OutputGuardrailResult(successfulText);
40+
}
41+
42+
public static OutputGuardrailResult successWith(String successfulText, Object successfulResult) {
43+
return new OutputGuardrailResult(successfulText, successfulResult);
3644
}
3745

3846
public static OutputGuardrailResult failure(List<? extends GuardrailResult.Failure> failures) {
@@ -45,7 +53,7 @@ public boolean isSuccess() {
4553
}
4654

4755
@Override
48-
public boolean isRewrittenResult() {
56+
public boolean hasRewrittenResult() {
4957
return result == Result.SUCCESS_WITH_RESULT;
5058
}
5159

@@ -88,7 +96,7 @@ public OutputGuardrailResult validatedBy(Class<? extends Guardrail> guardrailCla
8896
@Override
8997
public String toString() {
9098
if (isSuccess()) {
91-
return "success";
99+
return hasRewrittenResult() ? "Success with '" + successfulText + "'" : "Success";
92100
}
93101
return failures.stream().map(Failure::toString).collect(Collectors.joining(", "));
94102
}

core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import static dev.langchain4j.data.message.UserMessage.userMessage;
44
import static dev.langchain4j.internal.Exceptions.runtime;
5+
import static dev.langchain4j.model.output.TokenUsage.sum;
56
import static dev.langchain4j.service.AiServices.removeToolMessages;
67
import static dev.langchain4j.service.AiServices.verifyModerationIfNeeded;
78
import static io.quarkiverse.langchain4j.runtime.ResponseSchemaUtil.hasResponseSchema;
@@ -295,7 +296,7 @@ private List<ChatMessage> messagesToSend(ChatMessage augmentedUserMessage,
295296
throw new GuardrailsSupport.GuardrailRetryException();
296297
}
297298
} else {
298-
if (result.isRewrittenResult()) {
299+
if (result.hasRewrittenResult()) {
299300
throw new GuardrailException(
300301
"Attempting to rewrite the LLM output while streaming is not allowed");
301302
}
@@ -367,7 +368,7 @@ private List<ChatMessage> messagesToSend(ChatMessage augmentedUserMessage,
367368
audit.addLLMToApplicationMessage(response);
368369
}
369370

370-
tokenUsageAccumulator = tokenUsageAccumulator.add(response.tokenUsage());
371+
tokenUsageAccumulator = sum(tokenUsageAccumulator, response.tokenUsage());
371372
}
372373

373374
String userMessageTemplate = methodCreateInfo.getUserMessageTemplate();
@@ -380,7 +381,13 @@ private List<ChatMessage> messagesToSend(ChatMessage augmentedUserMessage,
380381
// everything worked as expected so let's commit the messages
381382
chatMemory.commit();
382383

383-
response = Response.from(response.content(), tokenUsageAccumulator, response.finishReason());
384+
Object guardrailResult = response.metadata().get(OutputGuardrailResult.class.getName());
385+
if (guardrailResult != null && isTypeOf(returnType, guardrailResult.getClass())) {
386+
return guardrailResult;
387+
}
388+
389+
response = Response.from(response.content(), tokenUsageAccumulator, response.finishReason(), response.metadata());
390+
384391
if (isResult(returnType)) {
385392
var parsedResponse = SERVICE_OUTPUT_PARSER.parse(response, resultTypeParam((ParameterizedType) returnType));
386393
return Result.builder()
@@ -389,9 +396,9 @@ private List<ChatMessage> messagesToSend(ChatMessage augmentedUserMessage,
389396
.sources(augmentationResult == null ? null : augmentationResult.contents())
390397
.finishReason(response.finishReason())
391398
.build();
392-
} else {
393-
return SERVICE_OUTPUT_PARSER.parse(response, returnType);
394399
}
400+
401+
return SERVICE_OUTPUT_PARSER.parse(response, returnType);
395402
}
396403

397404
private static Object doImplementGenerateImage(AiServiceMethodCreateInfo methodCreateInfo, QuarkusAiServiceContext context,

core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailsSupport.java

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22

33
import static dev.langchain4j.data.message.UserMessage.userMessage;
44

5-
import java.util.*;
5+
import java.util.ArrayList;
6+
import java.util.Collections;
7+
import java.util.List;
8+
import java.util.Map;
69
import java.util.function.Function;
710

811
import jakarta.enterprise.inject.spi.CDI;
@@ -100,17 +103,22 @@ public static Response<AiMessage> invokeOutputGuardrails(AiServiceMethodCreateIn
100103
throw new GuardrailException("Output validation failed. The guardrails have reached the maximum number of retries");
101104
}
102105

103-
if (result.isRewrittenResult()) {
104-
response = rewriteResponseWithText(response, result.successfulResult());
106+
if (result.hasRewrittenResult()) {
107+
response = rewriteResponse(response, result);
105108
}
106109

107110
return response;
108111
}
109112

110-
public static Response<AiMessage> rewriteResponseWithText(Response<AiMessage> response, String text) {
113+
public static Response<AiMessage> rewriteResponse(Response<AiMessage> response, OutputGuardrailResult result) {
111114
List<ToolExecutionRequest> tools = response.content().toolExecutionRequests();
112-
AiMessage content = tools != null && !tools.isEmpty() ? new AiMessage(text, tools) : new AiMessage(text);
113-
return new Response<>(content, response.tokenUsage(), response.finishReason(), response.metadata());
115+
AiMessage content = tools != null && !tools.isEmpty() ? new AiMessage(result.successfulText(), tools)
116+
: new AiMessage(result.successfulText());
117+
Map<String, Object> metadata = response.metadata();
118+
if (result.successfulResult() != null) {
119+
metadata.put(OutputGuardrailResult.class.getName(), result.successfulResult());
120+
}
121+
return new Response<>(content, response.tokenUsage(), response.finishReason(), metadata);
114122
}
115123

116124
@SuppressWarnings("unchecked")
@@ -173,10 +181,10 @@ private static <GR extends GuardrailResult> GR guardrailResult(GuardrailParams p
173181
for (Class<? extends Guardrail> bean : classes) {
174182
GR result = (GR) CDI.current().select(bean).get().validate(params).validatedBy(bean);
175183
if (result.isFatal()) {
176-
return accumulatedResults.isRewrittenResult() ? (GR) result.blockRetry() : result;
184+
return accumulatedResults.hasRewrittenResult() ? (GR) result.blockRetry() : result;
177185
}
178-
if (result.isRewrittenResult()) {
179-
params = params.withText(result.successfulResult());
186+
if (result.hasRewrittenResult()) {
187+
params = params.withText(result.successfulText());
180188
}
181189
accumulatedResults = compose(accumulatedResults, result, producer);
182190
}

0 commit comments

Comments
 (0)