Skip to content

Commit 480a390

Browse files
committed
feat: add observability to qianfan chat model
1 parent 1606383 commit 480a390

File tree

12 files changed

+368
-50
lines changed

12 files changed

+368
-50
lines changed

models/spring-ai-qianfan/pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,12 @@
5454
<scope>test</scope>
5555
</dependency>
5656

57+
<dependency>
58+
<groupId>io.micrometer</groupId>
59+
<artifactId>micrometer-observation-test</artifactId>
60+
<scope>test</scope>
61+
</dependency>
62+
5763
</dependencies>
5864

5965
</project>

models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatModel.java

Lines changed: 137 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,25 @@
1515
*/
1616
package org.springframework.ai.qianfan;
1717

18+
import io.micrometer.observation.Observation;
19+
import io.micrometer.observation.ObservationRegistry;
20+
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
1821
import org.slf4j.Logger;
1922
import org.slf4j.LoggerFactory;
23+
import org.springframework.ai.chat.messages.AssistantMessage;
24+
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
25+
import org.springframework.ai.chat.metadata.EmptyUsage;
2026
import org.springframework.ai.chat.model.ChatModel;
2127
import org.springframework.ai.chat.model.ChatResponse;
2228
import org.springframework.ai.chat.model.Generation;
29+
import org.springframework.ai.chat.model.MessageAggregator;
2330
import org.springframework.ai.chat.model.StreamingChatModel;
31+
import org.springframework.ai.chat.observation.ChatModelObservationContext;
32+
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
33+
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
34+
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
2435
import org.springframework.ai.chat.prompt.ChatOptions;
36+
import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
2537
import org.springframework.ai.chat.prompt.Prompt;
2638
import org.springframework.ai.model.ModelOptionsUtils;
2739
import org.springframework.ai.qianfan.api.QianFanApi;
@@ -30,11 +42,14 @@
3042
import org.springframework.ai.qianfan.api.QianFanApi.ChatCompletionMessage;
3143
import org.springframework.ai.qianfan.api.QianFanApi.ChatCompletionMessage.Role;
3244
import org.springframework.ai.qianfan.api.QianFanApi.ChatCompletionRequest;
45+
import org.springframework.ai.qianfan.api.QianFanConstants;
46+
import org.springframework.ai.qianfan.metadata.QianFanUsage;
3347
import org.springframework.ai.retry.RetryUtils;
3448
import org.springframework.http.ResponseEntity;
3549
import org.springframework.retry.support.RetryTemplate;
3650
import org.springframework.util.Assert;
3751
import reactor.core.publisher.Flux;
52+
import reactor.core.publisher.Mono;
3853

