Skip to content

Commit ab827c1

Browse files
committed
Introduce compile time safety to the agentic module
Quarkus will now check to make sure that declarative API conforms to all the requirements
1 parent 7463abc commit ab827c1

File tree

36 files changed

+1867
-16
lines changed

36 files changed

+1867
-16
lines changed

agentic/deployment/src/main/java/io/quarkiverse/langchain4j/agentic/deployment/AgenticLangChain4jDotNames.java

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,32 @@
55
import org.jboss.jandex.DotName;
66

77
import dev.langchain4j.agentic.Agent;
8+
import dev.langchain4j.agentic.agent.AgentRequest;
9+
import dev.langchain4j.agentic.agent.AgentResponse;
10+
import dev.langchain4j.agentic.agent.ErrorContext;
11+
import dev.langchain4j.agentic.agent.ErrorRecoveryResult;
12+
import dev.langchain4j.agentic.declarative.ActivationCondition;
13+
import dev.langchain4j.agentic.declarative.AfterAgentInvocation;
14+
import dev.langchain4j.agentic.declarative.BeforeAgentInvocation;
15+
import dev.langchain4j.agentic.declarative.ChatMemoryProviderSupplier;
16+
import dev.langchain4j.agentic.declarative.ChatMemorySupplier;
817
import dev.langchain4j.agentic.declarative.ChatModelSupplier;
918
import dev.langchain4j.agentic.declarative.ConditionalAgent;
19+
import dev.langchain4j.agentic.declarative.ContentRetrieverSupplier;
20+
import dev.langchain4j.agentic.declarative.ErrorHandler;
21+
import dev.langchain4j.agentic.declarative.ExitCondition;
22+
import dev.langchain4j.agentic.declarative.HumanInTheLoop;
23+
import dev.langchain4j.agentic.declarative.HumanInTheLoopResponseSupplier;
1024
import dev.langchain4j.agentic.declarative.LoopAgent;
25+
import dev.langchain4j.agentic.declarative.Output;
1126
import dev.langchain4j.agentic.declarative.ParallelAgent;
27+
import dev.langchain4j.agentic.declarative.ParallelExecutor;
28+
import dev.langchain4j.agentic.declarative.RetrievalAugmentorSupplier;
1229
import dev.langchain4j.agentic.declarative.SequenceAgent;
1330
import dev.langchain4j.agentic.declarative.SubAgent;
1431
import dev.langchain4j.agentic.declarative.SupervisorAgent;
32+
import dev.langchain4j.agentic.declarative.ToolProviderSupplier;
33+
import dev.langchain4j.agentic.declarative.ToolsSupplier;
1534
import dev.langchain4j.agentic.scope.AgenticScope;
1635
import dev.langchain4j.agentic.scope.ResultWithAgenticScope;
1736

@@ -32,6 +51,40 @@ public final class AgenticLangChain4jDotNames {
3251
public static final DotName AGENTIC_SCOPE = DotName.createSimple(AgenticScope.class);
3352
public static final DotName RESULT_WITH_AGENTIC_SCOPE = DotName.createSimple(ResultWithAgenticScope.class);
3453

54+
public static final DotName ACTIVATION_CONDITION = DotName.createSimple(ActivationCondition.class.getName());
55+
public static final DotName BEFORE_AGENT_INVOCATION = DotName.createSimple(BeforeAgentInvocation.class.getName());
56+
public static final DotName AFTER_AGENT_INVOCATION = DotName.createSimple(AfterAgentInvocation.class.getName());
57+
public static final DotName AGENT_REQUEST = DotName.createSimple(AgentRequest.class.getName());
58+
public static final DotName AGENT_RESPONSE = DotName.createSimple(AgentResponse.class.getName());
59+
public static final DotName CHAT_MEMORY_PROVIDER_SUPPLIER = DotName
60+
.createSimple(ChatMemoryProviderSupplier.class.getName());
61+
public static final DotName CHAT_MEMORY_SUPPLIER = DotName
62+
.createSimple(ChatMemorySupplier.class.getName());
63+
public static final DotName CONTENT_RETRIEVER_SUPPLIER = DotName
64+
.createSimple(ContentRetrieverSupplier.class.getName());
65+
public static final DotName ERROR_HANDLER = DotName
66+
.createSimple(ErrorHandler.class.getName());
67+
public static final DotName ERROR_CONTEXT = DotName
68+
.createSimple(ErrorContext.class.getName());
69+
public static final DotName ERROR_RECOVERY_RESULT = DotName
70+
.createSimple(ErrorRecoveryResult.class.getName());
71+
public static final DotName EXIT_CONDITION = DotName
72+
.createSimple(ExitCondition.class.getName());
73+
public static final DotName HUMAN_IN_THE_LOOP = DotName
74+
.createSimple(HumanInTheLoop.class.getName());
75+
public static final DotName HUMAN_IN_THE_LOOP_RESPONSE_SUPPLIER = DotName
76+
.createSimple(HumanInTheLoopResponseSupplier.class.getName());
77+
public static final DotName OUTPUT = DotName
78+
.createSimple(Output.class.getName());
79+
public static final DotName PARALLEL_EXECUTOR = DotName
80+
.createSimple(ParallelExecutor.class.getName());
81+
public static final DotName RETRIEVAL_AUGMENTER_SUPPLIER = DotName
82+
.createSimple(RetrievalAugmentorSupplier.class.getName());
83+
public static final DotName TOOL_PROVIDER_SUPPLIER = DotName
84+
.createSimple(ToolProviderSupplier.class.getName());
85+
public static final DotName TOOL_SUPPLIER = DotName
86+
.createSimple(ToolsSupplier.class.getName());
87+
3588
private AgenticLangChain4jDotNames() {
3689
}
3790
}

