Skip to content

Commit fe4865d

Browse files
authored
Merge pull request #1166 from jmartisk/tool-provider-registration
Fix BeanIfExistsToolProviderSupplier, introduce NoToolProviderSupplier
2 parents 6194218 + d456262 commit fe4865d

File tree

5 files changed

+68
-17
lines changed

5 files changed

+68
-17
lines changed

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -349,10 +349,14 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
349349
DotName toolProviderClassName = LangChain4jDotNames.BEAN_IF_EXISTS_TOOL_PROVIDER_SUPPLIER;
350350
AnnotationValue toolProviderValue = instance.value("toolProviderSupplier");
351351
if (toolProviderValue != null) {
352-
toolProviderClassName = toolProviderValue.asClass().name();
353-
validateSupplierAndRegisterForReflection(toolProviderClassName, index, reflectiveClassProducer);
354-
toolProviderInfos.add(new ToolProviderInfo(toolProviderClassName.toString(),
355-
declarativeAiServiceClassInfo.simpleName()));
352+
if (LangChain4jDotNames.NO_TOOL_PROVIDER_SUPPLIER.equals(toolProviderValue.asClass().name())) {
353+
toolProviderClassName = null;
354+
} else {
355+
toolProviderClassName = toolProviderValue.asClass().name();
356+
validateSupplierAndRegisterForReflection(toolProviderClassName, index, reflectiveClassProducer);
357+
toolProviderInfos.add(new ToolProviderInfo(toolProviderClassName.toString(),
358+
declarativeAiServiceClassInfo.simpleName()));
359+
}
356360
}
357361

358362
DotName imageModelSupplierClassName = LangChain4jDotNames.BEAN_IF_EXISTS_IMAGE_MODEL_SUPPLIER;
@@ -753,7 +757,11 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
753757
needsImageModelBean = true;
754758
}
755759