3954
import java.util.Collections;
4055
import java.util.List;
@@ -45,15 +60,17 @@
4560
* backed by {@link QianFanApi}.
4661
*
4762
* @author Geng Rong
48-
* @since 1.0
4963
* @see ChatModel
5064
* @see StreamingChatModel
5165
* @see QianFanApi
66+
* @since 1.0
5267
*/
5368
public class QianFanChatModel implements ChatModel, StreamingChatModel {
5469

5570
private static final Logger logger = LoggerFactory.getLogger(QianFanChatModel.class);
5671

72+
private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention();
73+
5774
/**
5875
* The default options used for the chat completion requests.
5976
*/
@@ -69,6 +86,16 @@ public class QianFanChatModel implements ChatModel, StreamingChatModel {
6986
*/
7087
private final QianFanApi qianFanApi;
7188

89+
/**
90+
* Observation registry used for instrumentation.
91+
*/
92+
private final ObservationRegistry observationRegistry;
93+
94+
/**
95+
* Conventions to use for generating observations.
96+
*/
97+
private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
98+
7299
/**
73100
* Creates an instance of the QianFanChatModel.
74101
* @param qianFanApi The QianFanApi instance to be used for interacting with the
@@ -98,52 +125,113 @@ public QianFanChatModel(QianFanApi qianFanApi, QianFanChatOptions options) {
98125
* @param retryTemplate The retry template.
99126
*/
100127
public QianFanChatModel(QianFanApi qianFanApi, QianFanChatOptions options, RetryTemplate retryTemplate) {
128+
this(qianFanApi, options, retryTemplate, ObservationRegistry.NOOP);
129+
}
130+
131+
/**
132+
* Initializes a new instance of the QianFanChatModel.
133+
* @param qianFanApi The QianFanApi instance to be used for interacting with the
134+
* QianFan Chat API.
135+
* @param options The QianFanChatOptions to configure the chat client.
136+
* @param retryTemplate The retry template.
137+
* @param observationRegistry The ObservationRegistry used for instrumentation.
138+
*/
139+
public QianFanChatModel(QianFanApi qianFanApi, QianFanChatOptions options, RetryTemplate retryTemplate,
140+
ObservationRegistry observationRegistry) {
101141
Assert.notNull(qianFanApi, "QianFanApi must not be null");
102142
Assert.notNull(options, "Options must not be null");
103143
Assert.notNull(retryTemplate, "RetryTemplate must not be null");
144+
Assert.notNull(observationRegistry, "ObservationRegistry must not be null");
104145
this.qianFanApi = qianFanApi;
105146
this.defaultOptions = options;
106147
this.retryTemplate = retryTemplate;
148+
this.observationRegistry = observationRegistry;
107149
}
108150

109151
@Override
110152
public ChatResponse call(Prompt prompt) {
111153

112154
ChatCompletionRequest request = createRequest(prompt, false);
113155

114-
return this.retryTemplate.execute(ctx -> {
156+
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
157+
.prompt(prompt)
158+
.provider(QianFanConstants.PROVIDER_NAME)
159+
.requestOptions(buildRequestOptions(request))
160+
.build();
115161

116-
ResponseEntity<ChatCompletion> completionEntity = this.doChatCompletion(request);
162+
return ChatModelObservationDocumentation.CHAT_MODEL_OPERATION
163+
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
164+
this.observationRegistry)
165+
.observe(() -> {
166+
ResponseEntity<ChatCompletion> completionEntity = this.retryTemplate
167+
.execute(ctx -> this.qianFanApi.chatCompletionEntity(request));
117168

118-
var chatCompletion = completionEntity.getBody();
119-
if (chatCompletion == null) {
120-
logger.warn("No chat completion returned for prompt: {}", prompt);
121-
return new ChatResponse(List.of());
122-
}
169+
var chatCompletion = completionEntity.getBody();
170+
if (chatCompletion == null) {
171+
logger.warn("No chat completion returned for prompt: {}", prompt);
172+
return new ChatResponse(List.of());
173+
}
123174

124-
// if (chatCompletion.baseResponse() != null &&
125-
// chatCompletion.baseResponse().statusCode() != 0) {
126-
// throw new RuntimeException(chatCompletion.baseResponse().message());
127-
// }
175+
// @formatter:off
176+
Map<String, Object> metadata = Map.of(
177+
"id", chatCompletion.id(),
178+
"role", Role.ASSISTANT
179+
);
180+
// @formatter:on
128181

129-
var generation = new Generation(chatCompletion.result(),
130-
Map.of("id", chatCompletion.id(), "role", Role.ASSISTANT));
131-
return new ChatResponse(Collections.singletonList(generation));
132-
});
182+
var assistantMessage = new AssistantMessage(chatCompletion.result(), metadata);
183+
List<Generation> generations = Collections.singletonList(new Generation(assistantMessage));
184+
ChatResponse chatResponse = new ChatResponse(generations, from(chatCompletion, request.model()));
185+
observationContext.setResponse(chatResponse);
186+
return chatResponse;
187+
});
133188
}
134189

135190
@Override
136191
public Flux<ChatResponse> stream(Prompt prompt) {
137-
var request = createRequest(prompt, true);
138192

139-
return retryTemplate.execute(ctx -> {
193+
return Flux.deferContextual(contextView -> {
194+
ChatCompletionRequest request = createRequest(prompt, true);
195+
140196
var completionChunks = this.qianFanApi.chatCompletionStream(request);
141197

142-
return completionChunks.map(this::toChatCompletion).map(chatCompletion -> {
143-
String id = chatCompletion.id();
144-
var generation = new Generation(chatCompletion.result(), Map.of("id", id, "role", Role.ASSISTANT));
145-
return new ChatResponse(Collections.singletonList(generation));
146-
});
198+
final ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
199+
.prompt(prompt)
200+
.provider(QianFanConstants.PROVIDER_NAME)
201+
.requestOptions(buildRequestOptions(request))
202+
.build();
203+
204+
Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(
205+
this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
206+
this.observationRegistry);
207+
208+
observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start();
209+
210+
Flux<ChatResponse> chatResponse = completionChunks.map(this::toChatCompletion)
211+
.switchMap(chatCompletion -> Mono.just(chatCompletion).map(chatCompletion2 -> {
212+
// @formatter:off
213+
Map<String, Object> metadata = Map.of(
214+
"id", chatCompletion.id(),
215+
"role", Role.ASSISTANT
216+
);
217+
// @formatter:on
218+
219+
var assistantMessage = new AssistantMessage(chatCompletion.result(), metadata);
220+
List<Generation> generations = Collections.singletonList(new Generation(assistantMessage));
221+
return new ChatResponse(generations, from(chatCompletion, request.model()));
222+
}))
223+
.doOnError(observation::error)
224+
.doFinally(s -> {
225+
// TODO: Consider a custom ObservationContext and
226+
// include additional metadata
227+
// if (s == SignalType.CANCEL) {
228+
// observationContext.setAborted(true);
229+
// }
230+
observation.stop();
231+
})
232+
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
233+
return new MessageAggregator().aggregate(chatResponse, observationContext::setResponse);
234+
147235
});
148236
}
149237

@@ -153,7 +241,8 @@ public Flux<ChatResponse> stream(Prompt prompt) {
153241
* @return the ChatCompletion
154242
*/
155243
private ChatCompletion toChatCompletion(ChatCompletionChunk chunk) {
156-
return new ChatCompletion(chunk.id(), chunk.object(), chunk.created(), chunk.result(), chunk.usage());
244+
return new ChatCompletion(chunk.id(), chunk.object(), chunk.created(), chunk.result(), chunk.finishReason(),
245+
chunk.usage());
157246
}
158247

159248
/**
@@ -193,8 +282,30 @@ public ChatOptions getDefaultOptions() {
193282
return QianFanChatOptions.fromOptions(this.defaultOptions);
194283
}
195284

196-
private ResponseEntity<ChatCompletion> doChatCompletion(ChatCompletionRequest request) {
197-
return this.qianFanApi.chatCompletionEntity(request);
285+
private ChatOptions buildRequestOptions(QianFanApi.ChatCompletionRequest request) {
286+
return ChatOptionsBuilder.builder()
287+
.withModel(request.model())
288+
.withFrequencyPenalty(request.frequencyPenalty())
289+
.withMaxTokens(request.maxTokens())
290+
.withPresencePenalty(request.presencePenalty())
291+
.withStopSequences(request.stop())
292+
.withTemperature(request.temperature())
293+
.withTopP(request.topP())
294+
.build();
295+
}
296+
297+
private ChatResponseMetadata from(QianFanApi.ChatCompletion result, String model) {
298+
Assert.notNull(result, "QianFan ChatCompletionResult must not be null");
299+
return ChatResponseMetadata.builder()
300+
.withId(result.id() != null ? result.id() : "")
301+
.withUsage(result.usage() != null ? QianFanUsage.from(result.usage()) : new EmptyUsage())
302+
.withModel(model)
303+
.withKeyValue("created", result.created() != null ? result.created() : 0L)
304+
.build();
305+
}
306+
307+
public void setObservationConvention(ChatModelObservationConvention observationConvention) {
308+
this.observationConvention = observationConvention;
198309
}
199310

200311
}

models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatOptions.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ public class QianFanChatOptions implements ChatOptions {
5151
* The maximum number of tokens to generate in the chat completion. The total length of input
5252
* tokens and generated tokens is limited by the model's context length.
5353
*/
54-
private @JsonProperty("max_tokens") Integer maxTokens;
54+
private @JsonProperty("max_output_tokens") Integer maxTokens;
5555
/**
5656
* Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they
5757
* appear in the text so far, increasing the model's likelihood to talk about new topics.

models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanApi.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ public class QianFanApi extends AuthApi {
6060
* @param secretKey QianFan secret key.
6161
*/
6262
public QianFanApi(String apiKey, String secretKey) {
63-
this(ApiUtils.DEFAULT_BASE_URL, apiKey, secretKey);
63+
this(QianFanConstants.DEFAULT_BASE_URL, apiKey, secretKey);
6464
}
6565

6666
/**
@@ -110,18 +110,18 @@ public QianFanApi(String baseUrl, String apiKey, String secretKey, RestClient.Bu
110110
* @param responseErrorHandler Response error handler.
111111
*/
112112
public QianFanApi(String baseUrl, String apiKey, String secretKey, RestClient.Builder restClientBuilder,
113-
WebClient.Builder webClientBuilder,ResponseErrorHandler responseErrorHandler) {
113+
WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) {
114114
super(apiKey, secretKey);
115115

116116
this.restClient = restClientBuilder
117117
.baseUrl(baseUrl)
118-
.defaultHeaders(ApiUtils.getJsonContentHeaders())
118+
.defaultHeaders(QianFanUtils.defaultHeaders())
119119
.defaultStatusHandler(responseErrorHandler)
120120
.build();
121121

122122
this.webClient = webClientBuilder
123123
.baseUrl(baseUrl)
124-
.defaultHeaders(ApiUtils.getJsonContentHeaders())
124+
.defaultHeaders(QianFanUtils.defaultHeaders())
125125
.build();
126126
}
127127

@@ -308,6 +308,7 @@ public record ChatCompletion(
308308
@JsonProperty("object") String object,
309309
@JsonProperty("created") Long created,
310310
@JsonProperty("result") String result,
311+
@JsonProperty("finish_reason") String finishReason,
311312
@JsonProperty("usage") Usage usage) {
312313
}
313314

@@ -319,6 +320,7 @@ public record ChatCompletion(
319320
*/
320321
@JsonInclude(Include.NON_NULL)
321322
public record Usage(
323+
@JsonProperty("completion_tokens") Integer completionTokens,
322324
@JsonProperty("prompt_tokens") Integer promptTokens,
323325
@JsonProperty("total_tokens") Integer totalTokens) {
324326

@@ -339,6 +341,7 @@ public record ChatCompletionChunk(
339341
@JsonProperty("object") String object,
340342
@JsonProperty("created") Long created,
341343
@JsonProperty("result") String result,
344+
@JsonProperty("finish_reason") String finishReason,
342345
@JsonProperty("is_end") Boolean end,
343346

344347
@JsonProperty("usage") Usage usage

models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/ApiUtils.java renamed to models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanConstants.java

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,7 @@
1515
*/
1616
package org.springframework.ai.qianfan.api;
1717

18-
import org.springframework.http.HttpHeaders;
19-
import org.springframework.http.MediaType;
20-
21-
import java.util.function.Consumer;
18+
import org.springframework.ai.observation.conventions.AiProvider;
2219

2320
/**
2421
* The ApiUtils class provides utility methods for working with API requests and
@@ -27,12 +24,10 @@
2724
* @author Geng Rong
2825
* @since 1.0
2926
*/
30-
public class ApiUtils {
27+
public class QianFanConstants {
3128

3229
public static final String DEFAULT_BASE_URL = "https://aip.baidubce.com/rpc/2.0/ai_custom";
3330

34-
public static Consumer<HttpHeaders> getJsonContentHeaders() {
35-
return headers -> headers.setContentType(MediaType.APPLICATION_JSON);
36-
}
31+
public static final String PROVIDER_NAME = AiProvider.QIANFAN.value();
3732

3833
}

models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/api/QianFanImageApi.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ public class QianFanImageApi extends AuthApi {
4444
* @param secretKey QianFan secret key.
4545
*/
4646
public QianFanImageApi(String apiKey, String secretKey) {
47-
this(ApiUtils.DEFAULT_BASE_URL, apiKey, secretKey, RestClient.builder());
47+
this(QianFanConstants.DEFAULT_BASE_URL, apiKey, secretKey, RestClient.builder());
4848
}
4949

5050
/**
@@ -71,7 +71,7 @@ public QianFanImageApi(String baseUrl, String apiKey, String secretKey, RestClie
7171
super(apiKey, secretKey);
7272

7373
this.restClient = restClientBuilder.baseUrl(baseUrl)
74-
.defaultHeaders(ApiUtils.getJsonContentHeaders())
74+
.defaultHeaders(QianFanUtils.defaultHeaders())
7575
.defaultStatusHandler(responseErrorHandler)
7676
.build();
7777
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
package org.springframework.ai.qianfan.api;
2+
3+
import org.springframework.http.HttpHeaders;
4+
import org.springframework.http.MediaType;
5+
6+
import java.util.function.Consumer;
7+
8+
public class QianFanUtils {
9+
10+
public static Consumer<HttpHeaders> defaultHeaders() {
11+
return headers -> headers.setContentType(MediaType.APPLICATION_JSON);
12+
}
13+
14+
}

0 commit comments

Comments
 (0)