Skip to content

Commit 44206cd

Browse files
committed
feat: add observability to zhipu chat model
1 parent 1606383 commit 44206cd

File tree

7 files changed

+331
-68
lines changed

7 files changed

+331
-68
lines changed

models/spring-ai-zhipuai/pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,12 @@
5454
<scope>test</scope>
5555
</dependency>
5656

57+
<dependency>
58+
<groupId>io.micrometer</groupId>
59+
<artifactId>micrometer-observation-test</artifactId>
60+
<scope>test</scope>
61+
</dependency>
62+
5763
</dependencies>
5864

5965
</project>

models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java

Lines changed: 135 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
*/
1616
package org.springframework.ai.zhipuai;
1717

18+
import io.micrometer.observation.Observation;
19+
import io.micrometer.observation.ObservationRegistry;
20+
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
1821
import org.slf4j.Logger;
1922
import org.slf4j.LoggerFactory;
2023
import org.springframework.ai.chat.messages.AssistantMessage;
@@ -23,12 +26,19 @@
2326
import org.springframework.ai.chat.messages.UserMessage;
2427
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
2528
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
29+
import org.springframework.ai.chat.metadata.EmptyUsage;
2630
import org.springframework.ai.chat.model.AbstractToolCallSupport;
2731
import org.springframework.ai.chat.model.ChatModel;
2832
import org.springframework.ai.chat.model.ChatResponse;
2933
import org.springframework.ai.chat.model.Generation;
34+
import org.springframework.ai.chat.model.MessageAggregator;
3035
import org.springframework.ai.chat.model.StreamingChatModel;
36+
import org.springframework.ai.chat.observation.ChatModelObservationContext;
37+
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
38+
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
39+
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
3140
import org.springframework.ai.chat.prompt.ChatOptions;
41+
import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
3242
import org.springframework.ai.chat.prompt.Prompt;
3343
import org.springframework.ai.model.ModelOptionsUtils;
3444
import org.springframework.ai.model.function.FunctionCallback;
@@ -47,6 +57,7 @@
4757
import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletionMessage.ToolCall;
4858
import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletionRequest;
4959
import org.springframework.ai.zhipuai.api.ZhiPuAiApi.FunctionTool;
60+
import org.springframework.ai.zhipuai.api.ZhiPuApiConstants;
5061
import org.springframework.ai.zhipuai.metadata.ZhiPuAiUsage;
5162
import org.springframework.http.ResponseEntity;
5263
import org.springframework.retry.support.RetryTemplate;
@@ -78,6 +89,8 @@ public class ZhiPuAiChatModel extends AbstractToolCallSupport implements ChatMod
7889

7990
private static final Logger logger = LoggerFactory.getLogger(ZhiPuAiChatModel.class);
8091

92+
private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention();
93+
8194
/**
8295
* The default options used for the chat completion requests.
8396
*/
@@ -93,6 +106,16 @@ public class ZhiPuAiChatModel extends AbstractToolCallSupport implements ChatMod
93106
*/
94107
private final ZhiPuAiApi zhiPuAiApi;
95108

109+
/**
110+
* Observation registry used for instrumentation.
111+
*/
112+
private final ObservationRegistry observationRegistry;
113+
114+
/**
115+
* Conventions to use for generating observations.
116+
*/
117+
private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
118+
96119
/**
97120
* Creates an instance of the ZhiPuAiChatModel.
98121
* @param zhiPuAiApi The ZhiPuAiApi instance to be used for interacting with the
@@ -124,7 +147,7 @@ public ZhiPuAiChatModel(ZhiPuAiApi zhiPuAiApi, ZhiPuAiChatOptions options) {
124147
*/
125148
public ZhiPuAiChatModel(ZhiPuAiApi zhiPuAiApi, ZhiPuAiChatOptions options,
126149
FunctionCallbackContext functionCallbackContext, RetryTemplate retryTemplate) {
127-
this(zhiPuAiApi, options, functionCallbackContext, List.of(), retryTemplate);
150+
this(zhiPuAiApi, options, functionCallbackContext, List.of(), retryTemplate, ObservationRegistry.NOOP);
128151
}
129152