agentic/deployment/src/main/java/io/quarkiverse/langchain4j/agentic/deployment/AgenticProcessor.java

Lines changed: 287 additions & 4 deletions
Large diffs are not rendered by default.
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
package io.quarkiverse.langchain4j.agentic.deployment;
2+
3+
import java.lang.reflect.Modifier;
4+
import java.util.List;
5+
import java.util.Set;
6+
import java.util.stream.Collectors;
7+
8+
import org.jboss.jandex.DotName;
9+
import org.jboss.jandex.MethodInfo;
10+
11+
import dev.langchain4j.service.IllegalConfigurationException;
12+
13+
class ValidationUtil {
14+
15+
static void validateStaticMethod(MethodInfo method, DotName annotationName) {
16+
if (!Modifier.isStatic(method.flags())) {
17+
throw new IllegalConfigurationException(
18+
String.format("Methods annotated with '%s' must be static. Offending method is '%s' of class '%s'",
19+
annotationName, method.name(), method.declaringClass().name().toString()));
20+
}
21+
}
22+
23+
static void validateAllowedReturnTypes(MethodInfo method, Set<DotName> allowedReturnTypes, DotName annotationName) {
24+
if (!allowedReturnTypes.contains(method.returnType().name())) {
25+
throw new IllegalConfigurationException(
26+
String.format(
27+
"Methods annotated with '%s' can only use the following return types: '%s'. Offending method is '%s' of class '%s'",
28+
annotationName,
29+
allowedReturnTypes.stream().map(DotName::withoutPackagePrefix).collect(Collectors.joining(",")),
30+
method.name(),
31+
method.declaringClass().name().toString()));
32+
}
33+
}
34+
35+
static void validateRequiredParameterTypes(MethodInfo method, List<DotName> requiredParameterTypes,
36+
DotName annotationName) {
37+
if (method.parameters().size() != requiredParameterTypes.size()) {
38+
throw new IllegalConfigurationException(
39+
String.format(
40+
"Methods annotated with '%s' must use the following parameter types: '%s'. Offending method is '%s' of class '%s'",
41+
annotationName,
42+
requiredParameterTypes.stream().map(DotName::withoutPackagePrefix)
43+
.collect(Collectors.joining(",")),
44+
method.name(),
45+
method.declaringClass().name().toString()));
46+
}
47+
48+
for (int i = 0; i < requiredParameterTypes.size(); i++) {
49+
DotName parameterTypeName = method.parameters().get(i).type().name();
50+
if (!parameterTypeName.equals(requiredParameterTypes.get(i))) {
51+
throw new IllegalConfigurationException(
52+
String.format(
53+
"Methods annotated with '%s' must use the following parameter types: '%s'. Offending method is '%s' of class '%s'",
54+
annotationName,
55+
requiredParameterTypes.stream().map(DotName::withoutPackagePrefix)
56+
.collect(Collectors.joining(",")),
57+
method.name(),
58+
method.declaringClass().name().toString()));
59+
}
60+
}
61+
}
62+
63+
static void validateNoMethodParameters(MethodInfo method, DotName annotationName) {
64+
if (!method.parameters().isEmpty()) {
65+
throw new IllegalConfigurationException(
66+
String.format(
67+
"Methods annotated with '%s' cannot have any method parameters. Offending method is '%s' of class '%s'",
68+
annotationName,
69+
method.name(),
70+
method.declaringClass().name().toString()));
71+
}
72+
}
73+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package io.quarkiverse.langchain4j.agentic.deployment.validation;
2+
3+
import static org.junit.jupiter.api.Assertions.fail;
4+
5+
import org.assertj.core.api.Assertions;
6+
import org.jboss.shrinkwrap.api.ShrinkWrap;
7+
import org.jboss.shrinkwrap.api.spec.JavaArchive;
8+
import org.junit.jupiter.api.Test;
9+
import org.junit.jupiter.api.extension.RegisterExtension;
10+
11+
import dev.langchain4j.agentic.declarative.BeforeAgentInvocation;
12+
import dev.langchain4j.service.IllegalConfigurationException;
13+
import io.quarkiverse.langchain4j.agentic.deployment.Agents;
14+
import io.quarkus.test.QuarkusUnitTest;
15+
16+
public class NonAgentRequestParameterTypeBeforeAgentInvocationTest {
17+
18+
@RegisterExtension
19+
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
20+
.setArchiveProducer(
21+
() -> ShrinkWrap.create(JavaArchive.class).addClasses(Agents.class, StyleReviewLoopAgentWithListener.class))
22+
.assertException(
23+
throwable -> Assertions.assertThat(throwable).isInstanceOf(IllegalConfigurationException.class)
24+
.hasMessageContaining("AgentRequest"));
25+
26+
@Test
27+
public void test() {
28+
fail("should never be called");
29+
}
30+
31+
public interface StyleReviewLoopAgentWithListener extends Agents.StyleReviewLoopAgent {
32+
33+
@BeforeAgentInvocation
34+
static void beforeAgentInvocation(Void request) {
35+
36+
}
37+
}
38+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package io.quarkiverse.langchain4j.agentic.deployment.validation;
2+
3+
import static org.junit.jupiter.api.Assertions.fail;
4+
5+
import org.assertj.core.api.Assertions;
6+
import org.jboss.shrinkwrap.api.ShrinkWrap;
7+
import org.jboss.shrinkwrap.api.spec.JavaArchive;
8+
import org.junit.jupiter.api.Test;
9+
import org.junit.jupiter.api.extension.RegisterExtension;
10+
11+
import dev.langchain4j.agentic.declarative.AfterAgentInvocation;
12+
import dev.langchain4j.service.IllegalConfigurationException;
13+
import io.quarkiverse.langchain4j.agentic.deployment.Agents;
14+
import io.quarkus.test.QuarkusUnitTest;
15+
16+
public class NonAgentResponseParameterTypeAfterAgentInvocationTest {
17+
18+
@RegisterExtension
19+
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
20+
.setArchiveProducer(
21+
() -> ShrinkWrap.create(JavaArchive.class).addClasses(Agents.class, StyleReviewLoopAgentWithListener.class))
22+
.assertException(
23+
throwable -> Assertions.assertThat(throwable).isInstanceOf(IllegalConfigurationException.class)
24+
.hasMessageContaining("AgentResponse"));
25+
26+
@Test
27+
public void test() {
28+
fail("should never be called");
29+
}
30+
31+
public interface StyleReviewLoopAgentWithListener extends Agents.StyleReviewLoopAgent {
32+
33+
@AfterAgentInvocation
34+
static void afterAgentInvocation(Void response) {
35+
36+
}
37+
}
38+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
package io.quarkiverse.langchain4j.agentic.deployment.validation;
2+
3+
import static org.assertj.core.api.Assertions.assertThat;
4+
import static org.junit.jupiter.api.Assertions.fail;
5+
6+
import org.assertj.core.api.Assertions;
7+
import org.jboss.shrinkwrap.api.ShrinkWrap;
8+
import org.jboss.shrinkwrap.api.spec.JavaArchive;
9+
import org.junit.jupiter.api.Test;
10+
import org.junit.jupiter.api.extension.RegisterExtension;
11+
12+
import dev.langchain4j.agentic.declarative.ActivationCondition;
13+
import dev.langchain4j.agentic.declarative.ConditionalAgent;
14+
import dev.langchain4j.agentic.declarative.SubAgent;
15+
import dev.langchain4j.agentic.scope.AgenticScope;
16+
import dev.langchain4j.service.IllegalConfigurationException;
17+
import dev.langchain4j.service.V;
18+
import io.quarkiverse.langchain4j.agentic.deployment.Agents;
19+
import io.quarkus.test.QuarkusUnitTest;
20+
21+
public class NonBooleanReturnTypeActivationConditionTest {
22+
23+
@RegisterExtension
24+
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
25+
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClasses(Agents.class, ExpertsAgent.class))
26+
.assertException(
27+
throwable -> Assertions.assertThat(throwable).isInstanceOf(IllegalConfigurationException.class)
28+
.hasMessageContaining("boolean"));
29+
30+
@Test
31+
public void test() {
32+
fail("should never be called");
33+
}
34+
35+
public interface ExpertsAgent {
36+
37+
@ConditionalAgent(outputName = "response", subAgents = {
38+
@SubAgent(type = Agents.MedicalExpert.class, outputName = "response"),
39+
@SubAgent(type = Agents.TechnicalExpert.class, outputName = "response"),
40+
@SubAgent(type = Agents.LegalExpert.class, outputName = "response")
41+
})
42+
String askExpert(@V("request") String request);
43+
44+
@ActivationCondition(Agents.MedicalExpert.class)
45+
static boolean activateMedical(@V("category") Agents.RequestCategory category) {
46+
return category == Agents.RequestCategory.MEDICAL;
47+
}
48+
49+
@ActivationCondition(Agents.LegalExpert.class)
50+
static int activateLegal(AgenticScope agenticScope) {
51+
return 1;
52+
}
53+
}
54+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package io.quarkiverse.langchain4j.agentic.deployment.validation;
2+
3+
import static org.junit.jupiter.api.Assertions.fail;
4+
5+
import org.assertj.core.api.Assertions;
6+
import org.jboss.shrinkwrap.api.ShrinkWrap;
7+
import org.jboss.shrinkwrap.api.spec.JavaArchive;
8+
import org.junit.jupiter.api.Test;
9+
import org.junit.jupiter.api.extension.RegisterExtension;
10+
11+
import dev.langchain4j.agentic.declarative.ExitCondition;
12+
import dev.langchain4j.agentic.declarative.LoopAgent;
13+
import dev.langchain4j.agentic.declarative.LoopCounter;
14+
import dev.langchain4j.agentic.declarative.SubAgent;
15+
import dev.langchain4j.service.IllegalConfigurationException;
16+
import dev.langchain4j.service.V;
17+
import io.quarkiverse.langchain4j.agentic.deployment.Agents;
18+
import io.quarkus.test.QuarkusUnitTest;
19+
20+
public class NonBooleanReturnTypeExitConditionTest {
21+
22+
@RegisterExtension
23+
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
24+
.setArchiveProducer(
25+
() -> ShrinkWrap.create(JavaArchive.class).addClasses(Agents.class, StyleReviewLoopAgentWithCounter.class))
26+
.assertException(
27+
throwable -> Assertions.assertThat(throwable).isInstanceOf(IllegalConfigurationException.class)
28+
.hasMessageContaining("boolean"));
29+
30+
@Test
31+
public void test() {
32+
fail("should never be called");
33+
}
34+
35+
public interface StyleReviewLoopAgentWithCounter {
36+
37+
@LoopAgent(description = "Review the given story to ensure it aligns with the specified style", outputName = "story", maxIterations = 5, subAgents = {
38+
@SubAgent(type = Agents.StyleScorer.class, outputName = "score"),
39+
@SubAgent(type = Agents.StyleEditor.class, outputName = "story")
40+
})
41+
String write(@V("story") String story);
42+
43+
@ExitCondition(testExitAtLoopEnd = true)
44+
static Object exit(@V("score") double score, @LoopCounter int loopCounter) {
45+
return score >= 0.8;
46+
}
47+
}
48+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
package io.quarkiverse.langchain4j.agentic.deployment.validation;
2+
3+
import static org.junit.jupiter.api.Assertions.fail;
4+
5+
import org.assertj.core.api.Assertions;
6+
import org.jboss.shrinkwrap.api.ShrinkWrap;
7+
import org.jboss.shrinkwrap.api.spec.JavaArchive;
8+
import org.junit.jupiter.api.Test;
9+
import org.junit.jupiter.api.extension.RegisterExtension;
10+
11+
import dev.langchain4j.agentic.Agent;
12+
import dev.langchain4j.agentic.declarative.ChatMemoryProviderSupplier;
13+
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
14+
import dev.langchain4j.service.IllegalConfigurationException;
15+
import dev.langchain4j.service.MemoryId;
16+
import dev.langchain4j.service.UserMessage;
17+
import dev.langchain4j.service.V;
18+
import io.quarkus.test.QuarkusUnitTest;
19+
20+
public class NonChatMemoryReturnTypeChatMemoryProviderSupplierTest {
21+
22+
@RegisterExtension
23+
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
24+
.setArchiveProducer(
25+
() -> ShrinkWrap.create(JavaArchive.class).addClasses(LegalExpertWithMemory.class))
26+
.assertException(
27+
throwable -> Assertions.assertThat(throwable).isInstanceOf(IllegalConfigurationException.class)
28+
.hasMessageContaining("ChatMemory"));
29+
30+
@Test
31+
public void test() {
32+
fail("should never be called");
33+
}
34+
35+
public interface LegalExpertWithMemory {
36+
37+
@UserMessage("""
38+
You are a legal expert.
39+
Analyze the following user request under a legal point of view and provide the best possible answer.
40+
The user request is {{request}}.
41+
""")
42+
@Agent("A legal expert")
43+
String legal(@MemoryId String memoryId, @V("request") String request);
44+
45+
@ChatMemoryProviderSupplier
46+
static Object chatMemory(Object memoryId) {
47+
return MessageWindowChatMemory.withMaxMessages(10);
48+
}
49+
}
50+
}

0 commit comments

Comments
 (0)