Skip to content

Commit 4fdab1d

Browse files
authored
Merge pull request #1157 from quarkiverse/weather-agent
Allow Rest Client and AI Service to be used as tools
2 parents 742b816 + 989ae31 commit 4fdab1d

File tree

27 files changed

+796
-26
lines changed

27 files changed

+796
-26
lines changed

core/deployment/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@
4646
<optional>true</optional> <!-- conditional dependency -->
4747
</dependency>
4848

49+
<dependency>
50+
<groupId>org.eclipse.microprofile.rest.client</groupId>
51+
<artifactId>microprofile-rest-client-api</artifactId>
52+
</dependency>
53+
4954
<dependency>
5055
<groupId>io.quarkus</groupId>
5156
<artifactId>quarkus-vertx-http-dev-ui-tests</artifactId>

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,10 @@
4343
import jakarta.annotation.PreDestroy;
4444
import jakarta.enterprise.context.Dependent;
4545
import jakarta.enterprise.inject.spi.DeploymentException;
46+
import jakarta.enterprise.util.AnnotationLiteral;
4647
import jakarta.inject.Inject;
4748

49+
import org.eclipse.microprofile.rest.client.inject.RestClient;
4850
import org.jboss.jandex.AnnotationInstance;
4951
import org.jboss.jandex.AnnotationTarget;
5052
import org.jboss.jandex.AnnotationValue;
@@ -77,6 +79,7 @@
7779
import io.quarkiverse.langchain4j.deployment.items.MethodParameterIgnoredAnnotationsBuildItem;
7880
import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem;
7981
import io.quarkiverse.langchain4j.deployment.items.ToolMethodBuildItem;
82+
import io.quarkiverse.langchain4j.deployment.items.ToolQualifierProvider;
8083
import io.quarkiverse.langchain4j.guardrails.OutputGuardrail;
8184
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailAccumulator;
8285
import io.quarkiverse.langchain4j.runtime.AiServicesRecorder;
@@ -262,11 +265,18 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
262265
chatModelNames.add(chatModelName);
263266
}
264267

265-
List<DotName> toolDotNames = Collections.emptyList();
268+
List<ClassInfo> toolClassInfos = Collections.emptyList();
266269
AnnotationValue toolsInstance = instance.value("tools");
267270
if (toolsInstance != null) {
268-
toolDotNames = Arrays.stream(toolsInstance.asClassArray()).map(Type::name)
269-
.collect(Collectors.toList());
271+
toolClassInfos = Arrays.stream(toolsInstance.asClassArray()).map(t -> {
272+
var ci = index.getClassByName(t.name());
273+
if (ci == null) {
274+
throw new IllegalArgumentException("Cannot find class " + t.name()
275+
+ " in index. Please make sure it's a valid CDI bean known to Quarkus");
276+
}
277+
return ci;
278+
})
279+
.toList();
270280
}
271281

272282
// the default value depends on whether tools exists or not - if they do, then we require a ChatMemoryProvider bean
@@ -397,7 +407,7 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
397407
declarativeAiServiceClassInfo,
398408
chatLanguageModelSupplierClassDotName,
399409
streamingChatLanguageModelSupplierClassDotName,
400-
toolDotNames,
410+
toolClassInfos,
401411
chatMemoryProviderSupplierClassDotName,
402412
retrieverClassDotName,
403413
retrievalAugmentorSupplierClassName,
@@ -476,11 +486,27 @@ private boolean isImageOrImageResultResult(Type returnType) {
476486
return false;
477487
}
478488

489+
@BuildStep
490+
public void toolQualifiers(BuildProducer<ToolQualifierProvider.BuildItem> producer) {
491+
producer.produce(new ToolQualifierProvider.BuildItem(new ToolQualifierProvider() {
492+
@Override
493+
public boolean supports(ClassInfo classInfo) {
494+
return classInfo.hasAnnotation(DotNames.REGISTER_REST_CLIENT);
495+
}
496+
497+
@Override
498+
public AnnotationLiteral<?> qualifier(ClassInfo classInfo) {
499+
return new RestClient.RestClientLiteral();
500+
}
501+
}));
502+
}
503+
479504
@BuildStep
480505
@Record(ExecutionTime.STATIC_INIT)
481506
public void handleDeclarativeServices(AiServicesRecorder recorder,
482507
List<DeclarativeAiServiceBuildItem> declarativeAiServiceItems,
483508
List<SelectedChatModelProviderBuildItem> selectedChatModelProvider,
509+
List<ToolQualifierProvider.BuildItem> toolQualifierProviderItems,
484510
BuildProducer<SyntheticBeanBuildItem> syntheticBeanProducer,
485511
BuildProducer<UnremovableBeanBuildItem> unremovableProducer) {
486512

@@ -507,7 +533,19 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
507533
? bi.getStreamingChatLanguageModelSupplierClassDotName().toString()
508534
: null);
509535