130153
/**
@@ -135,58 +158,77 @@ public ZhiPuAiChatModel(ZhiPuAiApi zhiPuAiApi, ZhiPuAiChatOptions options,
135158
* @param functionCallbackContext The function callback context.
136159
* @param toolFunctionCallbacks The tool function callbacks.
137160
* @param retryTemplate The retry template.
161+
* @param observationRegistry The ObservationRegistry used for instrumentation.
138162
*/
139163
public ZhiPuAiChatModel(ZhiPuAiApi zhiPuAiApi, ZhiPuAiChatOptions options,
140164
FunctionCallbackContext functionCallbackContext, List<FunctionCallback> toolFunctionCallbacks,
141-
RetryTemplate retryTemplate) {
165+
RetryTemplate retryTemplate, ObservationRegistry observationRegistry) {
142166
super(functionCallbackContext, options, toolFunctionCallbacks);
143167
Assert.notNull(zhiPuAiApi, "ZhiPuAiApi must not be null");
144168
Assert.notNull(options, "Options must not be null");
145169
Assert.notNull(retryTemplate, "RetryTemplate must not be null");
146170
Assert.isTrue(CollectionUtils.isEmpty(options.getFunctionCallbacks()),
147171
"The default function callbacks must be set via the toolFunctionCallbacks constructor parameter");
172+
Assert.notNull(observationRegistry, "ObservationRegistry must not be null");
148173
this.zhiPuAiApi = zhiPuAiApi;
149174
this.defaultOptions = options;
150175
this.retryTemplate = retryTemplate;
176+
this.observationRegistry = observationRegistry;
151177
}
152178

153179
@Override
154180
public ChatResponse call(Prompt prompt) {
155181
ChatCompletionRequest request = createRequest(prompt, false);
156182

157-
ResponseEntity<ChatCompletion> completionEntity = this.retryTemplate
158-
.execute(ctx -> this.zhiPuAiApi.chatCompletionEntity(request));
183+
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
184+
.prompt(prompt)
185+
.provider(ZhiPuApiConstants.PROVIDER_NAME)
186+
.requestOptions(buildRequestOptions(request))
187+
.build();
159188

160-
var chatCompletion = completionEntity.getBody();
189+
ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION
190+
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
191+
this.observationRegistry)
192+
.observe(() -> {
161193

162-
if (chatCompletion == null) {
163-
logger.warn("No chat completion returned for prompt: {}", prompt);
164-
return new ChatResponse(List.of());
165-
}
194+
ResponseEntity<ChatCompletion> completionEntity = this.retryTemplate
195+
.execute(ctx -> this.zhiPuAiApi.chatCompletionEntity(request));
166196

167-
List<Choice> choices = chatCompletion.choices();
197+
var chatCompletion = completionEntity.getBody();
168198

169-
List<Generation> generations = choices.stream().map(choice -> {
170-
// @formatter:off
171-
Map<String, Object> metadata = Map.of(
172-
"id", chatCompletion.id(),
173-
"role", choice.message().role() != null ? choice.message().role().name() : "",
174-
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "");
175-
// @formatter:on
176-
return buildGeneration(choice, metadata);
177-
}).toList();
199+
if (chatCompletion == null) {
200+
logger.warn("No chat completion returned for prompt: {}", prompt);
201+
return new ChatResponse(List.of());
202+
}
203+
204+
List<Choice> choices = chatCompletion.choices();
178205

179-
ChatResponse chatResponse = new ChatResponse(generations, from(completionEntity.getBody()));
206+
List<Generation> generations = choices.stream().map(choice -> {
207+
// @formatter:off
208+
Map<String, Object> metadata = Map.of(
209+
"id", chatCompletion.id(),
210+
"role", choice.message().role() != null ? choice.message().role().name() : "",
211+
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : ""
212+
);
213+
// @formatter:on
214+
return buildGeneration(choice, metadata);
215+
}).toList();
180216

181-
if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(chatResponse,
217+
ChatResponse chatResponse = new ChatResponse(generations, from(completionEntity.getBody()));
218+
219+
observationContext.setResponse(chatResponse);
220+
221+
return chatResponse;
222+
});
223+
if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(response,
182224
Set.of(ChatCompletionFinishReason.TOOL_CALLS.name(), ChatCompletionFinishReason.STOP.name()))) {
183-
var toolCallConversation = handleToolCalls(prompt, chatResponse);
225+
var toolCallConversation = handleToolCalls(prompt, response);
184226
// Recursively call the call method with the tool call message
185227
// conversation that contains the call responses.
186228
return this.call(new Prompt(toolCallConversation, prompt.getOptions()));
187229
}
188230

