Skip to content

Commit 9934f65

Browse files
authored
Merge pull request #1125 from quarkiverse/#1081
Introduce observability into streaming models
2 parents 9c6b20e + 2560ccc commit 9934f65

File tree

10 files changed

+269
-28
lines changed

10 files changed

+269
-28
lines changed

model-providers/ollama/deployment/src/main/java/io/quarkiverse/langchain4j/ollama/deployment/OllamaProcessor.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,9 @@ void generateBeans(OllamaRecorder recorder,
161161
.setRuntimeInit()
162162
.defaultBean()
163163
.scope(ApplicationScoped.class)
164-
.supplier(recorder.streamingChatModel(config, fixedRuntimeConfig, configName));
164+
.addInjectionPoint(ParameterizedType.create(DotNames.CDI_INSTANCE,
165+
new Type[] { ClassType.create(DotNames.CHAT_MODEL_LISTENER) }, null))
166+
.createWith(recorder.streamingChatModel(config, fixedRuntimeConfig, configName));
165167
addQualifierIfNecessary(streamingBuilder, configName);
166168
beanProducer.produce(streamingBuilder.done());
167169
}

model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/OllamaStreamingChatLanguageModel.java

Lines changed: 121 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,27 @@
66

77
import java.time.Duration;
88
import java.util.ArrayList;
9+
import java.util.Collections;
910
import java.util.List;
11+
import java.util.Map;
12+
import java.util.concurrent.ConcurrentHashMap;
1013
import java.util.function.Consumer;
1114
import java.util.stream.Collectors;
1215

