Skip to content

Commit 653b06a

Browse files
authored
Merge pull request #1858 from quarkiverse/agentic-validation
Introduce compile time safety to the agentic module
2 parents 7463abc + ab827c1 commit 653b06a

File tree

36 files changed

+1867
-16
lines changed

36 files changed

+1867
-16
lines changed

agentic/deployment/src/main/java/io/quarkiverse/langchain4j/agentic/deployment/AgenticLangChain4jDotNames.java

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,32 @@
55
import org.jboss.jandex.DotName;
66

77
import dev.langchain4j.agentic.Agent;
8+
import dev.langchain4j.agentic.agent.AgentRequest;
9+
import dev.langchain4j.agentic.agent.AgentResponse;
10+
import dev.langchain4j.agentic.agent.ErrorContext;
11+
import dev.langchain4j.agentic.agent.ErrorRecoveryResult;
12+
import dev.langchain4j.agentic.declarative.ActivationCondition;
13+
import dev.langchain4j.agentic.declarative.AfterAgentInvocation;
14+
import dev.langchain4j.agentic.declarative.BeforeAgentInvocation;
15+
import dev.langchain4j.agentic.declarative.ChatMemoryProviderSupplier;
16+
import dev.langchain4j.agentic.declarative.ChatMemorySupplier;
817
import dev.langchain4j.agentic.declarative.ChatModelSupplier;
918
import dev.langchain4j.agentic.declarative.ConditionalAgent;
19+
import dev.langchain4j.agentic.declarative.ContentRetrieverSupplier;
20+
import dev.langchain4j.agentic.declarative.ErrorHandler;
21+
import dev.langchain4j.agentic.declarative.ExitCondition;
22+
import dev.langchain4j.agentic.declarative.HumanInTheLoop;
23+
import dev.langchain4j.agentic.declarative.HumanInTheLoopResponseSupplier;
1024
import dev.langchain4j.agentic.declarative.LoopAgent;
25+
import dev.langchain4j.agentic.declarative.Output;
1126
import dev.langchain4j.agentic.declarative.ParallelAgent;
27+
import dev.langchain4j.agentic.declarative.ParallelExecutor;
28+
import dev.langchain4j.agentic.declarative.RetrievalAugmentorSupplier;
1229
import dev.langchain4j.agentic.declarative.SequenceAgent;
1330
import dev.langchain4j.agentic.declarative.SubAgent;
1431
import dev.langchain4j.agentic.declarative.SupervisorAgent;
32+
import dev.langchain4j.agentic.declarative.ToolProviderSupplier;
33+
import dev.langchain4j.agentic.declarative.ToolsSupplier;
1534
import dev.langchain4j.agentic.scope.AgenticScope;
1635
import dev.langchain4j.agentic.scope.ResultWithAgenticScope;
1736

@@ -32,6 +51,40 @@ public final class AgenticLangChain4jDotNames {
3251
public static final DotName AGENTIC_SCOPE = DotName.createSimple(AgenticScope.class);
3352
public static final DotName RESULT_WITH_AGENTIC_SCOPE = DotName.createSimple(ResultWithAgenticScope.class);
3453

54+
public static final DotName ACTIVATION_CONDITION = DotName.createSimple(ActivationCondition.class.getName());
55+
public static final DotName BEFORE_AGENT_INVOCATION = DotName.createSimple(BeforeAgentInvocation.class.getName());
56+
public static final DotName AFTER_AGENT_INVOCATION = DotName.createSimple(AfterAgentInvocation.class.getName());
57+
public static final DotName AGENT_REQUEST = DotName.createSimple(AgentRequest.class.getName());
58+
public static final DotName AGENT_RESPONSE = DotName.createSimple(AgentResponse.class.getName());
59+
public static final DotName CHAT_MEMORY_PROVIDER_SUPPLIER = DotName
60+
.createSimple(ChatMemoryProviderSupplier.class.getName());
61+
public static final DotName CHAT_MEMORY_SUPPLIER = DotName
62+
.createSimple(ChatMemorySupplier.class.getName());
63+
public static final DotName CONTENT_RETRIEVER_SUPPLIER = DotName
64+
.createSimple(ContentRetrieverSupplier.class.getName());
65+
public static final DotName ERROR_HANDLER = DotName
66+
.createSimple(ErrorHandler.class.getName());
67+
public static final DotName ERROR_CONTEXT = DotName
68+
.createSimple(ErrorContext.class.getName());
69+
public static final DotName ERROR_RECOVERY_RESULT = DotName
70+
.createSimple(ErrorRecoveryResult.class.getName());
71+
public static final DotName EXIT_CONDITION = DotName
72+
.createSimple(ExitCondition.class.getName());
73+
public static final DotName HUMAN_IN_THE_LOOP = DotName
74+
.createSimple(HumanInTheLoop.class.getName());
75+
public static final DotName HUMAN_IN_THE_LOOP_RESPONSE_SUPPLIER = DotName
76+
.createSimple(HumanInTheLoopResponseSupplier.class.getName());
77+
public static final DotName OUTPUT = DotName
78+
.createSimple(Output.class.getName());
79+
public static final DotName PARALLEL_EXECUTOR = DotName
80+
.createSimple(ParallelExecutor.class.getName());
81+
public static final DotName RETRIEVAL_AUGMENTER_SUPPLIER = DotName
82+
.createSimple(RetrievalAugmentorSupplier.class.getName());
83+
public static final DotName TOOL_PROVIDER_SUPPLIER = DotName
84+
.createSimple(ToolProviderSupplier.class.getName());
85+
public static final DotName TOOL_SUPPLIER = DotName
86+
.createSimple(ToolsSupplier.class.getName());
87+
3588
private AgenticLangChain4jDotNames() {
3689
}
3790
}

