Skip to content

Commit 3362655

Browse files
authored
Merge pull request #145 from quarkiverse/#138
Properly support Smallrye Fault Tolerance
2 parents f2f5876 + 9674a38 commit 3362655

File tree

9 files changed

+222
-36
lines changed

9 files changed

+222
-36
lines changed

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java

Lines changed: 76 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@
2323
import java.util.function.Predicate;
2424
import java.util.stream.Collectors;
2525

26+
import jakarta.annotation.PreDestroy;
27+
import jakarta.enterprise.context.Dependent;
2628
import jakarta.enterprise.inject.Instance;
29+
import jakarta.inject.Inject;
2730

2831
import org.jboss.jandex.AnnotationInstance;
2932
import org.jboss.jandex.AnnotationTarget;
@@ -50,7 +53,6 @@
5053
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodCreateInfo;
5154
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodImplementationSupport;
5255
import io.quarkiverse.langchain4j.runtime.aiservice.ChatMemoryRemovable;
53-
import io.quarkiverse.langchain4j.runtime.aiservice.DeclarativeAiServiceBeanDestroyer;
5456
import io.quarkiverse.langchain4j.runtime.aiservice.DeclarativeAiServiceCreateInfo;
5557
import io.quarkiverse.langchain4j.runtime.aiservice.MetricsWrapper;
5658
import io.quarkiverse.langchain4j.runtime.aiservice.QuarkusAiServiceContext;
@@ -59,6 +61,8 @@
5961
import io.quarkus.arc.ArcContainer;
6062
import io.quarkus.arc.InstanceHandle;
6163
import io.quarkus.arc.deployment.AdditionalBeanBuildItem;
64+
import io.quarkus.arc.deployment.GeneratedBeanBuildItem;
65+
import io.quarkus.arc.deployment.GeneratedBeanGizmoAdaptor;
6266
import io.quarkus.arc.deployment.SyntheticBeanBuildItem;
6367
import io.quarkus.arc.deployment.UnremovableBeanBuildItem;
6468
import io.quarkus.arc.processor.BuiltinScope;
@@ -311,16 +315,18 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
311315
: null);
312316

