|
69 | 69 | import com.fasterxml.jackson.databind.PropertyNamingStrategies; |
70 | 70 |
|
71 | 71 | import dev.langchain4j.guardrail.OutputGuardrail; |
| 72 | +import dev.langchain4j.invocation.InvocationParameters; |
72 | 73 | import dev.langchain4j.memory.ChatMemory; |
73 | 74 | import dev.langchain4j.model.chat.request.json.JsonSchema; |
74 | 75 | import dev.langchain4j.service.IllegalConfigurationException; |
@@ -162,6 +163,7 @@ public class AiServicesProcessor { |
162 | 163 | private static final DotName TOOLBOX = DotName.createSimple(ToolBox.class); |
163 | 164 | public static final DotName MICROMETER_TIMED = DotName.createSimple("io.micrometer.core.annotation.Timed"); |
164 | 165 | public static final DotName MICROMETER_COUNTED = DotName.createSimple("io.micrometer.core.annotation.Counted"); |
| 166 | + private static final DotName INVOCATION_PARAMETERS = DotName.createSimple(InvocationParameters.class); |
165 | 167 | public static final String DEFAULT_DELIMITER = "\n"; |
166 | 168 | public static final Predicate<AnnotationInstance> IS_METHOD_PARAMETER_ANNOTATION = ai -> ai.target() |
167 | 169 | .kind() == AnnotationTarget.Kind.METHOD_PARAMETER; |
@@ -1846,6 +1848,9 @@ private List<TemplateParameterInfo> gatherTemplateParamInfo(MethodInfo method, |
1846 | 1848 | private boolean isParameterAllowedAsTemplateVariable( |
1847 | 1849 | MethodParameterInfo param, Collection<Predicate<AnnotationInstance>> allowedPredicates, |
1848 | 1850 | Collection<Predicate<AnnotationInstance>> ignoredPredicates) { |
| 1851 | + if (param.type().name().equals(INVOCATION_PARAMETERS)) { |
| 1852 | + return false; |
| 1853 | + } |
1849 | 1854 |
|
1850 | 1855 | Collection<MethodParameterAsTemplateVariableAllowance> allowances = param.annotations().stream().map(anno -> { |
1851 | 1856 |
|
@@ -1974,41 +1979,56 @@ private AiServiceMethodCreateInfo.UserMessageInfo gatherUserMessageInfo(MethodIn |
1974 | 1979 | userNameParamPosition, imageParamPosition, audioParamPosition, pdfParamPosition); |
1975 | 1980 | } |
1976 | 1981 | } else { |
1977 | | - int numOfMethodParamsUsedInSystemMessage = 0; |
| 1982 | + Set<String> templateParamNames = Collections.EMPTY_SET; |
1978 | 1983 | if (systemMessageInfo.isPresent() && systemMessageInfo.get().text().isPresent()) { |
1979 | | - Set<String> templateParamNames = TemplateUtil.parts(systemMessageInfo.get().text().get()).stream() |
| 1984 | + templateParamNames = TemplateUtil.parts(systemMessageInfo.get().text().get()).stream() |
1980 | 1985 | .flatMap(l -> l.stream().map( |
1981 | 1986 | Expression.Part::getName)) |
1982 | 1987 | .collect(Collectors.toSet()); |
1983 | | - for (MethodParameterInfo parameter : method.parameters()) { |
1984 | | - if (templateParamNames.contains(parameter.name())) { |
1985 | | - numOfMethodParamsUsedInSystemMessage++; |
1986 | | - } |
1987 | | - } |
1988 | 1988 | } |
1989 | | - if (numOfMethodParamsUsedInSystemMessage != method.parametersCount()) { |
1990 | | - if (method.parametersCount() == 0) { |
1991 | | - throw illegalConfigurationForMethod("Method should have at least one argument", method); |
1992 | | - } |
1993 | | - if (method.parametersCount() == 1) { |
1994 | | - return AiServiceMethodCreateInfo.UserMessageInfo.fromMethodParam(0, userNameParamPosition, |
1995 | | - imageParamPosition, audioParamPosition, pdfParamPosition); |
| 1989 | + int userMessageParamPosition = -1; |
| 1990 | + int undefinedParams = 0; |
| 1991 | + for (int i = 0; i < method.parametersCount(); i++) { |
| 1992 | + MethodParameterInfo parameter = method.parameters().get(i); |
| 1993 | + if (templateParamNames.contains(parameter.name())) { |
| 1994 | + continue; |
| 1995 | + } else if (userNameParamPosition.isPresent() && i == userNameParamPosition.get()) { |
| 1996 | + continue; |
| 1997 | + } else if (imageParamPosition.isPresent() && i == imageParamPosition.get()) { |
| 1998 | + continue; |
| 1999 | + } else if (audioParamPosition.isPresent() && i == audioParamPosition.get()) { |
| 2000 | + continue; |
| 2001 | + } else if (pdfParamPosition.isPresent() && i == pdfParamPosition.get()) { |
| 2002 | + continue; |
| 2003 | + } else if (parameter.type().name().equals(INVOCATION_PARAMETERS)) { |
| 2004 | + continue; |
| 2005 | + } else if (parameter.hasAnnotation(LangChain4jDotNames.MEMORY_ID)) { |
| 2006 | + continue; |
1996 | 2007 | } |
| 2008 | + undefinedParams++; |
| 2009 | + if (undefinedParams > 1) { |
| 2010 | + if (fallbackToDummyUserMesage.test(method)) { |
| 2011 | + return AiServiceMethodCreateInfo.UserMessageInfo.fromTemplate( |
| 2012 | + AiServiceMethodCreateInfo.TemplateInfo.fromText("", Map.of()), Optional.empty(), |
| 2013 | + Optional.empty(), |
| 2014 | + Optional.empty(), Optional.empty()); |
| 2015 | + } |
1997 | 2016 |
|
1998 | | - if (fallbackToDummyUserMesage.test(method)) { |
1999 | | - return AiServiceMethodCreateInfo.UserMessageInfo.fromTemplate( |
2000 | | - AiServiceMethodCreateInfo.TemplateInfo.fromText("", Map.of()), Optional.empty(), |
2001 | | - Optional.empty(), |
2002 | | - Optional.empty(), Optional.empty()); |
| 2017 | + throw illegalConfigurationForMethod( |
| 2018 | + "For methods with multiple parameters, each parameter must be annotated with @V (or match an template parameter by name), @UserMessage, @UserName or @MemoryId", |
| 2019 | + method); |
2003 | 2020 | } |
2004 | | - |
2005 | | - throw illegalConfigurationForMethod( |
2006 | | - "For methods with multiple parameters, each parameter must be annotated with @V (or match an template parameter by name), @UserMessage, @UserName or @MemoryId", |
2007 | | - method); |
2008 | | - } else { |
2009 | | - // all method parameters are present in the system message, so there is no user message |
| 2021 | + userMessageParamPosition = i; |
| 2022 | + } |
| 2023 | + if (userMessageParamPosition == -1) { |
| 2024 | + // There is no user message |
2010 | 2025 | return new AiServiceMethodCreateInfo.UserMessageInfo(Optional.empty(), Optional.empty(), Optional.empty(), |
2011 | 2026 | Optional.empty(), Optional.empty(), Optional.empty()); |
| 2027 | + } else { |
| 2028 | + return AiServiceMethodCreateInfo.UserMessageInfo.fromMethodParam(userMessageParamPosition, |
| 2029 | + userNameParamPosition, |
| 2030 | + imageParamPosition, audioParamPosition, pdfParamPosition); |
| 2031 | + |
2012 | 2032 | } |
2013 | 2033 | } |
2014 | 2034 | } |
|
0 commit comments