Skip to content

Commit def13c4

Browse files
authored
Merge pull request #1837 from patriot1burke/immediate-resultObject
InvocationParameters and ToolExecution.resultObject() support
2 parents 9e21b4d + d173843 commit def13c4

File tree

11 files changed

+208
-68
lines changed

11 files changed

+208
-68
lines changed

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java

Lines changed: 45 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
7070

7171
import dev.langchain4j.guardrail.OutputGuardrail;
72+
import dev.langchain4j.invocation.InvocationParameters;
7273
import dev.langchain4j.memory.ChatMemory;
7374
import dev.langchain4j.model.chat.request.json.JsonSchema;
7475
import dev.langchain4j.service.IllegalConfigurationException;
@@ -162,6 +163,7 @@ public class AiServicesProcessor {
162163
private static final DotName TOOLBOX = DotName.createSimple(ToolBox.class);
163164
public static final DotName MICROMETER_TIMED = DotName.createSimple("io.micrometer.core.annotation.Timed");
164165
public static final DotName MICROMETER_COUNTED = DotName.createSimple("io.micrometer.core.annotation.Counted");
166+
private static final DotName INVOCATION_PARAMETERS = DotName.createSimple(InvocationParameters.class);
165167
public static final String DEFAULT_DELIMITER = "\n";
166168
public static final Predicate<AnnotationInstance> IS_METHOD_PARAMETER_ANNOTATION = ai -> ai.target()
167169
.kind() == AnnotationTarget.Kind.METHOD_PARAMETER;
@@ -1846,6 +1848,9 @@ private List<TemplateParameterInfo> gatherTemplateParamInfo(MethodInfo method,
18461848
private boolean isParameterAllowedAsTemplateVariable(
18471849
MethodParameterInfo param, Collection<Predicate<AnnotationInstance>> allowedPredicates,
18481850
Collection<Predicate<AnnotationInstance>> ignoredPredicates) {
1851+
if (param.type().name().equals(INVOCATION_PARAMETERS)) {
1852+
return false;
1853+
}
18491854

18501855
Collection<MethodParameterAsTemplateVariableAllowance> allowances = param.annotations().stream().map(anno -> {
18511856

@@ -1974,41 +1979,56 @@ private AiServiceMethodCreateInfo.UserMessageInfo gatherUserMessageInfo(MethodIn
19741979
userNameParamPosition, imageParamPosition, audioParamPosition, pdfParamPosition);
19751980
}
19761981
} else {
1977-
int numOfMethodParamsUsedInSystemMessage = 0;
1982+
Set<String> templateParamNames = Collections.EMPTY_SET;
19781983
if (systemMessageInfo.isPresent() && systemMessageInfo.get().text().isPresent()) {
1979-
Set<String> templateParamNames = TemplateUtil.parts(systemMessageInfo.get().text().get()).stream()
1984+
templateParamNames = TemplateUtil.parts(systemMessageInfo.get().text().get()).stream()
19801985
.flatMap(l -> l.stream().map(
19811986
Expression.Part::getName))
19821987
.collect(Collectors.toSet());
1983-
for (MethodParameterInfo parameter : method.parameters()) {
1984-
if (templateParamNames.contains(parameter.name())) {
1985-
numOfMethodParamsUsedInSystemMessage++;
1986-
}
1987-
}
19881988
}
1989-
if (numOfMethodParamsUsedInSystemMessage != method.parametersCount()) {
1990-
if (method.parametersCount() == 0) {
1991-
throw illegalConfigurationForMethod("Method should have at least one argument", method);
1992-
}
1993-
if (method.parametersCount() == 1) {
1994-
return AiServiceMethodCreateInfo.UserMessageInfo.fromMethodParam(0, userNameParamPosition,
1995-
imageParamPosition, audioParamPosition, pdfParamPosition);
1989+
int userMessageParamPosition = -1;
1990+
int undefinedParams = 0;
1991+
for (int i = 0; i < method.parametersCount(); i++) {
1992+
MethodParameterInfo parameter = method.parameters().get(i);
1993+
if (templateParamNames.contains(parameter.name())) {
1994+
continue;
1995+
} else if (userNameParamPosition.isPresent() && i == userNameParamPosition.get()) {
1996+
continue;
1997+
} else if (imageParamPosition.isPresent() && i == imageParamPosition.get()) {
1998+
continue;
1999+
} else if (audioParamPosition.isPresent() && i == audioParamPosition.get()) {
2000+
continue;
2001+
} else if (pdfParamPosition.isPresent() && i == pdfParamPosition.get()) {
2002+
continue;
2003+
} else if (parameter.type().name().equals(INVOCATION_PARAMETERS)) {
2004+
continue;
2005+
} else if (parameter.hasAnnotation(LangChain4jDotNames.MEMORY_ID)) {
2006+
continue;
19962007
}
2008+
undefinedParams++;
2009+
if (undefinedParams > 1) {
2010+
if (fallbackToDummyUserMesage.test(method)) {
2011+
return AiServiceMethodCreateInfo.UserMessageInfo.fromTemplate(
2012+
AiServiceMethodCreateInfo.TemplateInfo.fromText("", Map.of()), Optional.empty(),
2013+
Optional.empty(),
2014+
Optional.empty(), Optional.empty());
2015+
}
19972016

1998-
if (fallbackToDummyUserMesage.test(method)) {
1999-
return AiServiceMethodCreateInfo.UserMessageInfo.fromTemplate(
2000-
AiServiceMethodCreateInfo.TemplateInfo.fromText("", Map.of()), Optional.empty(),
2001-
Optional.empty(),
2002-
Optional.empty(), Optional.empty());
2017+
throw illegalConfigurationForMethod(
2018+
"For methods with multiple parameters, each parameter must be annotated with @V (or match an template parameter by name), @UserMessage, @UserName or @MemoryId",
2019+
method);
20032020
}
2004-
2005-
throw illegalConfigurationForMethod(
2006-
"For methods with multiple parameters, each parameter must be annotated with @V (or match an template parameter by name), @UserMessage, @UserName or @MemoryId",
2007-
method);
2008-
} else {
2009-
// all method parameters are present in the system message, so there is no user message
2021+
userMessageParamPosition = i;
2022+
}
2023+
if (userMessageParamPosition == -1) {
2024+
// There is no user message
20102025
return new AiServiceMethodCreateInfo.UserMessageInfo(Optional.empty(), Optional.empty(), Optional.empty(),
20112026
Optional.empty(), Optional.empty(), Optional.empty());
2027+
} else {
2028+
return AiServiceMethodCreateInfo.UserMessageInfo.fromMethodParam(userMessageParamPosition,
2029+
userNameParamPosition,
2030+
imageParamPosition, audioParamPosition, pdfParamPosition);
2031+
20122032
}
20132033
}
20142034
}

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import dev.langchain4j.agent.tool.Tool;
4545
import dev.langchain4j.agent.tool.ToolMemoryId;
4646
import dev.langchain4j.agent.tool.ToolSpecification;
47+
import dev.langchain4j.invocation.InvocationParameters;
4748
import dev.langchain4j.model.chat.request.json.JsonArraySchema;
4849
import dev.langchain4j.model.chat.request.json.JsonBooleanSchema;
4950
import dev.langchain4j.model.chat.request.json.JsonEnumSchema;
@@ -90,11 +91,12 @@ public class ToolProcessor {
9091
private static final DotName TOOL = DotName.createSimple(Tool.class);
9192
private static final DotName TOOL_MEMORY_ID = DotName.createSimple(ToolMemoryId.class);
9293
private static final DotName JSON_IGNORE = DotName.createSimple(JsonIgnore.class);
94+
private static final DotName INVOCATION_PARAMETERS = DotName.createSimple(InvocationParameters.class);
9395

9496
private static final DotName P = DotName.createSimple(dev.langchain4j.agent.tool.P.class);
9597
private static final DotName DESCRIPTION = DotName.createSimple(Description.class);
9698
private static final MethodDescriptor METHOD_METADATA_CTOR = MethodDescriptor
97-
.ofConstructor(ToolInvoker.MethodMetadata.class, boolean.class, Map.class, Integer.class);
99+
.ofConstructor(ToolInvoker.MethodMetadata.class, boolean.class, Map.class, Integer.class, Integer.class);
98100
private static final MethodDescriptor HASHMAP_CTOR = MethodDescriptor.ofConstructor(HashMap.class);
99101
public static final MethodDescriptor MAP_PUT = MethodDescriptor.ofMethod(Map.class, "put", Object.class, Object.class,
100102
Object.class);
@@ -254,11 +256,16 @@ public void handleTools(
254256
var required = new ArrayList<String>(toolMethod.parametersCount());
255257

256258
MethodParameterInfo memoryIdParameter = null;
259+
MethodParameterInfo invocationParamsParameter = null;
257260
for (MethodParameterInfo parameter : toolMethod.parameters()) {
258261
if (parameter.hasAnnotation(TOOL_MEMORY_ID)) {
259262
memoryIdParameter = parameter;
260263
continue;
261264
}
265+
if (parameter.type().name().equals(INVOCATION_PARAMETERS)) {
266+
invocationParamsParameter = parameter;
267+
continue;
268+
}
262269

263270
var pInstance = parameter.annotation(P);
264271
var jsonSchemaElement = toJsonSchemaElement(parameter, index);
@@ -282,7 +289,9 @@ public void handleTools(
282289
String methodSignature = createUniqueSignature(toolMethod);
283290

284291
String invokerClassName = generateInvoker(toolMethod, classOutput, nameToParamPosition,
285-
memoryIdParameter != null ? memoryIdParameter.position() : null, methodSignature);
292+
memoryIdParameter != null ? memoryIdParameter.position() : null,
293+
invocationParamsParameter != null ? invocationParamsParameter.position() : null,
294+
methodSignature);
286295
generatedInvokerClasses.add(invokerClassName);
287296
String argumentMapperClassName = generateArgumentMapper(toolMethod, classOutput,
288297
methodSignature);
@@ -452,7 +461,8 @@ private String getToolDescription(AnnotationValue descriptionValue) {
452461
}
453462

454463
private static String generateInvoker(MethodInfo methodInfo, ClassOutput classOutput,
455-
Map<String, Integer> nameToParamPosition, Short memoryIdParamPosition, String methodSignature) {
464+
Map<String, Integer> nameToParamPosition, Short memoryIdParamPosition, Short invocationParamsParamPosition,
465+
String methodSignature) {
456466
String implClassName = methodInfo.declaringClass().name() + "$$QuarkusInvoker$" + methodInfo.name() + "_"
457467
+ HashUtil.sha1(methodSignature);
458468
try (ClassCreator classCreator = ClassCreator.builder()
@@ -484,7 +494,7 @@ private static String generateInvoker(MethodInfo methodInfo, ClassOutput classOu
484494

485495
boolean toolReturnsVoid = methodInfo.returnType().kind() == Type.Kind.VOID;
486496
if (toolReturnsVoid) {
487-
invokeMc.returnValue(invokeMc.load("Success"));
497+
invokeMc.returnValue(invokeMc.loadNull());
488498
} else {
489499
invokeMc.returnValue(result);
490500
}
@@ -503,6 +513,9 @@ private static String generateInvoker(MethodInfo methodInfo, ClassOutput classOu
503513
methodMetadataMc.load(toolReturnsVoid),
504514
nameToParamPositionHandle,
505515
memoryIdParamPosition != null ? methodMetadataMc.load(Integer.valueOf(memoryIdParamPosition))
516+
: methodMetadataMc.loadNull(),
517+
invocationParamsParamPosition != null
518+
? methodMetadataMc.load(Integer.valueOf(invocationParamsParamPosition))
506519
: methodMetadataMc.loadNull());
507520
methodMetadataMc.returnValue(resultHandle);
508521
}

core/deployment/src/test/java/io/quarkiverse/langchain4j/test/tools/ToolExecutionModelTest.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,21 @@ void testImmediate() {
270270
assertThat(r.finishReason()).isEqualTo(FinishReason.TOOL_EXECUTION);
271271
assertThat(r.toolExecutions()).hasSize(1);
272272
assertThat(r.toolExecutions().get(0).result()).contains("hiImmediate");
273+
assertThat(r.toolExecutions().get(0).resultObject()).isInstanceOf(String.class);
274+
assertThat((String) r.toolExecutions().get(0).resultObject()).contains("hiImmediate");
275+
}
276+
277+
@Test
278+
@ActivateRequestContext
279+
void testImmediateVoid() {
280+
// This tests @Tool(returnBehavior = ReturnBehavior.IMMEDIATE)
281+
String uuid = UUID.randomUUID().toString();
282+
Result<String> r = aiService.helloResult("abc", "hiImmediateVoid - " + uuid);
283+
assertNull(r.content());
284+
assertThat(r.finishReason()).isEqualTo(FinishReason.TOOL_EXECUTION);
285+
assertThat(r.toolExecutions()).hasSize(1);
286+
assertThat(r.toolExecutions().get(0).result()).contains("Success");
287+
assertThat(r.toolExecutions().get(0).resultObject()).isNull();
273288
}
274289

275290
@Test
@@ -303,6 +318,10 @@ public String hiImmediate(String m) {
303318
return "hiImmediate";
304319
}
305320

321+
@Tool(returnBehavior = ReturnBehavior.IMMEDIATE)
322+
public void hiImmediateVoid(String m) {
323+
}
324+
306325
@Tool
307326
public String hi(String m) {
308327
return m + " " + Thread.currentThread();

0 commit comments

Comments
 (0)