agentic/deployment/src/main/java/io/quarkiverse/langchain4j/agentic/deployment/AgenticProcessor.java

Lines changed: 287 additions & 4 deletions
Large diffs are not rendered by default.
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
package io.quarkiverse.langchain4j.agentic.deployment;
2+
3+
import java.lang.reflect.Modifier;
4+
import java.util.List;
5+
import java.util.Set;
6+
import java.util.stream.Collectors;
7+
8+
import org.jboss.jandex.DotName;
9+
import org.jboss.jandex.MethodInfo;
10+
11+
import dev.langchain4j.service.IllegalConfigurationException;
12+
13+
class ValidationUtil {
14+
15+
static void validateStaticMethod(MethodInfo method, DotName annotationName) {
16+
if (!Modifier.isStatic(method.flags())) {
17+
throw new IllegalConfigurationException(
18+
String.format("Methods annotated with '%s' must be static. Offending method is '%s' of class '%s'",
19+
annotationName, method.name(), method.declaringClass().name().toString()));
20+
}
21+
}
22+
23+
static void validateAllowedReturnTypes(MethodInfo method, Set<DotName> allowedReturnTypes, DotName annotationName) {
24+
if (!allowedReturnTypes.contains(method.returnType().name())) {
25+
throw new IllegalConfigurationException(
26+
String.format(
27+
"Methods annotated with '%s' can only use the following return types: '%s'. Offending method is '%s' of class '%s'",
28+
annotationName,
29+
allowedReturnTypes.stream().map(DotName::withoutPackagePrefix).collect(Collectors.joining(",")),
30+
method.name(),
31+
method.declaringClass().name().toString()));
32+
}
33+
}
34+
35+
static void validateRequiredParameterTypes(MethodInfo method, List<DotName> requiredParameterTypes,
36+
DotName annotationName) {
37+
if (method.parameters().size() != requiredParameterTypes.size()) {
38+
throw new IllegalConfigurationException(
39+
String.format(
40+
"Methods annotated with '%s' must use the following parameter types: '%s'. Offending method is '%s' of class '%s'",
41+
annotationName,
42+
requiredParameterTypes.stream().map(DotName::withoutPackagePrefix)
43+
.collect(Collectors.joining(",")),
44+
method.name(),
45+
method.declaringClass().name().toString()));
46+
}
47+
48+
for (int i = 0; i < requiredParameterTypes.size(); i++) {
49+
DotName parameterTypeName = method.parameters().get(i).type().name();
50+
if (!parameterTypeName.equals(requiredParameterTypes.get(i))) {
51+
throw new IllegalConfigurationException(
52+
String.format(
53+
"Methods annotated with '%s' must use the following parameter types: '%s'. Offending method is '%s' of class '%s'",
54+
annotationName,
55+
requiredParameterTypes.stream().map(DotName::withoutPackagePrefix)
56+
.collect(Collectors.joining(",")),
57+
method.name(),
58+
method.declaringClass().name().toString()));
59+
}
60+
}
61+
}
62+
63+
static void validateNoMethodParameters(MethodInfo method, DotName annotationName) {
64+
if (!method.parameters().isEmpty()) {
65+
throw new IllegalConfigurationException(
66+
String.format(
67+
"Methods annotated with '%s' cannot have any method parameters. Offending method is '%s' of class '%s'",
68+
annotationName,
69+
method.name(),
70+
method.declaringClass().name().toString()));
71+
}
72+
}
73+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package io.quarkiverse.langchain4j.agentic.deployment.validation;
2+
3+
import static org.junit.jupiter.api.Assertions.fail;
4+
5+
import org.assertj.core.api.Assertions;
6+
import org.jboss.shrinkwrap.api.ShrinkWrap;
7+
import org.jboss.shrinkwrap.api.spec.JavaArchive;
8+
import org.junit.jupiter.api.Test;
9+
import org.junit.jupiter.api.extension.RegisterExtension;
10+
11+
import dev.langchain4j.agentic.declarative.BeforeAgentInvocation;
12+
import dev.langchain4j.service.IllegalConfigurationException;
13+
import io.quarkiverse.langchain4j.agentic.deployment.Agents;
14+
import io.quarkus.test.QuarkusUnitTest;
15+
16+
public class NonAgentRequestParameterTypeBeforeAgentInvocationTest {
17+
18+
@RegisterExtension
19+
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
20+
.setArchiveProducer(
21+
() -> ShrinkWrap.create(JavaArchive.class).addClasses(Agents.class, StyleReviewLoopAgentWithListener.class))
22+
.assertException(
23+
throwable -> Assertions.assertThat(throwable).isInstanceOf(IllegalConfigurationException.class)
24+
.hasMessageContaining("AgentRequest"));
25+
26+
@Test
27+
public void test() {
28+
fail("should never be called");
29+
}
30+
31+
public interface StyleReviewLoopAgentWithListener extends Agents.StyleReviewLoopAgent {
32+
33+
@BeforeAgentInvocation
34+
static void beforeAgentInvocation(Void request) {
35+
36+
}
37+
}
38+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package io.quarkiverse.langchain4j.agentic.deployment.validation;
2+
3+
import static org.junit.jupiter.api.Assertions.fail;
4+
5+
import org.assertj.core.api.Assertions;
6+
import org.jboss.shrinkwrap.api.ShrinkWrap;
7+
import org.jboss.shrinkwrap.api.spec.JavaArchive;
8+
import org.junit.jupiter.api.Test;
9+
import org.junit.jupiter.api.extension.RegisterExtension;
10+
11+
import dev.langchain4j.agentic.declarative.AfterAgentInvocation;
12+
import dev.langchain4j.service.IllegalConfigurationException;
13+
import io.quarkiverse.langchain4j.agentic.deployment.Agents;
14+
import io.quarkus.test.QuarkusUnitTest;
15+
16+
public class NonAgentResponseParameterTypeAfterAgentInvocationTest {
17+
18+
@RegisterExtension
19+
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
20+
.setArchiveProducer(
21+
() -> ShrinkWrap.create(JavaArchive.class).addClasses(Agents.class, StyleReviewLoopAgentWithListener.class))
22+
.assertException(
23+
throwable -> Assertions.assertThat(throwable).isInstanceOf(IllegalConfigurationException.class)
24+
.hasMessageContaining("AgentResponse"));
25+
26+
@Test
27+
public void test() {
28+
fail("should never be called");
29+
}
30+
31+
public interface StyleReviewLoopAgentWithListener extends Agents.StyleReviewLoopAgent {
32+
33+
@AfterAgentInvocation
34+
static void afterAgentInvocation(Void response) {
35+
36+
}
37+
}
38+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
package io.quarkiverse.langchain4j.agentic.deployment.validation;
2+
3+
import static org.assertj.core.api.Assertions.assertThat;
4+
import static org.junit.jupiter.api.Assertions.fail;
5+
6+
import org.assertj.core.api.Assertions;
7+
import org.jboss.shrinkwrap.api.ShrinkWrap;
8+
import org.jboss.shrinkwrap.api.spec.JavaArchive;
9+
import org.junit.jupiter.api.Test;
10+
import org.junit.jupiter.api.extension.RegisterExtension;
11+
12+
import dev.langchain4j.agentic.declarative.ActivationCondition;
13+
import dev.langchain4j.agentic.declarative.ConditionalAgent;
14+
import dev.langchain4j.agentic.declarative.SubAgent;
15+
import dev.langchain4j.agentic.scope.AgenticScope;
16+
import dev.langchain4j.service.IllegalConfigurationException;
17+
import dev.langchain4j.service.V;
18+
import io.quarkiverse.langchain4j.agentic.deployment.Agents;
19+
import io.quarkus.test.QuarkusUnitTest;
20+
21+
public class NonBooleanReturnTypeActivationConditionTest {
22+
23+
@RegisterExtension
24+
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
25+
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClasses(Agents.class, ExpertsAgent.class))
26+
.assertException(
27+
throwable -> Assertions.assertThat(throwable).isInstanceOf(IllegalConfigurationException.class)
28+
.hasMessageContaining("boolean"));
29+
30+
@Test
31+
public void test() {
32+
fail("should never be called");
33+
}
34+
35+
public interface ExpertsAgent {
36+
37+
@ConditionalAgent(outputName = "response", subAgents = {
38+
@SubAgent(type = Agents.MedicalExpert.class, outputName = "response"),
39+
@SubAgent(type = Agents.TechnicalExpert.class, outputName = "response"),
40+
@SubAgent(type = Agents.LegalExpert.class, outputName = "response")
41+
})
42+
String askExpert(@V("request") String request);
43+
44+
@ActivationCondition(Agents.MedicalExpert.class)
45+
static boolean activateMedical(@V("category") Agents.RequestCategory category) {
46+
return category == Agents.RequestCategory.MEDICAL;
47+
}
48+
49+
@ActivationCondition(Agents.LegalExpert.class)
50+
static int activateLegal(AgenticScope agenticScope) {
51+
return 1;
52+
}
53+
}
54+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package io.quarkiverse.langchain4j.agentic.deployment.validation;
2+
3+
import static org.junit.jupiter.api.Assertions.fail;
4+
5+
import org.assertj.core.api.Assertions;
6+
import org.jboss.shrinkwrap.api.ShrinkWrap;
7+
import org.jboss.shrinkwrap.api.spec.JavaArchive;
8+
import org.junit.jupiter.api.Test;
9+
import org.junit.jupiter.api.extension.RegisterExtension;
10+
11+
import dev.langchain4j.agentic.declarative.ExitCondition;
12+
import dev.langchain4j.agentic.declarative.LoopAgent;
13+
import dev.langchain4j.agentic.declarative.LoopCounter;
14+
import dev.langchain4j.agentic.declarative.SubAgent;
15+
import dev.langchain4j.service.IllegalConfigurationException;
16+
import dev.langchain4j.service.V;
17+
import io.quarkiverse.langchain4j.agentic.deployment.Agents;
18+
import io.quarkus.test.QuarkusUnitTest;
19+
20+
public class NonBooleanReturnTypeExitConditionTest {
21+
22+
@RegisterExtension
23+
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
24+
.setArchiveProducer(
25+
() -> ShrinkWrap.create(JavaArchive.class).addClasses(Agents.class, StyleReviewLoopAgentWithCounter.class))
26+
.assertException(
27+
throwable -> Assertions.assertThat(throwable).isInstanceOf(IllegalConfigurationException.class)
28+
.hasMessageContaining("boolean"));
29+
30+
@Test
31+
public void test() {
32+
fail("should never be called");
33+
}
34+
35+
public interface StyleReviewLoopAgentWithCounter {
36+
37+
@LoopAgent(description = "Review the given story to ensure it aligns with the specified style", outputName = "story", maxIterations = 5, subAgents = {
38+
@SubAgent(type = Agents.StyleScorer.class, outputName = "score"),
39+
@SubAgent(type = Agents.StyleEditor.class, outputName = "story")
40+
})
41+
String write(@V("story") String story);
42+
43+
@ExitCondition(testExitAtLoopEnd = true)
44+
static Object exit(@V("score") double score, @LoopCounter int loopCounter) {
45+
return score >= 0.8;
46+
}
47+
}
48+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
package io.quarkiverse.langchain4j.agentic.deployment.validation;
2+
3+
import static org.junit.jupiter.api.Assertions.fail;
4+
5+
import org.assertj.core.api.Assertions;
6+
import org.jboss.shrinkwrap.api.ShrinkWrap;
7+
import org.jboss.shrinkwrap.api.spec.JavaArchive;
8+
import org.junit.jupiter.api.Test;
9+
import org.junit.jupiter.api.extension.RegisterExtension;
10+
11+
import dev.langchain4j.agentic.Agent;
12+
import dev.langchain4j.agentic.declarative.ChatMemoryProviderSupplier;
13+
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
14+
import dev.langchain4j.service.IllegalConfigurationException;
15+
import dev.langchain4j.service.MemoryId;
16+
import dev.langchain4j.service.UserMessage;
17+
import dev.langchain4j.service.V;
18+
import io.quarkus.test.QuarkusUnitTest;
19+
20+
public class NonChatMemoryReturnTypeChatMemoryProviderSupplierTest {
21+
22+
@RegisterExtension
23+
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
24+
.setArchiveProducer(
25+
() -> ShrinkWrap.create(JavaArchive.class).addClasses(LegalExpertWithMemory.class))
26+
.assertException(
27+
throwable -> Assertions.assertThat(throwable).isInstanceOf(IllegalConfigurationException.class)
28+
.hasMessageContaining("ChatMemory"));
29+
30+
@Test
31+
public void test() {
32+
fail("should never be called");
33+
}
34+
35+
public interface LegalExpertWithMemory {
36+
37+
@UserMessage("""
38+
You are a legal expert.
39+
Analyze the following user request under a legal point of view and provide the best possible answer.
40+
The user request is {{request}}.
41+
""")
42+
@Agent("A legal expert")
43+
String legal(@MemoryId String memoryId, @V("request") String request);
44+
45+
@ChatMemoryProviderSupplier
46+
static Object chatMemory(Object memoryId) {
47+
return MessageWindowChatMemory.withMaxMessages(10);
48+
}
49+
}
50+
}

0 commit comments

Comments
 (0)