Skip to content

Commit 9f6b857

Browse files
committed
Add AbstractJsonExtractorOutputGuardrail to guardrails docs
1 parent 74b7c12 commit 9f6b857

File tree

1 file changed

+78
-32
lines changed

1 file changed

+78
-32
lines changed

docs/modules/ROOT/pages/guardrails.adoc

Lines changed: 78 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -518,62 +518,108 @@ It may happen that the output generated by the LLM is not completely satisfying,
518518
{"name":"Alex", age:18} Alex is 18 since he became an adult a few days ago.
519519
----
520520

521-
In this situation it is better to try to programmatically trim the json part of the response and check if we can deserialize a valid Person object out of it, before trying to reprompt the LLM again. If the programmatic extraction of the json string from the partially hallucinated LLM output succeeds, it is possible to propagate the rewritten output through the `successWith` method as in the following example.
521+
In this situation it is better to try to programmatically trim the json part of the response and check if we can deserialize a valid Person object out of it, before trying to reprompt the LLM again. If the programmatic extraction of the json string from the partially hallucinated LLM output succeeds, it is possible to propagate the rewritten output through the `successWith` method.
522+
523+
This scenario is so common that it is already provided an abstract class implementing the `OutputGuardrail` interface and performing this programmatic json sanitization out-of-the-box.
522524

523525
[source,java]
524526
----
525-
import com.fasterxml.jackson.core.JsonProcessingException;
526-
import com.fasterxml.jackson.databind.ObjectMapper;
527-
import dev.langchain4j.data.message.AiMessage;
528-
import io.quarkiverse.langchain4j.guardrails.OutputGuardrail;
529-
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult;
530-
import jakarta.enterprise.context.ApplicationScoped;
531527
import jakarta.inject.Inject;
532528
import org.jboss.logging.Logger;
529+
import com.fasterxml.jackson.core.type.TypeReference;
530+
import dev.langchain4j.data.message.AiMessage;
533531
534-
@ApplicationScoped
535-
public class ValidJsonOutputGuardrail implements OutputGuardrail {
536-
537-
private static final ObjectMapper MAPPER = new ObjectMapper();
532+
public abstract class AbstractJsonExtractorOutputGuardrail implements OutputGuardrail {
538533
539534
@Inject
540535
Logger logger;
541536
537+
@Inject
538+
JsonGuardrailsUtils jsonGuardrailsUtils;
539+
540+
protected AbstractJsonExtractorOutputGuardrail() {
541+
if (getOutputClass() == null && getOutputType() == null) {
542+
throw new IllegalArgumentException("Either getOutputClass() or getOutputType() must be implemented");
543+
}
544+
}
545+
542546
@Override
543547
public OutputGuardrailResult validate(AiMessage responseFromLLM) {
544548
String llmResponse = responseFromLLM.text();
545-
logger.infof("LLM output: %s", llmResponse);
549+
logger.debugf("LLM output: %s", llmResponse);
546550
547-
if (validateJson(llmResponse, Person.class)) {
548-
return success();
551+
Object result = deserialize(llmResponse);
552+
if (result != null) {
553+
return successWith(llmResponse, result);
549554
}
550555
551-
String json = trimNonJson(llmResponse);
552-
if (json != null && validateJson(json, Person.class)) {
553-
return successWith(json);
556+
String json = jsonGuardrailsUtils.trimNonJson(llmResponse);
557+
if (json != null) {
558+
result = deserialize(json);
559+
if (result != null) {
560+
return successWith(json, result);
561+
}
554562
}
555563
556564
return reprompt("Invalid JSON",
557-
"Make sure you return a valid JSON object following "
558-
+ "the specified format");
565+
"Make sure you return a valid JSON object following "
566+
+ "the specified format");
559567
}
560568
561-
private static String trimNonJson(String llmResponse) {
562-
int jsonStart = llmResponse.indexOf("{");
563-
int jsonEnd = llmResponse.indexOf("}");
564-
if (jsonStart >= 0 && jsonEnd >= 0 && jsonStart < jsonEnd) {
565-
return llmResponse.substring(jsonStart + 1, jsonEnd);
566-
}
569+
protected Object deserialize(String llmResponse) {
570+
return getOutputClass() != null ? jsonGuardrailsUtils.deserialize(llmResponse, getOutputClass())
571+
: jsonGuardrailsUtils.deserialize(llmResponse, getOutputType());
572+
}
573+
574+
protected Class<?> getOutputClass() {
567575
return null;
568576
}
569577
570-
private static boolean validateJson(String json, Class<?> expectedOutputClass) {
571-
try {
572-
MAPPER.readValue(json, expectedOutputClass);
573-
return true;
574-
} catch (JsonProcessingException e) {
575-
return false;
576-
}
578+
protected TypeReference<?> getOutputType() {
579+
return null;
580+
}
581+
}
582+
----
583+
584+
This implementation, first tries to deserialize the LLM response into the expected class to be returned by the data extraction. If this doesn't succeed it tries to trim away the non-json part of the response and perform the deserialization again. Note that in both case together with the json response, either the original LLM one or the one programmatically trimmed, the `successWith` method also returns the resulting deserialized object, so that it could be used directly as the final response of the data extraction, instead of uselessly having to execute a second deserialization. In case that both these attempts of deserialization fail then the `OutputGuardrail` perform a reprompt, hoping that the LLM will finally produce a valid json string.
585+
586+
In this way if for example there is an AI service trying to extract the data of a customer from the user prompts like the following
587+
588+
[source,java]
589+
----
590+
@RegisterAiService
591+
public interface CustomerExtractor {
592+
593+
@UserMessage("Extract information about a customer from this text '{text}'. The response must contain only the JSON with customer's data and without any other sentence.")
594+
@OutputGuardrails(CustomerExtractionOutputGuardrail.class)
595+
Customer extractData(String text);
596+
}
597+
----
598+
599+
it is possible to use with it an `OutputGuardrail` that sanitizes the json LLM response by simply extending the former abstract class and declaring which is the expected output class of the data extraction.
600+
601+
[source,java]
602+
----
603+
@ApplicationScoped
604+
public class CustomerExtractionOutputGuardrail extends AbstractJsonExtractorOutputGuardrail {
605+
606+
@Override
607+
protected Class<?> getOutputClass() {
608+
return Customer.class;
609+
}
610+
}
611+
----
612+
613+
Note that if the data extraction requires a generified Java type, like a `List<Customer>`, it is conversely necessary to extend the `getOutputType` and return a Jackson's `TypeReference` as it follows:
614+
615+
[source,java]
616+
----
617+
@ApplicationScoped
618+
public class CustomersExtractionOutputGuardrail extends AbstractJsonExtractorOutputGuardrail {
619+
620+
@Override
621+
protected TypeReference<?> getOutputType() {
622+
return new TypeReference<List<Customer>>() {};
577623
}
578624
}
579625
----

0 commit comments

Comments
 (0)