510-
List<String> toolClassNames = bi.getToolDotNames().stream().map(DotName::toString).collect(Collectors.toList());
536+
List<ToolQualifierProvider> toolQualifierProviders = toolQualifierProviderItems.stream().map(
537+
ToolQualifierProvider.BuildItem::getProvider).toList();
538+
Map<String, AnnotationLiteral<?>> toolToQualifierMap = new HashMap<>();
539+
for (ClassInfo ci : bi.getToolClassInfos()) {
540+
AnnotationLiteral<?> qualifier = null;
541+
for (ToolQualifierProvider provider : toolQualifierProviders) {
542+
if (provider.supports(ci)) {
543+
qualifier = provider.qualifier(ci);
544+
break;
545+
}
546+
}
547+
toolToQualifierMap.put(ci.name().toString(), qualifier);
548+
}
511549

512550
String toolProviderSupplierClassName = (bi.getToolProviderClassDotName() != null
513551
? bi.getToolProviderClassDotName().toString()
@@ -597,7 +635,7 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
597635
serviceClassName,
598636
chatLanguageModelSupplierClassName,
599637
streamingChatLanguageModelSupplierClassName,
600-
toolClassNames,
638+
toolToQualifierMap,
601639
toolProviderSupplierClassName,
602640
chatMemoryProviderSupplierClassName, retrieverClassName,
603641
retrievalAugmentorSupplierClassName,
@@ -639,12 +677,16 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
639677
needsChatModelBean = true;
640678
}
641679

