5151import org .objectweb .asm .tree .analysis .AnalyzerException ;
5252
5353import dev .langchain4j .exception .IllegalConfigurationException ;
54- import dev .langchain4j .service .Moderate ;
5554import io .quarkiverse .langchain4j .ModelName ;
5655import io .quarkiverse .langchain4j .ToolBox ;
5756import io .quarkiverse .langchain4j .deployment .items .SelectedChatModelProviderBuildItem ;
@@ -185,6 +184,7 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
185184
186185 Set <String > chatModelNames = new HashSet <>();
187186 Set <String > moderationModelNames = new HashSet <>();
187+
188188 for (AnnotationInstance instance : index .getAnnotations (LangChain4jDotNames .REGISTER_AI_SERVICES )) {
189189 if (instance .target ().kind () != AnnotationTarget .Kind .CLASS ) {
190190 continue ; // should never happen
@@ -206,14 +206,12 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
206206 }
207207
208208 String chatModelName = NamedConfigUtil .DEFAULT_NAME ;
209+ String moderationModelName = NamedConfigUtil .DEFAULT_NAME ;
210+ String embeddingModelName = getModelName (instance .value ("modelName" ));
211+
209212 if (chatLanguageModelSupplierClassDotName == null ) {
210213 AnnotationValue modelNameValue = instance .value ("modelName" );
211- if (modelNameValue != null ) {
212- String modelNameValueStr = modelNameValue .asString ();
213- if ((modelNameValueStr != null ) && !modelNameValueStr .isEmpty ()) {
214- chatModelName = modelNameValueStr ;
215- }
216- }
214+ chatModelName = getModelName (modelNameValue );
217215 chatModelNames .add (chatModelName );
218216 }
219217
@@ -239,6 +237,18 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
239237 }
240238 }
241239
240+ // the default value depends on whether tools exists or not - if they do, then we require a AiCacheProvider bean
241+ DotName aiCacheProviderSupplierClassDotName = LangChain4jDotNames .BEAN_AI_CACHE_PROVIDER_SUPPLIER ;
242+ AnnotationValue aiCacheProviderSupplierValue = instance .value ("cacheProviderSupplier" );
243+ if (aiCacheProviderSupplierValue != null ) {
244+ aiCacheProviderSupplierClassDotName = aiCacheProviderSupplierValue .asClass ().name ();
245+ if (!aiCacheProviderSupplierClassDotName
246+ .equals (LangChain4jDotNames .BEAN_AI_CACHE_PROVIDER_SUPPLIER )) {
247+ validateSupplierAndRegisterForReflection (aiCacheProviderSupplierClassDotName , index ,
248+ reflectiveClassProducer );
249+ }
250+ }
251+
242252 DotName retrieverClassDotName = null ;
243253 AnnotationValue retrieverValue = instance .value ("retriever" );
244254 if (retrieverValue != null ) {
@@ -292,17 +302,11 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
292302 }
293303
294304 // determine whether the method is annotated with @Moderate
295- String moderationModelName = NamedConfigUtil .DEFAULT_NAME ;
296305 for (MethodInfo method : declarativeAiServiceClassInfo .methods ()) {
297306 if (method .hasAnnotation (LangChain4jDotNames .MODERATE )) {
298307 if (moderationModelSupplierClassName .equals (LangChain4jDotNames .BEAN_IF_EXISTS_MODERATION_MODEL_SUPPLIER )) {
299308 AnnotationValue modelNameValue = instance .value ("modelName" );
300- if (modelNameValue != null ) {
301- String modelNameValueStr = modelNameValue .asString ();
302- if ((modelNameValueStr != null ) && !modelNameValueStr .isEmpty ()) {
303- moderationModelName = modelNameValueStr ;
304- }
305- }
309+ moderationModelName = getModelName (modelNameValue );
306310 moderationModelNames .add (moderationModelName );
307311 }
308312 break ;
@@ -321,13 +325,16 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
321325 chatLanguageModelSupplierClassDotName ,
322326 toolDotNames ,
323327 chatMemoryProviderSupplierClassDotName ,
328+ aiCacheProviderSupplierClassDotName ,
324329 retrieverClassDotName ,
325330 retrievalAugmentorSupplierClassName ,
326331 customRetrievalAugmentorSupplierClassIsABean ,
327332 auditServiceSupplierClassName ,
328333 moderationModelSupplierClassName ,
329334 cdiScope ,
330- chatModelName , moderationModelName ));
335+ chatModelName ,
336+ moderationModelName ,
337+ embeddingModelName ));
331338 }
332339
333340 for (String chatModelName : chatModelNames ) {
@@ -361,7 +368,8 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
361368 List <DeclarativeAiServiceBuildItem > declarativeAiServiceItems ,
362369 List <SelectedChatModelProviderBuildItem > selectedChatModelProvider ,
363370 BuildProducer <SyntheticBeanBuildItem > syntheticBeanProducer ,
364- BuildProducer <UnremovableBeanBuildItem > unremoveableProducer ) {
371+ BuildProducer <UnremovableBeanBuildItem > unremoveableProducer ,
372+ AiCacheBuildItem aiCacheBuildItem ) {
365373
366374 boolean needsChatModelBean = false ;
367375 boolean needsStreamingChatModelBean = false ;
@@ -370,6 +378,8 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
370378 boolean needsRetrievalAugmentorBean = false ;
371379 boolean needsAuditServiceBean = false ;
372380 boolean needsModerationModelBean = false ;
381+ boolean needsAiCacheProvider = false ;
382+
373383 Set <DotName > allToolNames = new HashSet <>();
374384
375385 for (DeclarativeAiServiceBuildItem bi : declarativeAiServiceItems ) {
@@ -386,6 +396,10 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
386396 ? bi .getChatMemoryProviderSupplierClassDotName ().toString ()
387397 : null ;
388398
399+ String aiCacheProviderSupplierClassName = bi .getAiCacheProviderSupplierClassDotName () != null
400+ ? bi .getAiCacheProviderSupplierClassDotName ().toString ()
401+ : null ;
402+
389403 String retrieverClassName = bi .getRetrieverClassDotName () != null
390404 ? bi .getRetrieverClassDotName ().toString ()
391405 : null ;
@@ -403,7 +417,7 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
403417 : null );
404418
405419 // determine whether the method returns Multi<String>
406- boolean injectStreamingChatModelBean = false ;
420+ boolean needsStreamingChatModel = false ;
407421 for (MethodInfo method : declarativeAiServiceClassInfo .methods ()) {
408422 if (!LangChain4jDotNames .MULTI .equals (method .returnType ().name ())) {
409423 continue ;
@@ -419,29 +433,36 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
419433 throw illegalConfiguration ("Only Multi<String> is supported as a Multi return type. Offending method is '"
420434 + method .declaringClass ().name ().toString () + "#" + method .name () + "'" );
421435 }
422- injectStreamingChatModelBean = true ;
436+ needsStreamingChatModel = true ;
423437 }
424438
425- boolean injectModerationModelBean = false ;
439+ boolean needsModerationModel = false ;
426440 for (MethodInfo method : declarativeAiServiceClassInfo .methods ()) {
427- if (method .hasAnnotation (Moderate . class )) {
428- injectModerationModelBean = true ;
441+ if (method .hasAnnotation (LangChain4jDotNames . MODERATE )) {
442+ needsModerationModel = true ;
429443 break ;
430444 }
431445 }
432446
433447 String chatModelName = bi .getChatModelName ();
434448 String moderationModelName = bi .getModerationModelName ();
449+ String embeddingModelName = bi .getEmbeddingModelName ();
450+ boolean enableCache = aiCacheBuildItem .isEnable ();
451+
435452 SyntheticBeanBuildItem .ExtendedBeanConfigurator configurator = SyntheticBeanBuildItem
436453 .configure (QuarkusAiServiceContext .class )
437454 .forceApplicationClass ()
438455 .createWith (recorder .createDeclarativeAiService (
439456 new DeclarativeAiServiceCreateInfo (serviceClassName , chatLanguageModelSupplierClassName ,
440- toolClassNames , chatMemoryProviderSupplierClassName , retrieverClassName ,
457+ toolClassNames , chatMemoryProviderSupplierClassName , aiCacheProviderSupplierClassName ,
458+ retrieverClassName ,
441459 retrievalAugmentorSupplierClassName ,
442460 auditServiceClassSupplierName , moderationModelSupplierClassName , chatModelName ,
443461 moderationModelName ,
444- injectStreamingChatModelBean , injectModerationModelBean )))
462+ embeddingModelName ,
463+ needsStreamingChatModel ,
464+ needsModerationModel ,
465+ enableCache )))
445466 .setRuntimeInit ()
446467 .addQualifier ()
447468 .annotation (LangChain4jDotNames .QUARKUS_AI_SERVICE_CONTEXT_QUALIFIER ).addValue ("value" , serviceClassName )
@@ -451,15 +472,15 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
451472 if ((chatLanguageModelSupplierClassName == null ) && !selectedChatModelProvider .isEmpty ()) {
452473 if (NamedConfigUtil .isDefault (chatModelName )) {
453474 configurator .addInjectionPoint (ClassType .create (LangChain4jDotNames .CHAT_MODEL ));
454- if (injectStreamingChatModelBean ) {
475+ if (needsStreamingChatModel ) {
455476 configurator .addInjectionPoint (ClassType .create (LangChain4jDotNames .STREAMING_CHAT_MODEL ));
456477 needsStreamingChatModelBean = true ;
457478 }
458479 } else {
459480 configurator .addInjectionPoint (ClassType .create (LangChain4jDotNames .CHAT_MODEL ),
460481 AnnotationInstance .builder (ModelName .class ).add ("value" , chatModelName ).build ());
461482
462- if (injectStreamingChatModelBean ) {
483+ if (needsStreamingChatModel ) {
463484 configurator .addInjectionPoint (ClassType .create (LangChain4jDotNames .STREAMING_CHAT_MODEL ),
464485 AnnotationInstance .builder (ModelName .class ).add ("value" , chatModelName ).build ());
465486 needsStreamingChatModelBean = true ;
@@ -515,7 +536,7 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
515536 }
516537
517538 if (LangChain4jDotNames .BEAN_IF_EXISTS_MODERATION_MODEL_SUPPLIER .toString ()
518- .equals (moderationModelSupplierClassName ) && injectModerationModelBean ) {
539+ .equals (moderationModelSupplierClassName ) && needsModerationModel ) {
519540
520541 if (NamedConfigUtil .isDefault (moderationModelName )) {
521542 configurator .addInjectionPoint (ClassType .create (LangChain4jDotNames .MODERATION_MODEL ));
@@ -527,6 +548,15 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
527548 needsModerationModelBean = true ;
528549 }
529550
551+ if (enableCache ) {
552+ if (LangChain4jDotNames .BEAN_AI_CACHE_PROVIDER_SUPPLIER .toString ().equals (aiCacheProviderSupplierClassName )) {
553+ configurator .addInjectionPoint (ClassType .create (LangChain4jDotNames .AI_CACHE_PROVIDER ));
554+ }
555+ configurator .addInjectionPoint (ClassType .create (LangChain4jDotNames .AI_CACHE_PROVIDER ));
556+ configurator .addInjectionPoint (ClassType .create (LangChain4jDotNames .EMBEDDING_MODEL ));
557+ needsAiCacheProvider = true ;
558+ }
559+
530560 syntheticBeanProducer .produce (configurator .done ());
531561 }
532562
@@ -551,6 +581,10 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
551581 if (needsModerationModelBean ) {
552582 unremoveableProducer .produce (UnremovableBeanBuildItem .beanTypes (LangChain4jDotNames .MODERATION_MODEL ));
553583 }
584+ if (needsAiCacheProvider ) {
585+ unremoveableProducer .produce (UnremovableBeanBuildItem .beanTypes (LangChain4jDotNames .AI_CACHE_PROVIDER ));
586+ unremoveableProducer .produce (UnremovableBeanBuildItem .beanTypes (LangChain4jDotNames .EMBEDDING_MODEL ));
587+ }
554588 if (!allToolNames .isEmpty ()) {
555589 unremoveableProducer .produce (UnremovableBeanBuildItem .beanTypes (allToolNames ));
556590 }
@@ -870,6 +904,8 @@ private AiServiceMethodCreateInfo gatherMethodMetadata(MethodInfo method, boolea
870904 }
871905
872906 boolean requiresModeration = method .hasAnnotation (LangChain4jDotNames .MODERATE );
907+ boolean requiresCache = method .declaringClass ().hasDeclaredAnnotation (LangChain4jDotNames .CACHE_RESULT )
908+ || method .hasDeclaredAnnotation (LangChain4jDotNames .CACHE_RESULT );
873909
874910 List <MethodParameterInfo > params = method .parameters ();
875911
@@ -887,7 +923,7 @@ private AiServiceMethodCreateInfo gatherMethodMetadata(MethodInfo method, boolea
887923 List <String > methodToolClassNames = gatherMethodToolClassNames (method );
888924
889925 return new AiServiceMethodCreateInfo (method .declaringClass ().name ().toString (), method .name (), systemMessageInfo ,
890- userMessageInfo , memoryIdParamPosition , requiresModeration ,
926+ userMessageInfo , memoryIdParamPosition , requiresModeration , requiresCache ,
891927 returnType , metricsTimedInfo , metricsCountedInfo , spanInfo , methodToolClassNames );
892928 }
893929
@@ -1222,6 +1258,16 @@ static Map<String, Integer> toNameToArgsPositionMap(List<TemplateParameterInfo>
12221258 }
12231259 }
12241260
1261+ private String getModelName (AnnotationValue value ) {
1262+ if (value != null ) {
1263+ String modelNameValueStr = value .asString ();
1264+ if ((modelNameValueStr != null ) && !modelNameValueStr .isEmpty ()) {
1265+ return modelNameValueStr ;
1266+ }
1267+ }
1268+ return NamedConfigUtil .DEFAULT_NAME ;
1269+ }
1270+
12251271 public static final class AiServicesMethodBuildItem extends MultiBuildItem {
12261272
12271273 private final MethodInfo methodInfo ;
0 commit comments