189-
return chatResponse;
231+
return response;
190232
}
191233

192234
@Override
@@ -196,72 +238,87 @@ public ChatOptions getDefaultOptions() {
196238

197239
@Override
198240
public Flux<ChatResponse> stream(Prompt prompt) {
199-
ChatCompletionRequest request = createRequest(prompt, true);
241+
return Flux.deferContextual(contextView -> {
242+
ChatCompletionRequest request = createRequest(prompt, true);
200243

201-
Flux<ChatCompletionChunk> completionChunks = this.retryTemplate
202-
.execute(ctx -> this.zhiPuAiApi.chatCompletionStream(request));
244+
Flux<ChatCompletionChunk> completionChunks = this.retryTemplate
245+
.execute(ctx -> this.zhiPuAiApi.chatCompletionStream(request));
203246

204-
// For chunked responses, only the first chunk contains the choice role.
205-
// The rest of the chunks with same ID share the same role.
206-
ConcurrentHashMap<String, String> roleMap = new ConcurrentHashMap<>();
247+
// For chunked responses, only the first chunk contains the choice role.
248+
// The rest of the chunks with same ID share the same role.
249+
ConcurrentHashMap<String, String> roleMap = new ConcurrentHashMap<>();
207250

208-
// Convert the ChatCompletionChunk into a ChatCompletion to be able to reuse
209-
// the function call handling logic.
210-
Flux<ChatResponse> chatResponse = completionChunks.map(this::chunkToChatCompletion)
211-
.switchMap(chatCompletion -> Mono.just(chatCompletion).map(chatCompletion2 -> {
212-
try {
213-
@SuppressWarnings("null")
214-
String id = chatCompletion2.id();
251+
final ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
252+
.prompt(prompt)
253+
.provider(ZhiPuApiConstants.PROVIDER_NAME)
254+
.requestOptions(buildRequestOptions(request))
255+
.build();
215256

216-
// @formatter:off
257+
Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(
258+
this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
259+
this.observationRegistry);
260+
261+
observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start();
262+
263+
Flux<ChatResponse> chatResponse = completionChunks.map(this::chunkToChatCompletion)
264+
.switchMap(chatCompletion -> Mono.just(chatCompletion).map(chatCompletion2 -> {
265+
try {
266+
String id = chatCompletion2.id();
267+
268+
// @formatter:off
217269
List<Generation> generations = chatCompletion2.choices().stream().map(choice -> {
218270
if (choice.message().role() != null) {
219271
roleMap.putIfAbsent(id, choice.message().role().name());
220272
}
221273
Map<String, Object> metadata = Map.of(
222-
"id", chatCompletion2.id(),
223-
"role", roleMap.getOrDefault(id, ""),
224-
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "");
274+
"id", chatCompletion2.id(),
275+
"role", roleMap.getOrDefault(id, ""),
276+
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : ""
277+
);
225278
return buildGeneration(choice, metadata);
226279
}).toList();
227280
// @formatter:on
228281

229-
if (chatCompletion2.usage() != null) {
230282
return new ChatResponse(generations, from(chatCompletion2));
231283
}
232-
else {
233-
return new ChatResponse(generations);
284+
catch (Exception e) {
285+
logger.error("Error processing chat completion", e);
286+
return new ChatResponse(List.of());
234287
}
235-
}
236-
catch (Exception e) {
237-
logger.error("Error processing chat completion", e);
238-
return new ChatResponse(List.of());
239-
}
240288

241-
}));
289+
}));
242290

243-
return chatResponse.flatMap(response -> {
244-
245-
if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(response,
246-
Set.of(ChatCompletionFinishReason.TOOL_CALLS.name(), ChatCompletionFinishReason.STOP.name()))) {
247-
var toolCallConversation = handleToolCalls(prompt, response);
248-
// Recursively call the stream method with the tool call message
249-
// conversation that contains the call responses.
250-
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
251-
}
252-
else {
291+
// @formatter:off
292+
Flux<ChatResponse> flux = chatResponse.flatMap(response -> {
293+
if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(response, Set.of(ChatCompletionFinishReason.TOOL_CALLS.name(), ChatCompletionFinishReason.STOP.name()))) {
294+
var toolCallConversation = handleToolCalls(prompt, response);
295+
// Recursively call the stream method with the tool call message
296+
// conversation that contains the call responses.
297+
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
298+
}
253299
return Flux.just(response);
254-
}
300+
}).doOnError(observation::error).doFinally(s -> {
301+
// TODO: Consider a custom ObservationContext and
302+
// include additional metadata
303+
// if (s == SignalType.CANCEL) {
304+
// observationContext.setAborted(true);
305+
// }
306+
observation.stop();
307+
}).contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
308+
// @formatter:on
309+
310+
return new MessageAggregator().aggregate(flux, observationContext::setResponse);
255311
});
256312
}
257313

258314
private ChatResponseMetadata from(ChatCompletion result) {
259315
Assert.notNull(result, "ZhiPuAI ChatCompletionResult must not be null");
260316
return ChatResponseMetadata.builder()
261-
.withId(result.id())
262-
.withUsage(ZhiPuAiUsage.from(result.usage()))
263-
.withModel(result.model())
264-
.withKeyValue("created", result.created())
317+
.withId(result.id() != null ? result.id() : "")
318+
.withUsage(result.usage() != null ? ZhiPuAiUsage.from(result.usage()) : new EmptyUsage())
319+
.withModel(result.model() != null ? result.model() : "")
320+
.withKeyValue("created", result.created() != null ? result.created() : 0L)
321+
.withKeyValue("system-fingerprint", result.systemFingerprint() != null ? result.systemFingerprint() : "")
265322
.build();
266323
}
267324

@@ -406,6 +463,16 @@ else if (mediaContentData instanceof String text) {
406463
}
407464
}
408465

466+
private ChatOptions buildRequestOptions(ZhiPuAiApi.ChatCompletionRequest request) {
467+
return ChatOptionsBuilder.builder()
468+
.withModel(request.model())
469+
.withMaxTokens(request.maxTokens())
470+
.withStopSequences(request.stop())
471+
.withTemperature(request.temperature())
472+
.withTopP(request.topP())
473+
.build();
474+
}
475+
409476
private List<FunctionTool> getFunctionTools(Set<String> functionNames) {
410477
return this.resolveFunctionCallbacks(functionNames).stream().map(functionCallback -> {
411478
var function = new FunctionTool.Function(functionCallback.getDescription(), functionCallback.getName(),
@@ -414,4 +481,8 @@ private List<FunctionTool> getFunctionTools(Set<String> functionNames) {
414481
}).toList();
415482
}
416483

484+
public void setObservationConvention(ChatModelObservationConvention observationConvention) {
485+
this.observationConvention = observationConvention;
486+
}
487+
417488
}

models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuApiConstants.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package org.springframework.ai.zhipuai.api;
22

3+
import org.springframework.ai.observation.conventions.AiProvider;
4+
35
/**
46
* Common value constants for ZhiPu api.
57
*
@@ -10,4 +12,6 @@ public final class ZhiPuApiConstants {
1012

1113
public static final String DEFAULT_BASE_URL = "https://open.bigmodel.cn/api/paas";
1214

15+
public static final String PROVIDER_NAME = AiProvider.ZHIPUAI.value();
16+
1317
}

models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiRetryTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ public void zhiPuAiChatStreamTransientError() {
164164
public void zhiPuAiChatStreamNonTransientError() {
165165
when(zhiPuAiApi.chatCompletionStream(isA(ChatCompletionRequest.class)))
166166
.thenThrow(new RuntimeException("Non Transient Error"));
167-
assertThrows(RuntimeException.class, () -> chatModel.stream(new Prompt("text")));
167+
assertThrows(RuntimeException.class, () -> chatModel.stream(new Prompt("text")).collectList().block());
168168
}
169169

170170
@Test

0 commit comments

Comments
 (0)