642-
if (!toolClassNames.isEmpty()) {
643-
for (String toolClassName : toolClassNames) {
644-
DotName dotName = DotName.createSimple(toolClassName);
680+
for (var entry : toolToQualifierMap.entrySet()) {
681+
DotName dotName = DotName.createSimple(entry.getKey());
682+
AnnotationLiteral<?> qualifier = entry.getValue();
683+
if (qualifier == null) {
645684
configurator.addInjectionPoint(ClassType.create(dotName));
646-
allToolNames.add(dotName);
685+
} else {
686+
configurator.addInjectionPoint(ClassType.create(dotName),
687+
AnnotationInstance.builder(qualifier.annotationType()).build());
647688
}
689+
allToolNames.add(dotName);
648690
}
649691

650692
if (LangChain4jDotNames.BEAN_CHAT_MEMORY_PROVIDER_SUPPLIER.toString().equals(chatMemoryProviderSupplierClassName)) {

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ public final class DeclarativeAiServiceBuildItem extends MultiBuildItem {
1616
private final ClassInfo serviceClassInfo;
1717
private final DotName chatLanguageModelSupplierClassDotName;
1818
private final DotName streamingChatLanguageModelSupplierClassDotName;
19-
private final List<DotName> toolDotNames;
19+
private final List<ClassInfo> toolClassInfos;
2020
private final DotName toolProviderClassDotName;
2121

2222
private final DotName chatMemoryProviderSupplierClassDotName;
@@ -37,7 +37,7 @@ public DeclarativeAiServiceBuildItem(
3737
ClassInfo serviceClassInfo,
3838
DotName chatLanguageModelSupplierClassDotName,
3939
DotName streamingChatLanguageModelSupplierClassDotName,
40-
List<DotName> toolDotNames,
40+
List<ClassInfo> toolClassInfos,
4141
DotName chatMemoryProviderSupplierClassDotName,
4242
DotName retrieverClassDotName,
4343
DotName retrievalAugmentorSupplierClassDotName,
@@ -55,7 +55,7 @@ public DeclarativeAiServiceBuildItem(
5555
this.serviceClassInfo = serviceClassInfo;
5656
this.chatLanguageModelSupplierClassDotName = chatLanguageModelSupplierClassDotName;
5757
this.streamingChatLanguageModelSupplierClassDotName = streamingChatLanguageModelSupplierClassDotName;
58-
this.toolDotNames = toolDotNames;
58+
this.toolClassInfos = toolClassInfos;
5959
this.chatMemoryProviderSupplierClassDotName = chatMemoryProviderSupplierClassDotName;
6060
this.retrieverClassDotName = retrieverClassDotName;
6161
this.retrievalAugmentorSupplierClassDotName = retrievalAugmentorSupplierClassDotName;
@@ -84,8 +84,8 @@ public DotName getStreamingChatLanguageModelSupplierClassDotName() {
8484
return streamingChatLanguageModelSupplierClassDotName;
8585
}
8686

87-
public List<DotName> getToolDotNames() {
88-
return toolDotNames;
87+
public List<ClassInfo> getToolClassInfos() {
88+
return toolClassInfos;
8989
}
9090

9191
public DotName getChatMemoryProviderSupplierClassDotName() {

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DotNames.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import jakarta.enterprise.inject.Instance;
1010

11+
import org.eclipse.microprofile.rest.client.inject.RegisterRestClient;
1112
import org.jboss.jandex.DotName;
1213

1314
import dev.langchain4j.agent.tool.Tool;
@@ -62,6 +63,8 @@ public class DotNames {
6263
public static final DotName MODEL_AUTH_PROVIDER = DotName.createSimple(ModelAuthProvider.class);
6364
public static final DotName TOOL = DotName.createSimple(Tool.class);
6465

66+
public static final DotName REGISTER_REST_CLIENT = DotName.createSimple(RegisterRestClient.class);
67+
6568
public static final DotName OUTPUT_GUARDRAIL_ACCUMULATOR = DotName.createSimple(OutputGuardrailAccumulator.class);
6669

6770
/**

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ public class ToolProcessor {
9292
private static final MethodDescriptor HASHMAP_CTOR = MethodDescriptor.ofConstructor(HashMap.class);
9393
public static final MethodDescriptor MAP_PUT = MethodDescriptor.ofMethod(Map.class, "put", Object.class, Object.class,
9494
Object.class);
95+
private static final ResultHandle[] EMPTY_RESULT_HANDLE_ARRAY = new ResultHandle[0];
9596

9697
private static final Logger log = Logger.getLogger(ToolProcessor.class);
9798

@@ -136,7 +137,19 @@ public void handleTools(
136137

137138
MethodInfo methodInfo = instance.target().asMethod();
138139
ClassInfo classInfo = methodInfo.declaringClass();
139-
if (classInfo.isInterface() || Modifier.isAbstract(classInfo.flags())) {
140+
boolean causeValidationError = false;
141+
if (classInfo.isInterface()) {
142+
143+
if (classInfo.hasAnnotation(LangChain4jDotNames.REGISTER_AI_SERVICES) || classInfo.hasAnnotation(
144+
DotNames.REGISTER_REST_CLIENT)) {
145+
// we allow tools on method of these interfaces because we know they will be beans
146+
} else {
147+
causeValidationError = true;
148+
}
149+
} else if (Modifier.isAbstract(classInfo.flags())) {
150+
causeValidationError = true;
151+
}
152+
if (causeValidationError) {
140153
validation.produce(
141154
new ValidationPhaseBuildItem.ValidationErrorBuildItem(new IllegalStateException(
142155
"@Tool is only supported on non-abstract classes, all other usages are ignored. Offending method is '"
@@ -409,16 +422,21 @@ private static String generateInvoker(MethodInfo methodInfo, ClassOutput classOu
409422
MethodDescriptor.ofMethod(implClassName, "invoke", Object.class, Object.class, Object[].class));
410423

411424
ResultHandle result;
425+
ResultHandle[] targetMethodHandles = EMPTY_RESULT_HANDLE_ARRAY;
412426
if (methodInfo.parametersCount() > 0) {
413427
List<ResultHandle> argumentHandles = new ArrayList<>(methodInfo.parametersCount());
414428
for (int i = 0; i < methodInfo.parametersCount(); i++) {
415429
argumentHandles.add(invokeMc.readArrayValue(invokeMc.getMethodParam(1), i));
416430
}
417-
ResultHandle[] targetMethodHandles = argumentHandles.toArray(new ResultHandle[0]);
418-
result = invokeMc.invokeVirtualMethod(MethodDescriptor.of(methodInfo), invokeMc.getMethodParam(0),
431+
targetMethodHandles = argumentHandles.toArray(EMPTY_RESULT_HANDLE_ARRAY);
432+
}
433+
434+
if (methodInfo.declaringClass().isInterface()) {
435+
result = invokeMc.invokeInterfaceMethod(MethodDescriptor.of(methodInfo), invokeMc.getMethodParam(0),
419436
targetMethodHandles);
420437
} else {
421-
result = invokeMc.invokeVirtualMethod(MethodDescriptor.of(methodInfo), invokeMc.getMethodParam(0));
438+
result = invokeMc.invokeVirtualMethod(MethodDescriptor.of(methodInfo), invokeMc.getMethodParam(0),
439+
targetMethodHandles);
422440
}
423441

424442
boolean toolReturnsVoid = methodInfo.returnType().kind() == Type.Kind.VOID;

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/devui/LangChain4jDevUIProcessor.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ private void addEmbeddingStorePage(CardPageBuildItem card) {
7878
private void addAiServicesPage(CardPageBuildItem card, List<DeclarativeAiServiceBuildItem> aiServices) {
7979
List<AiServiceInfo> infos = new ArrayList<>();
8080
for (DeclarativeAiServiceBuildItem aiService : aiServices) {
81-
List<String> tools = aiService.getToolDotNames().stream().map(dotName -> dotName.toString()).toList();
81+
List<String> tools = aiService.getToolClassInfos().stream().map(ci -> ci.name().toString()).toList();
8282
infos.add(new AiServiceInfo(aiService.getServiceClassInfo().name().toString(), tools));
8383
}
8484

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
package io.quarkiverse.langchain4j.deployment.items;
2+
3+
import jakarta.enterprise.util.AnnotationLiteral;
4+
5+
import org.jboss.jandex.ClassInfo;
6+
7+
import io.quarkus.builder.item.MultiBuildItem;
8+
9+
/**
10+
* Used to determine if a class containing a tool should be used along with a CDI qualifier
11+
*/
12+
public interface ToolQualifierProvider {
13+
14+
boolean supports(ClassInfo classInfo);
15+
16+
AnnotationLiteral<?> qualifier(ClassInfo classInfo);
17+
18+
final class BuildItem extends MultiBuildItem {
19+
20+
private final ToolQualifierProvider provider;
21+
22+
public BuildItem(ToolQualifierProvider provider) {
23+
this.provider = provider;
24+
}
25+
26+
public ToolQualifierProvider getProvider() {
27+
return provider;
28+
}
29+
}
30+
}

core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import java.util.function.Supplier;
1212

1313
import jakarta.enterprise.inject.Instance;
14+
import jakarta.enterprise.util.AnnotationLiteral;
1415
import jakarta.enterprise.util.TypeLiteral;
1516

1617
import dev.langchain4j.data.segment.TextSegment;
@@ -148,12 +149,21 @@ public T apply(SyntheticCreationalContext<T> creationalContext) {
148149
}
149150
}
150151

151-
List<String> toolsClasses = info.toolsClassNames();
152+
Map<String, AnnotationLiteral<?>> toolsClasses = info.toolsClassInfo();
152153
if ((toolsClasses != null) && !toolsClasses.isEmpty()) {
153154
List<Object> tools = new ArrayList<>(toolsClasses.size());
154-
for (String toolClass : toolsClasses) {
155-
Object tool = creationalContext.getInjectedReference(
156-
Thread.currentThread().getContextClassLoader().loadClass(toolClass));
155+
for (var entry : toolsClasses.entrySet()) {
156+
AnnotationLiteral<?> qualifier = entry.getValue();
157+
Object tool;
158+
if (qualifier != null) {
159+
tool = creationalContext.getInjectedReference(
160+
Thread.currentThread().getContextClassLoader().loadClass(entry.getKey()),
161+
qualifier);
162+
} else {
163+
tool = creationalContext.getInjectedReference(
164+
Thread.currentThread().getContextClassLoader().loadClass(entry.getKey()));
165+
}
166+
157167
tools.add(tool);
158168
}
159169
quarkusAiServices.tools(tools);

core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/DeclarativeAiServiceCreateInfo.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
package io.quarkiverse.langchain4j.runtime.aiservice;
22

3-
import java.util.List;
3+
import java.util.Map;
4+
5+
import jakarta.enterprise.util.AnnotationLiteral;
46

57
public record DeclarativeAiServiceCreateInfo(
68
String serviceClassName,
79
String languageModelSupplierClassName,
810
String streamingChatLanguageModelSupplierClassName,
9-
List<String> toolsClassNames,
11+
Map<String, AnnotationLiteral<?>> toolsClassInfo,
1012
String toolProviderSupplier,
1113
String chatMemoryProviderSupplierClassName,
1214
String retrieverClassName,

model-providers/openai/openai-vanilla/deployment/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@
4646
<artifactId>quarkus-smallrye-fault-tolerance</artifactId>
4747
<scope>test</scope>
4848
</dependency>
49+
<dependency>
50+
<groupId>io.quarkus</groupId>
51+
<artifactId>quarkus-rest</artifactId>
52+
<scope>test</scope>
53+
</dependency>
4954
<dependency>
5055
<groupId>io.smallrye.certs</groupId>
5156
<artifactId>smallrye-certificate-generator-junit5</artifactId>

0 commit comments

Comments
 (0)