Skip to content

Commit b779f6b

Browse files
authored
Merge pull request #1142 from andreadimaio/main
Introduce observability in watsonx
2 parents 5ec8e05 + d7c304b commit b779f6b

File tree

9 files changed

+340
-99
lines changed

9 files changed

+340
-99
lines changed

model-providers/watsonx/deployment/src/main/java/io/quarkiverse/langchain4j/watsonx/deployment/WatsonxProcessor.java

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,19 @@
77
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.TOKEN_COUNT_ESTIMATOR;
88

99
import java.util.List;
10-
import java.util.function.Supplier;
10+
import java.util.function.Function;
1111

1212
import jakarta.enterprise.context.ApplicationScoped;
1313

1414
import org.jboss.jandex.AnnotationInstance;
15+
import org.jboss.jandex.ClassType;
16+
import org.jboss.jandex.ParameterizedType;
17+
import org.jboss.jandex.Type;
1518

1619
import dev.langchain4j.model.chat.ChatLanguageModel;
1720
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
1821
import io.quarkiverse.langchain4j.ModelName;
22+
import io.quarkiverse.langchain4j.deployment.DotNames;
1923
import io.quarkiverse.langchain4j.deployment.items.ChatModelProviderCandidateBuildItem;
2024
import io.quarkiverse.langchain4j.deployment.items.EmbeddingModelProviderCandidateBuildItem;
2125
import io.quarkiverse.langchain4j.deployment.items.ScoringModelProviderCandidateBuildItem;
@@ -26,6 +30,7 @@
2630
import io.quarkiverse.langchain4j.watsonx.runtime.WatsonxRecorder;
2731
import io.quarkiverse.langchain4j.watsonx.runtime.config.LangChain4jWatsonxConfig;
2832
import io.quarkiverse.langchain4j.watsonx.runtime.config.LangChain4jWatsonxFixedRuntimeConfig;
33+
import io.quarkus.arc.SyntheticCreationalContext;
2934
import io.quarkus.arc.deployment.SyntheticBeanBuildItem;
3035
import io.quarkus.deployment.Capabilities;
3136
import io.quarkus.deployment.Capability;
@@ -86,8 +91,8 @@ void generateBeans(WatsonxRecorder recorder, LangChain4jWatsonxConfig runtimeCon
8691
? fixedRuntimeConfig.defaultConfig().mode()
8792
: fixedRuntimeConfig.namedConfig().get(configName).mode();
8893

89-
Supplier<ChatLanguageModel> chatLanguageModel;
90-
Supplier<StreamingChatLanguageModel> streamingChatLanguageModel;
94+
Function<SyntheticCreationalContext<ChatLanguageModel>, ChatLanguageModel> chatLanguageModel;
95+
Function<SyntheticCreationalContext<StreamingChatLanguageModel>, StreamingChatLanguageModel> streamingChatLanguageModel;
9196

9297
if (mode.equalsIgnoreCase("chat")) {
9398
chatLanguageModel = recorder.chatModel(runtimeConfig, configName);
@@ -106,7 +111,9 @@ void generateBeans(WatsonxRecorder recorder, LangChain4jWatsonxConfig runtimeCon
106111
.setRuntimeInit()
107112
.defaultBean()
108113
.scope(ApplicationScoped.class)
109-
.supplier(chatLanguageModel);
114+
.addInjectionPoint(ParameterizedType.create(DotNames.CDI_INSTANCE,
115+
new Type[] { ClassType.create(DotNames.CHAT_MODEL_LISTENER) }, null))
116+
.createWith(chatLanguageModel);
110117

111118
addQualifierIfNecessary(chatBuilder, configName);
112119
beanProducer.produce(chatBuilder.done());
@@ -116,7 +123,9 @@ void generateBeans(WatsonxRecorder recorder, LangChain4jWatsonxConfig runtimeCon
116123
.setRuntimeInit()
117124
.defaultBean()
118125
.scope(ApplicationScoped.class)
119-
.supplier(chatLanguageModel);
126+
.addInjectionPoint(ParameterizedType.create(DotNames.CDI_INSTANCE,
127+
new Type[] { ClassType.create(DotNames.CHAT_MODEL_LISTENER) }, null))
128+
.createWith(chatLanguageModel);
120129

121130
addQualifierIfNecessary(tokenizerBuilder, configName);
122131
beanProducer.produce(tokenizerBuilder.done());
@@ -126,7 +135,9 @@ void generateBeans(WatsonxRecorder recorder, LangChain4jWatsonxConfig runtimeCon
126135
.setRuntimeInit()
127136
.defaultBean()
128137
.scope(ApplicationScoped.class)
129-
.supplier(streamingChatLanguageModel);
138+
.addInjectionPoint(ParameterizedType.create(DotNames.CDI_INSTANCE,
139+
new Type[] { ClassType.create(DotNames.CHAT_MODEL_LISTENER) }, null))
140+
.createWith(streamingChatLanguageModel);
130141

131142
addQualifierIfNecessary(streamingBuilder, configName);
132143
beanProducer.produce(streamingBuilder.done());
@@ -171,9 +182,8 @@ private void addQualifierIfNecessary(SyntheticBeanBuildItem.ExtendedBeanConfigur
171182

172183
/**
173184
* When both {@code rest-client-jackson} and {@code rest-client-jsonb} are present on the classpath we need to make sure
174-
* that Jackson is used.
175-
* This is not a proper solution as it affects all clients, but it's better than the having the reader/writers be selected
176-
* at random.
185+
* that Jackson is used. This is not a proper solution as it affects all clients, but it's better than the having the
186+
* reader/writers be selected at random.
177187
*/
178188
@BuildStep
179189
public void deprioritizeJsonb(Capabilities capabilities,

model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/GenerationAllPropertiesTest.java

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import static org.awaitility.Awaitility.await;
55
import static org.junit.jupiter.api.Assertions.assertEquals;
66
import static org.junit.jupiter.api.Assertions.assertNotNull;
7+
import static org.junit.jupiter.api.Assertions.fail;
78

89
import java.time.Duration;
910
import java.util.Date;
@@ -21,10 +22,12 @@
2122
import dev.langchain4j.data.embedding.Embedding;
2223
import dev.langchain4j.data.message.AiMessage;
2324
import dev.langchain4j.data.segment.TextSegment;
25+
import dev.langchain4j.model.StreamingResponseHandler;
2426
import dev.langchain4j.model.chat.ChatLanguageModel;
2527
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
2628
import dev.langchain4j.model.chat.TokenCountEstimator;
2729
import dev.langchain4j.model.embedding.EmbeddingModel;
30+
import dev.langchain4j.model.output.FinishReason;
2831
import dev.langchain4j.model.output.Response;
2932
import dev.langchain4j.model.scoring.ScoringModel;
3033
import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingParameters;
@@ -268,7 +271,25 @@ void check_chat_streaming_model_config() throws Exception {
268271
dev.langchain4j.data.message.UserMessage.from("UserMessage"));
269272

270273
var streamingResponse = new AtomicReference<AiMessage>();
271-
streamingChatModel.generate(messages, WireMockUtil.streamingResponseHandler(streamingResponse));
274+
streamingChatModel.generate(messages, new StreamingResponseHandler<>() {
275+
@Override
276+
public void onNext(String token) {
277+
}
278+
279+
@Override
280+
public void onError(Throwable error) {
281+
fail("Streaming failed: %s".formatted(error.getMessage()), error);
282+
}
283+
284+
@Override
285+
public void onComplete(Response<AiMessage> response) {
286+
assertEquals(FinishReason.LENGTH, response.finishReason());
287+
assertEquals(2, response.tokenUsage().inputTokenCount());
288+
assertEquals(14, response.tokenUsage().outputTokenCount());
289+
assertEquals(16, response.tokenUsage().totalTokenCount());
290+
streamingResponse.set(response.content());
291+
}
292+
});
272293

273294
await().atMost(Duration.ofMinutes(1))
274295
.pollInterval(Duration.ofSeconds(2))
@@ -277,5 +298,6 @@ void check_chat_streaming_model_config() throws Exception {
277298
assertThat(streamingResponse.get().text())
278299
.isNotNull()
279300
.isEqualTo(". I'm a beginner");
301+
280302
}
281303
}

model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/Watsonx.java

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,31 @@
22

33
import java.net.URL;
44
import java.time.Duration;
5+
import java.util.Collections;
6+
import java.util.List;
7+
import java.util.Map;
58
import java.util.concurrent.TimeUnit;
69

10+
import org.jboss.logging.Logger;
711
import org.jboss.resteasy.reactive.client.api.LoggingScope;
812

13+
import dev.langchain4j.model.chat.listener.ChatModelErrorContext;
14+
import dev.langchain4j.model.chat.listener.ChatModelListener;
15+
import dev.langchain4j.model.chat.listener.ChatModelRequest;
16+
import dev.langchain4j.model.chat.listener.ChatModelRequestContext;
17+
import dev.langchain4j.model.chat.listener.ChatModelResponse;
18+
import dev.langchain4j.model.chat.listener.ChatModelResponseContext;
919
import io.quarkiverse.langchain4j.watsonx.client.WatsonxRestApi;
1020
import io.quarkiverse.langchain4j.watsonx.client.filter.BearerTokenHeaderFactory;
1121
import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder;
1222

1323
public abstract class Watsonx {
1424

25+
private static final Logger logger = Logger.getLogger(Watsonx.class);
26+
1527
protected final String modelId, projectId, spaceId, version;
1628
protected final WatsonxRestApi client;
29+
protected final List<ChatModelListener> listeners;
1730

1831
public Watsonx(Builder<?> builder) {
1932
QuarkusRestClientBuilder restClientBuilder = QuarkusRestClientBuilder.newBuilder()
@@ -34,6 +47,38 @@ public Watsonx(Builder<?> builder) {
3447
this.spaceId = builder.spaceId;
3548
this.projectId = builder.projectId;
3649
this.version = builder.version;
50+
this.listeners = builder.listeners;
51+
}
52+
53+
protected void beforeSentRequest(ChatModelRequest request, Map<Object, Object> attributes) {
54+
for (ChatModelListener listener : listeners) {
55+
try {
56+
listener.onRequest(new ChatModelRequestContext(request, attributes));
57+
} catch (Exception e) {
58+
logger.warn("Exception while calling model listener", e);
59+
}
60+
}
61+
}
62+
63+
protected void afterReceivedResponse(ChatModelResponse response, ChatModelRequest request, Map<Object, Object> attributes) {
64+
for (ChatModelListener listener : listeners) {
65+
try {
66+
listener.onResponse(new ChatModelResponseContext(response, request, attributes));
67+
} catch (Exception e) {
68+
logger.warn("Exception while calling model listener", e);
69+
}
70+
}
71+
}
72+
73+
protected void onRequestError(Throwable error, ChatModelRequest request, ChatModelResponse partialResponse,
74+
Map<Object, Object> attributes) {
75+
for (ChatModelListener listener : listeners) {
76+
try {
77+
listener.onError(new ChatModelErrorContext(error, request, partialResponse, attributes));
78+
} catch (Exception e) {
79+
logger.warn("Exception while calling model listener", e);
80+
}
81+
}
3782
}
3883

3984
public WatsonxRestApi getClient() {
@@ -67,6 +112,7 @@ public static abstract class Builder<T extends Builder<T>> {
67112
protected URL url;
68113
protected boolean logResponses;
69114
protected boolean logRequests;
115+
private List<ChatModelListener> listeners = Collections.emptyList();
70116
protected WatsonxTokenGenerator tokenGenerator;
71117

72118
public T modelId(String modelId) {
@@ -99,6 +145,11 @@ public T timeout(Duration timeout) {
99145
return (T) this;
100146
}
101147

148+
public T listeners(List<ChatModelListener> listeners) {
149+
this.listeners = listeners;
150+
return (T) this;
151+
}
152+
102153
public T tokenGenerator(WatsonxTokenGenerator tokenGenerator) {
103154
this.tokenGenerator = tokenGenerator;
104155
return (T) this;

0 commit comments

Comments
 (0)