Skip to content

Commit cb3bbad

Browse files
authored
Merge pull request #1853 from ejstuart/fix/toolprovider-with-streaming
Add check for tool providers when streaming
2 parents 853faed + c6077a9 commit cb3bbad

File tree

1 file changed

+21
-6
lines changed

1 file changed

+21
-6
lines changed

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,24 +1078,33 @@ public void markUsedResponseAugmenterUnremovable(List<AiServicesMethodBuildItem>
10781078
*
10791079
* @param method the AI method
10801080
* @param tools the tools
1081+
* @param toolProviderClassDotName the tool provider class name (if configured)
10811082
*/
10821083
public boolean detectAiServiceMethodThanNeedToBeDispatchedOnWorkerThread(
10831084
MethodInfo method,
10841085
List<String> associatedTools,
1085-
List<ToolMethodBuildItem> tools) {
1086+
List<ToolMethodBuildItem> tools,
1087+
DotName toolProviderClassDotName) {
10861088
boolean reactive = method.returnType().name().equals(DotNames.UNI)
10871089
|| method.returnType().name().equals(DotNames.COMPLETION_STAGE)
10881090
|| method.returnType().name().equals(DotNames.MULTI);
10891091

10901092
boolean requireSwitchToWorkerThread = false;
10911093

1092-
if (associatedTools.isEmpty()) {
1093-
// No tools, no need to dispatch
1094+
if (!reactive) {
1095+
// We are already on a thread we can block.
10941096
return false;
10951097
}
10961098

1097-
if (!reactive) {
1098-
// We are already on a thread we can block.
1099+
// If a ToolProvider is configured for a reactive method, assume it may provide blocking tools at runtime
1100+
if (toolProviderClassDotName != null
1101+
&& !LangChain4jDotNames.NO_TOOL_PROVIDER_SUPPLIER.equals(toolProviderClassDotName)) {
1102+
// Be conservative: assume ToolProvider may supply blocking tools at runtime
1103+
return true;
1104+
}
1105+
1106+
if (associatedTools.isEmpty()) {
1107+
// No tools, no need to dispatch
10991108
return false;
11001109
}
11011110

@@ -1763,15 +1772,21 @@ private Optional<JsonSchema> jsonSchemaFrom(java.lang.reflect.Type returnType) {
17631772
private boolean detectIfToolExecutionRequiresAWorkerThread(MethodInfo method, List<ToolMethodBuildItem> tools,
17641773
Collection<String> methodToolClassNames) {
17651774
List<String> allTools = new ArrayList<>(methodToolClassNames);
1775+
DotName toolProviderClassDotName = null;
17661776
// We need to combine it with the tools that are registered globally - unfortunately, we don't have access to the AI service here, so, re-parsing.
17671777
AnnotationInstance annotation = method.declaringClass().annotation(REGISTER_AI_SERVICES);
17681778
if (annotation != null) {
17691779
AnnotationValue value = annotation.value("tools");
17701780
if (value != null) {
17711781
allTools.addAll(Arrays.stream(value.asClassArray()).map(t -> t.name().toString()).toList());
17721782
}
1783+
// Extract toolProviderSupplier from annotation
1784+
AnnotationValue toolProviderValue = annotation.value("toolProviderSupplier");
1785+
if (toolProviderValue != null) {
1786+
toolProviderClassDotName = toolProviderValue.asClass().name();
1787+
}
17731788
}
1774-
return detectAiServiceMethodThanNeedToBeDispatchedOnWorkerThread(method, allTools, tools);
1789+
return detectAiServiceMethodThanNeedToBeDispatchedOnWorkerThread(method, allTools, tools, toolProviderClassDotName);
17751790
}
17761791

17771792
private void validateReturnType(MethodInfo method) {

0 commit comments

Comments
 (0)