756-
if (!RegisterAiService.BeanIfExistsToolProviderSupplier.class.getName()
760+
if (RegisterAiService.BeanIfExistsToolProviderSupplier.class.getName()
761+
.equals(toolProviderSupplierClassName)) {
762+
configurator.addInjectionPoint(ParameterizedType.create(DotNames.CDI_INSTANCE,
763+
new Type[] { ClassType.create(LangChain4jDotNames.TOOL_PROVIDER) }, null));
764+
} else if (!RegisterAiService.BeanIfExistsToolProviderSupplier.class.getName()
757765
.equals(toolProviderSupplierClassName) && toolProviderSupplierClassName != null) {
758766
DotName toolProvider = DotName.createSimple(toolProviderSupplierClassName);
759767
configurator.addInjectionPoint(ClassType.create(toolProvider));

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/LangChain4jDotNames.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import dev.langchain4j.service.TokenStream;
2626
import dev.langchain4j.service.UserMessage;
2727
import dev.langchain4j.service.UserName;
28+
import dev.langchain4j.service.tool.ToolProvider;
2829
import dev.langchain4j.web.search.WebSearchEngine;
2930
import dev.langchain4j.web.search.WebSearchTool;
3031
import io.quarkiverse.langchain4j.CreatedAware;
@@ -104,6 +105,9 @@ public class LangChain4jDotNames {
104105
static final DotName BEAN_IF_EXISTS_TOOL_PROVIDER_SUPPLIER = DotName.createSimple(
105106
RegisterAiService.BeanIfExistsToolProviderSupplier.class);
106107

108+
static final DotName NO_TOOL_PROVIDER_SUPPLIER = DotName.createSimple(
109+
RegisterAiService.NoToolProviderSupplier.class);
110+
107111
static final DotName QUARKUS_AI_SERVICE_CONTEXT_QUALIFIER = DotName.createSimple(
108112
QuarkusAiServiceContextQualifier.class);
109113

@@ -113,5 +117,5 @@ public class LangChain4jDotNames {
113117
static final DotName WEB_SEARCH_ENGINE = DotName.createSimple(WebSearchEngine.class);
114118
static final DotName IMAGE = DotName.createSimple(Image.class);
115119
static final DotName RESULT = DotName.createSimple(Result.class);
116-
static final DotName TOOL_PROVIDER = DotName.createSimple(ToolProcessor.class);
120+
static final DotName TOOL_PROVIDER = DotName.createSimple(ToolProvider.class);
117121
}

core/deployment/src/test/java/io/quarkiverse/langchain4j/test/ToolProviderTest.java

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,10 @@ class ToolProviderTest {
3939
MyServiceWithCustomToolProvider myServiceWithTools;
4040

4141
@Inject
42-
MyServiceWithDefaultToolProviderConfig myServiceWithoutTools;
42+
MyServiceWithDefaultToolProviderConfig myServiceWithIfExistsTools;
43+
44+
@Inject
45+
MyServiceWithNoToolProvider myServiceWithNoToolProvider;
4346

4447
@ApplicationScoped
4548
public static class MyCustomToolProviderSupplier implements Supplier<ToolProvider> {
@@ -108,9 +111,14 @@ interface MyServiceWithCustomToolProvider {
108111
String chat(@UserMessage String msg, @MemoryId Object id);
109112
}
110113

111-
@RegisterAiService(chatLanguageModelSupplier = BlockingChatLanguageModelSupplierTest.MyModelSupplier.class, chatMemoryProviderSupplier = RegisterAiService.NoChatMemoryProviderSupplier.class)
114+
@RegisterAiService(chatLanguageModelSupplier = TestAiSupplier.class)
112115
interface MyServiceWithDefaultToolProviderConfig {
113-
String chat(String msg);
116+
String chat(@UserMessage String msg, @MemoryId Object id);
117+
}
118+
119+
@RegisterAiService(toolProviderSupplier = RegisterAiService.NoToolProviderSupplier.class, chatLanguageModelSupplier = TestAiSupplier.class, chatMemoryProviderSupplier = RegisterAiService.NoChatMemoryProviderSupplier.class)
120+
interface MyServiceWithNoToolProvider {
121+
String chat(@UserMessage String msg, @MemoryId Object id);
114122
}
115123

116124
@RegisterExtension
@@ -126,10 +134,17 @@ void testCall() {
126134
assertEquals("0", answer);
127135
}
128136

137+
@Test
138+
@ActivateRequestContext
139+
void testCallDefaultTools() {
140+
String answer = myServiceWithIfExistsTools.chat("hello", 1);
141+
assertEquals("0", answer);
142+
}
143+
129144
@Test
130145
@ActivateRequestContext
131146
void testCallNoTools() {
132-
String answer = myServiceWithoutTools.chat("hello");
147+
String answer = myServiceWithNoToolProvider.chat("hello", 1);
133148
assertEquals("42", answer);
134149
}
135150
}

core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,17 @@ public List<TextSegment> findRelevant(String text) {
209209
}
210210
}
211211

212+
/**
213+
* Marker that is used when the user does not want any tool provider
214+
*/
215+
final class NoToolProviderSupplier implements Supplier<ToolProvider> {
216+
217+
@Override
218+
public ToolProvider get() {
219+
throw new UnsupportedOperationException("should never be called");
220+
}
221+
}
222+
212223
/**
213224
* Marker that is used to tell Quarkus to use the {@link RetrievalAugmentor} that the user has configured as a CDI bean.
214225
* If no such bean exists, then no retrieval augmentor will be used.

core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ public class AiServicesRecorder {
4343
private static final TypeLiteral<Instance<RetrievalAugmentor>> RETRIEVAL_AUGMENTOR_TYPE_LITERAL = new TypeLiteral<>() {
4444
};
4545

46+
private static final TypeLiteral<Instance<ToolProvider>> TOOL_PROVIDER_TYPE_LITERAL = new TypeLiteral<>() {
47+
};
48+
4649
// the key is the interface's class name
4750
private static final Map<String, AiServiceClassCreateInfo> metadata = new HashMap<>();
4851

@@ -169,13 +172,23 @@ public T apply(SyntheticCreationalContext<T> creationalContext) {
169172
quarkusAiServices.tools(tools);
170173
}
171174

172-
if (!RegisterAiService.BeanIfExistsToolProviderSupplier.class.getName()
173-
.equals(info.toolProviderSupplier())) {
174-
Class<?> toolProviderClass = Thread.currentThread().getContextClassLoader()
175-
.loadClass(info.toolProviderSupplier());
176-
Supplier<? extends ToolProvider> toolProvider = (Supplier<? extends ToolProvider>) creationalContext
177-
.getInjectedReference(toolProviderClass);
178-
quarkusAiServices.toolProvider(toolProvider.get());
175+
if (info.toolProviderSupplier() != null) {
176+
if (!RegisterAiService.BeanIfExistsToolProviderSupplier.class.getName()
177+
.equals(info.toolProviderSupplier())) {
178+
// specific provider
179+
Class<?> toolProviderClass = Thread.currentThread().getContextClassLoader()
180+
.loadClass(info.toolProviderSupplier());
181+
Supplier<? extends ToolProvider> toolProvider = (Supplier<? extends ToolProvider>) creationalContext
182+
.getInjectedReference(toolProviderClass);
183+
quarkusAiServices.toolProvider(toolProvider.get());
184+
} else {
185+
// if-exists provider
186+
Instance<ToolProvider> instance = creationalContext
187+
.getInjectedReference(TOOL_PROVIDER_TYPE_LITERAL);
188+
if (instance.isResolvable()) {
189+
quarkusAiServices.toolProvider(instance.get());
190+
}
191+
}
179192
}
180193

181194
if (info.chatMemoryProviderSupplierClassName() != null) {

0 commit comments

Comments
 (0)