313317
SyntheticBeanBuildItem.ExtendedBeanConfigurator configurator = SyntheticBeanBuildItem
314-
.configure(declarativeAiServiceClassInfo.name())
318+
.configure(QuarkusAiServiceContext.class)
315319
.createWith(recorder.createDeclarativeAiService(
316320
new DeclarativeAiServiceCreateInfo(serviceClassName, chatLanguageModelSupplierClassName,
317321
toolClassNames, chatMemoryProviderSupplierClassName,
318322
retrieverSupplierClassName,
319323
auditServiceClassSupplierName,
320324
moderationModelSupplierClassName)))
321-
.destroyer(DeclarativeAiServiceBeanDestroyer.class)
322325
.setRuntimeInit()
323-
.scope(bi.getCdiScope());
326+
.addQualifier()
327+
.annotation(Langchain4jDotNames.QUARKUS_AI_SERVICE_CONTEXT_QUALIFIER).addValue("value", serviceClassName)
328+
.done()
329+
.scope(Dependent.class);
324330
if ((chatLanguageModelSupplierClassName == null) && selectedChatModelProvider.isPresent()) { // TODO: is second condition needed?
325331
configurator.addInjectionPoint(ClassType.create(Langchain4jDotNames.CHAT_MODEL));
326332
needsChatModelBean = true;
@@ -392,6 +398,7 @@ public void handleAiServices(AiServicesRecorder recorder,
392398
CombinedIndexBuildItem indexBuildItem,
393399
List<DeclarativeAiServiceBuildItem> declarativeAiServiceItems,
394400
BuildProducer<GeneratedClassBuildItem> generatedClassProducer,
401+
BuildProducer<GeneratedBeanBuildItem> generatedBeanProducer,
395402
BuildProducer<ReflectiveClassBuildItem> reflectiveClassProducer,
396403
BuildProducer<AiServicesMethodBuildItem> aiServicesMethodProducer,
397404
BuildProducer<AdditionalBeanBuildItem> additionalBeanProducer,
@@ -476,7 +483,8 @@ public void handleAiServices(AiServicesRecorder recorder,
476483

477484
Map<String, AiServiceClassCreateInfo> perClassMetadata = new HashMap<>();
478485
if (!ifacesForCreate.isEmpty()) {
479-
ClassOutput classOutput = new GeneratedClassGizmoAdaptor(generatedClassProducer, true);
486+
ClassOutput generatedClassOutput = new GeneratedClassGizmoAdaptor(generatedClassProducer, true);
487+
ClassOutput generatedBeanOutput = new GeneratedBeanGizmoAdaptor(generatedBeanProducer);
480488
for (ClassInfo iface : ifacesForCreate) {
481489
Set<MethodInfo> allMethods = new HashSet<>(iface.methods());
482490
JandexUtil.getAllSuperinterfaces(iface, index).forEach(ci -> allMethods.addAll(ci.methods()));
@@ -497,13 +505,22 @@ public void handleAiServices(AiServicesRecorder recorder,
497505
boolean isRegisteredService = registeredAiServiceClassNames.contains(ifaceName);
498506

499507
ClassCreator.Builder classCreatorBuilder = ClassCreator.builder()
500-
.classOutput(classOutput)
508+
.classOutput(isRegisteredService ? generatedBeanOutput : generatedClassOutput)
501509
.className(implClassName)
502510
.interfaces(ifaceName, ChatMemoryRemovable.class.getName());
503511
if (isRegisteredService) {
504512
classCreatorBuilder.interfaces(AutoCloseable.class);
505513
}
506514
try (ClassCreator classCreator = classCreatorBuilder.build()) {
515+
if (isRegisteredService) {
516+
// we need to make this a bean, so we need to add the proper scope annotation
517+
ScopeInfo scopeInfo = declarativeAiServiceItems.stream()
518+
.filter(bi -> bi.getServiceClassInfo().equals(iface))
519+
.findFirst().orElseThrow(() -> new IllegalStateException(
520+
"Unable to determine the CDI scope of " + iface))
521+
.getCdiScope();
522+
classCreator.addAnnotation(scopeInfo.getDotName().toString());
523+
}
507524

508525
FieldDescriptor contextField = classCreator.getFieldCreator("context", QuarkusAiServiceContext.class)
509526
.setModifiers(Modifier.PRIVATE | Modifier.FINAL)
@@ -516,37 +533,67 @@ public void handleAiServices(AiServicesRecorder recorder,
516533
String methodId = createMethodId(methodInfo);
517534
perMethodMetadata.put(methodId,
518535
gatherMethodMetadata(methodInfo, addMicrometerMetrics, addOpenTelemetrySpan));
519-
MethodCreator constructor = classCreator.getMethodCreator(MethodDescriptor.INIT, "V",
520-
QuarkusAiServiceContext.class);
521-
constructor.invokeSpecialMethod(OBJECT_CONSTRUCTOR, constructor.getThis());
522-
constructor.writeInstanceField(contextField, constructor.getThis(), constructor.getMethodParam(0));
523-
constructor.returnValue(null);
524-
525-
MethodCreator mc = classCreator.getMethodCreator(MethodDescriptor.of(methodInfo));
526-
ResultHandle contextHandle = mc.readInstanceField(contextField, mc.getThis());
527-
ResultHandle methodCreateInfoHandle = mc.invokeStaticMethod(RECORDER_METHOD_CREATE_INFO,
528-
mc.load(ifaceName),
529-
mc.load(methodId));
530-
ResultHandle paramsHandle = mc.newArray(Object.class, methodInfo.parametersCount());
531-
for (int i = 0; i < methodInfo.parametersCount(); i++) {
532-
mc.writeArrayValue(paramsHandle, i, mc.getMethodParam(i));
536+
{
537+
MethodCreator ctor = classCreator.getMethodCreator(MethodDescriptor.INIT, "V",
538+
QuarkusAiServiceContext.class);
539+
ctor.setModifiers(Modifier.PUBLIC);
540+
ctor.addAnnotation(Inject.class);
541+
ctor.getParameterAnnotations(0)
542+
.addAnnotation(Langchain4jDotNames.QUARKUS_AI_SERVICE_CONTEXT_QUALIFIER.toString())
543+
.add("value", ifaceName);
544+
ctor.invokeSpecialMethod(OBJECT_CONSTRUCTOR, ctor.getThis());
545+
ctor.writeInstanceField(contextField, ctor.getThis(),
546+
ctor.getMethodParam(0));
547+
ctor.returnValue(null);
533548
}
534549

535-
ResultHandle supportHandle = getFromCDI(mc, AiServiceMethodImplementationSupport.class.getName());
536-
ResultHandle inputHandle = mc.newInstance(
537-
MethodDescriptor.ofConstructor(AiServiceMethodImplementationSupport.Input.class,
538-
QuarkusAiServiceContext.class, AiServiceMethodCreateInfo.class, Object[].class),
539-
contextHandle, methodCreateInfoHandle, paramsHandle);
540-
541-
ResultHandle resultHandle = mc.invokeVirtualMethod(SUPPORT_IMPLEMENT, supportHandle, inputHandle);
542-
mc.returnValue(resultHandle);
550+
{
551+
MethodCreator noArgsCtor = classCreator.getMethodCreator(MethodDescriptor.INIT, "V");
552+
noArgsCtor.setModifiers(Modifier.PUBLIC);
553+
noArgsCtor.invokeSpecialMethod(OBJECT_CONSTRUCTOR, noArgsCtor.getThis());
554+
noArgsCtor.writeInstanceField(contextField, noArgsCtor.getThis(), noArgsCtor.loadNull());
555+
noArgsCtor.returnValue(null);
556+
}
543557

544-
aiServicesMethodProducer.produce(new AiServicesMethodBuildItem(methodInfo));
558+
{ // actual method we need to implement
559+
MethodCreator mc = classCreator.getMethodCreator(MethodDescriptor.of(methodInfo));
560+
561+
// copy annotations
562+
for (AnnotationInstance annotationInstance : methodInfo.declaredAnnotations()) {
563+
// TODO: we need to review this
564+
if (annotationInstance.name().toString()
565+
.startsWith("org.eclipse.microprofile.faulttolerance")) {
566+
mc.addAnnotation(annotationInstance);
567+
}
568+
}
569+
570+
ResultHandle contextHandle = mc.readInstanceField(contextField, mc.getThis());
571+
ResultHandle methodCreateInfoHandle = mc.invokeStaticMethod(RECORDER_METHOD_CREATE_INFO,
572+
mc.load(ifaceName),
573+
mc.load(methodId));
574+
ResultHandle paramsHandle = mc.newArray(Object.class, methodInfo.parametersCount());
575+
for (int i = 0; i < methodInfo.parametersCount(); i++) {
576+
mc.writeArrayValue(paramsHandle, i, mc.getMethodParam(i));
577+
}
578+
579+
ResultHandle supportHandle = getFromCDI(mc, AiServiceMethodImplementationSupport.class.getName());
580+
ResultHandle inputHandle = mc.newInstance(
581+
MethodDescriptor.ofConstructor(AiServiceMethodImplementationSupport.Input.class,
582+
QuarkusAiServiceContext.class, AiServiceMethodCreateInfo.class,
583+
Object[].class),
584+
contextHandle, methodCreateInfoHandle, paramsHandle);
585+
586+
ResultHandle resultHandle = mc.invokeVirtualMethod(SUPPORT_IMPLEMENT, supportHandle, inputHandle);
587+
mc.returnValue(resultHandle);
588+
589+
aiServicesMethodProducer.produce(new AiServicesMethodBuildItem(methodInfo));
590+
}
545591
}
546592

547593
if (isRegisteredService) {
548594
MethodCreator mc = classCreator.getMethodCreator(
549595
MethodDescriptor.ofMethod(implClassName, "close", void.class));
596+
mc.addAnnotation(PreDestroy.class);
550597
ResultHandle contextHandle = mc.readInstanceField(contextField, mc.getThis());
551598
mc.invokeVirtualMethod(QUARKUS_AI_SERVICES_CONTEXT_CLOSE, contextHandle);
552599
mc.returnVoid();

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/Langchain4jDotNames.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import io.quarkiverse.langchain4j.CreatedAware;
2222
import io.quarkiverse.langchain4j.RegisterAiService;
2323
import io.quarkiverse.langchain4j.audit.AuditService;
24+
import io.quarkiverse.langchain4j.runtime.aiservice.QuarkusAiServiceContextQualifier;
2425

2526
public class Langchain4jDotNames {
2627
public static final DotName CHAT_MODEL = DotName.createSimple(ChatLanguageModel.class);
@@ -67,4 +68,7 @@ public class Langchain4jDotNames {
6768
static final DotName NO_MODERATION_MODEL_SUPPLIER = DotName.createSimple(
6869
RegisterAiService.NoModerationModelSupplier.class);
6970

71+
static final DotName QUARKUS_AI_SERVICE_CONTEXT_QUALIFIER = DotName.createSimple(
72+
QuarkusAiServiceContextQualifier.class);
73+
7074
}

core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ public T apply(SyntheticCreationalContext<T> creationalContext) {
7575
.loadClass(info.getServiceClassName());
7676

7777
QuarkusAiServiceContext aiServiceContext = new QuarkusAiServiceContext(serviceClass);
78+
// we don't really care about QuarkusAiServices here, all we care about is that it
79+
// properly populates QuarkusAiServiceContext which is what we are trying to construct
7880
var quarkusAiServices = INSTANCE.create(aiServiceContext);
7981

8082
if (info.getLanguageModelSupplierClassName() != null) {
@@ -164,7 +166,7 @@ public T apply(SyntheticCreationalContext<T> creationalContext) {
164166
}
165167
}
166168

167-
return (T) quarkusAiServices.build();
169+
return (T) aiServiceContext;
168170
} catch (ClassNotFoundException e) {
169171
throw new IllegalStateException(e);
170172
} catch (InvocationTargetException | NoSuchMethodException | IllegalAccessException

core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceContext.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@ public class QuarkusAiServiceContext extends AiServiceContext {
1111

1212
public AuditService auditService;
1313

14+
// needed by Arc
15+
public QuarkusAiServiceContext() {
16+
super(null);
17+
}
18+
1419
public QuarkusAiServiceContext(Class<?> aiServiceClass) {
1520
super(aiServiceClass);
1621
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
package io.quarkiverse.langchain4j.runtime.aiservice;
2+
3+
import static java.lang.annotation.ElementType.PARAMETER;
4+
import static java.lang.annotation.RetentionPolicy.RUNTIME;
5+
6+
import java.lang.annotation.Inherited;
7+
import java.lang.annotation.Retention;
8+
import java.lang.annotation.Target;
9+
10+
import jakarta.enterprise.util.AnnotationLiteral;
11+
import jakarta.inject.Qualifier;
12+
13+
@Qualifier
14+
@Inherited
15+
@Target({ PARAMETER })
16+
@Retention(RUNTIME)
17+
public @interface QuarkusAiServiceContextQualifier {
18+
19+
/**
20+
* The name of class
21+
*/
22+
String value();
23+
24+
class Literal extends AnnotationLiteral<QuarkusAiServiceContextQualifier> implements QuarkusAiServiceContextQualifier {
25+
26+
public static Literal of(String value) {
27+
return new Literal(value);
28+
}
29+
30+
private final String value;
31+
32+
public Literal(String value) {
33+
this.value = value;
34+
}
35+
36+
@Override
37+
public String value() {
38+
return value;
39+
}
40+
}
41+
}

integration-tests/openai/pom.xml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525
<groupId>io.quarkus</groupId>
2626
<artifactId>quarkus-micrometer</artifactId>
2727
</dependency>
28+
<dependency>
29+
<groupId>io.quarkus</groupId>
30+
<artifactId>quarkus-smallrye-fault-tolerance</artifactId>
31+
</dependency>
2832
<dependency>
2933
<groupId>io.quarkus</groupId>
3034
<artifactId>quarkus-junit5</artifactId>
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
package org.acme.example.openai.aiservices;
2+
3+
import jakarta.ws.rs.GET;
4+
import jakarta.ws.rs.Path;
5+
6+
import org.eclipse.microprofile.faulttolerance.Fallback;
7+
8+
import dev.langchain4j.service.SystemMessage;
9+
import io.quarkiverse.langchain4j.RegisterAiService;
10+
11+
@Path("assistant-with-fallback")
12+
public class AssistantResourceWithFallback {
13+
14+
private final Assistant assistant;
15+
16+
public AssistantResourceWithFallback(Assistant assistant) {
17+
this.assistant = assistant;
18+
}
19+
20+
@GET
21+
public String get() {
22+
return assistant.chat("test");
23+
}
24+
25+
@RegisterAiService
26+
interface Assistant {
27+
28+
@SystemMessage("""
29+
Help me: {something}
30+
""")
31+
@Fallback(fallbackMethod = "fallback")
32+
String chat(String message);
33+
34+
static String fallback(String message) {
35+
return "This is a fallback message";
36+
}
37+
}
38+
39+
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
package org.acme.example.openai.aiservices;
2+
3+
import static io.restassured.RestAssured.given;
4+
import static org.hamcrest.CoreMatchers.equalTo;
5+
6+
import java.net.URL;
7+
8+
import org.junit.jupiter.api.Test;
9+
10+
import io.quarkus.test.common.http.TestHTTPEndpoint;
11+
import io.quarkus.test.common.http.TestHTTPResource;
12+
import io.quarkus.test.junit.QuarkusTest;
13+
14+
@QuarkusTest
15+
class AssistantResourceWithFallbackTest {
16+
17+
@TestHTTPEndpoint(AssistantResourceWithFallback.class)
18+
@TestHTTPResource
19+
URL url;
20+
21+
@Test
22+
public void fallback() {
23+
given()
24+
.baseUri(url.toString())
25+
.get()
26+
.then()
27+
.statusCode(200)
28+
.body(equalTo("This is a fallback message"));
29+
}
30+
}

0 commit comments

Comments
 (0)