16+
import org.jboss.logging.Logger;
17+
1318
import dev.langchain4j.agent.tool.ToolExecutionRequest;
1419
import dev.langchain4j.agent.tool.ToolSpecification;
1520
import dev.langchain4j.data.message.AiMessage;
1621
import dev.langchain4j.data.message.ChatMessage;
1722
import dev.langchain4j.model.StreamingResponseHandler;
1823
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
24+
import dev.langchain4j.model.chat.listener.ChatModelErrorContext;
25+
import dev.langchain4j.model.chat.listener.ChatModelListener;
26+
import dev.langchain4j.model.chat.listener.ChatModelRequest;
27+
import dev.langchain4j.model.chat.listener.ChatModelRequestContext;
28+
import dev.langchain4j.model.chat.listener.ChatModelResponse;
29+
import dev.langchain4j.model.chat.listener.ChatModelResponseContext;
1930
import dev.langchain4j.model.output.Response;
2031
import dev.langchain4j.model.output.TokenUsage;
2132
import io.smallrye.mutiny.Context;
@@ -24,20 +35,26 @@
2435
* Use to have streaming feature on models used trough Ollama.
2536
*/
2637
public class OllamaStreamingChatLanguageModel implements StreamingChatLanguageModel {
38+
39+
private static final Logger log = Logger.getLogger(OllamaStreamingChatLanguageModel.class);
40+
2741
private static final String TOOLS_CONTEXT = "TOOLS";
2842
private static final String TOKEN_USAGE_CONTEXT = "TOKEN_USAGE";
2943
private static final String RESPONSE_CONTEXT = "RESPONSE";
44+
private static final String MODEL_ID = "MODEL_ID";
3045
private final OllamaClient client;
3146
private final String model;
3247
private final String format;
3348
private final Options options;
49+
private final List<ChatModelListener> listeners;
3450

3551
private OllamaStreamingChatLanguageModel(OllamaStreamingChatLanguageModel.Builder builder) {
3652
client = new OllamaClient(builder.baseUrl, builder.timeout, builder.logRequests, builder.logResponses,
3753
builder.configName, builder.tlsConfigurationName);
3854
model = builder.model;
3955
format = builder.format;
4056
options = builder.options;
57+
this.listeners = builder.listeners;
4158
}
4259

4360
public static OllamaStreamingChatLanguageModel.Builder builder() {
@@ -60,13 +77,25 @@ public void generate(List<ChatMessage> messages, List<ToolSpecification> toolSpe
6077
.build();
6178

6279
Context context = Context.empty();
80+
context.put(MODEL_ID, "");
6381
context.put(RESPONSE_CONTEXT, new ArrayList<ChatResponse>());
6482
context.put(TOOLS_CONTEXT, new ArrayList<ToolExecutionRequest>());
6583

84+
ChatModelRequest modelListenerRequest = createModelListenerRequest(request, messages, toolSpecifications);
85+
Map<Object, Object> attributes = new ConcurrentHashMap<>();
86+
ChatModelRequestContext requestContext = new ChatModelRequestContext(modelListenerRequest, attributes);
87+
listeners.forEach(listener -> {
88+
try {
89+
listener.onRequest(requestContext);
90+
} catch (Exception e) {
91+
log.warn("Exception while calling model listener", e);
92+
}
93+
});
94+
6695
client.streamingChat(request)
6796
.subscribe()
6897
.with(context,
69-
new Consumer<ChatResponse>() {
98+
new Consumer<>() {
7099
@Override
71100
@SuppressWarnings("unchecked")
72101
public void accept(ChatResponse response) {
@@ -89,6 +118,9 @@ public void accept(ChatResponse response) {
89118
}
90119

91120
if (response.done()) {
121+
if (response.model() != null) {
122+
context.put(MODEL_ID, response.model());
123+
}
92124
TokenUsage tokenUsage = new TokenUsage(
93125
response.evalCount(),
94126
response.promptEvalCount(),
@@ -101,9 +133,36 @@ public void accept(ChatResponse response) {
101133
}
102134
}
103135
},
104-
new Consumer<Throwable>() {
136+
new Consumer<>() {
105137
@Override
106138
public void accept(Throwable error) {
139+
List<ChatResponse> chatResponses = context.get(RESPONSE_CONTEXT);
140+
String stringResponse = chatResponses.stream()
141+
.map(ChatResponse::message)
142+
.map(Message::content)
143+
.collect(Collectors.joining());
144+
AiMessage aiMessage = new AiMessage(stringResponse);
145+
Response<AiMessage> aiMessageResponse = Response.from(aiMessage);
146+
147+
ChatModelResponse modelListenerPartialResponse = createModelListenerResponse(
148+
null,
149+
context.get(MODEL_ID),
150+
aiMessageResponse);
151+
152+
ChatModelErrorContext errorContext = new ChatModelErrorContext(
153+
error,
154+
modelListenerRequest,
155+
modelListenerPartialResponse,
156+
attributes);
157+
158+
listeners.forEach(listener -> {
159+
try {
160+
listener.onError(errorContext);
161+
} catch (Exception e) {
162+
log.warn("Exception while calling model listener", e);
163+
}
164+
});
165+
107166
handler.onError(error);
108167
}
109168
},
@@ -115,22 +174,72 @@ public void run() {
115174
List<ChatResponse> chatResponses = context.get(RESPONSE_CONTEXT);
116175
List<ToolExecutionRequest> toolExecutionRequests = context.get(TOOLS_CONTEXT);
117176

118-
if (toolExecutionRequests.size() > 0) {
177+
if (!toolExecutionRequests.isEmpty()) {
119178
handler.onComplete(Response.from(AiMessage.from(toolExecutionRequests), tokenUsage));
120179
return;
121180
}
122181

123-
String response = chatResponses.stream()
182+
String stringResponse = chatResponses.stream()
124183
.map(ChatResponse::message)
125184
.map(Message::content)
126185
.collect(Collectors.joining());
127186

128-
AiMessage message = new AiMessage(response);
129-
handler.onComplete(Response.from(message, tokenUsage));
187+
AiMessage aiMessage = new AiMessage(stringResponse);
188+
Response<AiMessage> aiMessageResponse = Response.from(aiMessage, tokenUsage);
189+
190+
ChatModelResponse modelListenerResponse = createModelListenerResponse(
191+
null,
192+
context.get(MODEL_ID),
193+
aiMessageResponse);
194+
ChatModelResponseContext responseContext = new ChatModelResponseContext(
195+
modelListenerResponse,
196+
modelListenerRequest,
197+
attributes);
198+
listeners.forEach(listener -> {
199+
try {
200+
listener.onResponse(responseContext);
201+
} catch (Exception e) {
202+
log.warn("Exception while calling model listener", e);
203+
}
204+
});
205+
206+
handler.onComplete(aiMessageResponse);
130207
}
131208
});
132209
}
133210

211+
private ChatModelRequest createModelListenerRequest(ChatRequest request,
212+
List<ChatMessage> messages,
213+
List<ToolSpecification> toolSpecifications) {
214+
Options options = request.options();
215+
var builder = ChatModelRequest.builder()
216+
.model(request.model())
217+
.messages(messages)
218+
.toolSpecifications(toolSpecifications);
219+
if (options != null) {
220+
builder.temperature(options.temperature())
221+
.topP(options.topP())
222+
.maxTokens(options.numPredict());
223+
}
224+
return builder.build();
225+
}
226+
227+
private ChatModelResponse createModelListenerResponse(String responseId,
228+
String responseModel,
229+
Response<AiMessage> response) {
230+
if (response == null) {
231+
return null;
232+
}
233+
234+
return ChatModelResponse.builder()
235+
.id(responseId)
236+
.model(responseModel)
237+
.tokenUsage(response.tokenUsage())
238+
.finishReason(response.finishReason())
239+
.aiMessage(response.content())
240+
.build();
241+
}
242+
134243
@Override
135244
public void generate(List<ChatMessage> messages, ToolSpecification toolSpecification,
136245
StreamingResponseHandler<AiMessage> handler) {
@@ -161,6 +270,7 @@ private Builder() {
161270
private boolean logRequests = false;
162271
private boolean logResponses = false;
163272
private String configName;
273+
private List<ChatModelListener> listeners = Collections.emptyList();
164274

165275
public Builder baseUrl(String val) {
166276
baseUrl = val;
@@ -207,6 +317,11 @@ public Builder configName(String configName) {
207317
return this;
208318
}
209319

320+
public Builder listeners(List<ChatModelListener> listeners) {
321+
this.listeners = listeners;
322+
return this;
323+
}
324+
210325
public OllamaStreamingChatLanguageModel build() {
211326
return new OllamaStreamingChatLanguageModel(this);
212327
}

model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/runtime/OllamaRecorder.java

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,8 @@ public EmbeddingModel get() {
133133
}
134134
}
135135

136-
public Supplier<StreamingChatLanguageModel> streamingChatModel(LangChain4jOllamaConfig runtimeConfig,
136+
public Function<SyntheticCreationalContext<StreamingChatLanguageModel>, StreamingChatLanguageModel> streamingChatModel(
137+
LangChain4jOllamaConfig runtimeConfig,
137138
LangChain4jOllamaFixedRuntimeConfig fixedRuntimeConfig, String configName) {
138139
LangChain4jOllamaConfig.OllamaConfig ollamaConfig = correspondingOllamaConfig(runtimeConfig, configName);
139140
LangChain4jOllamaFixedRuntimeConfig.OllamaConfig ollamaFixedConfig = correspondingOllamaFixedConfig(fixedRuntimeConfig,
@@ -166,16 +167,20 @@ public Supplier<StreamingChatLanguageModel> streamingChatModel(LangChain4jOllama
166167
.options(optionsBuilder.build())
167168
.configName(NamedConfigUtil.isDefault(configName) ? null : configName);
168169

169-
return new Supplier<>() {
170+
return new Function<>() {
170171
@Override
171-
public StreamingChatLanguageModel get() {
172+
public StreamingChatLanguageModel apply(
173+
SyntheticCreationalContext<StreamingChatLanguageModel> context) {
174+
builder.listeners(context.getInjectedReference(CHAT_MODEL_LISTENER_TYPE_LITERAL).stream()
175+
.collect(Collectors.toList()));
172176
return builder.build();
173177
}
174178
};
175179
} else {
176-
return new Supplier<>() {
180+
return new Function<>() {
177181
@Override
178-
public StreamingChatLanguageModel get() {
182+
public StreamingChatLanguageModel apply(
183+
SyntheticCreationalContext<StreamingChatLanguageModel> context) {
179184
return new DisabledStreamingChatLanguageModel();
180185
}
181186
};

model-providers/openai/azure-openai/deployment/src/main/java/io/quarkiverse/langchain4j/azure/openai/deployment/AzureOpenAiProcessor.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,8 @@ void generateBeans(AzureOpenAiRecorder recorder,
108108
.setRuntimeInit()
109109
.defaultBean()
110110
.scope(ApplicationScoped.class)
111+
.addInjectionPoint(ParameterizedType.create(DotNames.CDI_INSTANCE,
112+
new Type[] { ClassType.create(DotNames.CHAT_MODEL_LISTENER) }, null))
111113
.createWith(streamingChatModel);
112114
addQualifierIfNecessary(streamingBuilder, configName);
113115
beanProducer.produce(streamingBuilder.done());

model-providers/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/AzureOpenAiChatModel.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,6 @@ public AzureOpenAiChatModel(String endpoint,
126126
: ResponseFormat.builder()
127127
.type(ResponseFormatType.valueOf(responseFormat.toUpperCase(Locale.ROOT)))
128128
.build();
129-
;
130129
}
131130

132131
@Override

0 commit comments

Comments
 (0)