Skip to content

Commit 8969261

Browse files
committed
Only construct Observation.Context for Chat and Vector ops on Model observations.
* Simplify and polish code in OpenAI using the API with Spring constructs and Builders. * Simplify common code expressions used in AI provider ChatModels with ValueUtils. * Apply whitespace to improve readability. Closes #1661
1 parent 0d2d4b7 commit 8969261

File tree

16 files changed

+945
-405
lines changed

16 files changed

+945
-405
lines changed

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.util.List;
2323
import java.util.Map;
2424
import java.util.Set;
25+
import java.util.function.Supplier;
2526
import java.util.stream.Collectors;
2627

2728
import io.micrometer.observation.Observation;
@@ -54,10 +55,12 @@
5455
import org.springframework.ai.chat.observation.ChatModelObservationContext;
5556
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
5657
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
58+
import org.springframework.ai.chat.observation.ChatModelObservationSupport;
5759
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
5860
import org.springframework.ai.chat.prompt.ChatOptions;
5961
import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
6062
import org.springframework.ai.chat.prompt.Prompt;
63+
import org.springframework.ai.model.Content;
6164
import org.springframework.ai.model.ModelOptionsUtils;
6265
import org.springframework.ai.model.function.FunctionCallback;
6366
import org.springframework.ai.model.function.FunctionCallbackContext;
@@ -77,6 +80,7 @@
7780
* @author Mariusz Bernacki
7881
* @author Thomas Vitale
7982
* @author Claudio Silva Junior
83+
* @author John Blum
8084
* @since 1.0.0
8185
*/
8286
public class AnthropicChatModel extends AbstractToolCallSupport implements ChatModel {
@@ -209,28 +213,31 @@ public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaul
209213

210214
@Override
211215
public ChatResponse call(Prompt prompt) {
216+
212217
ChatCompletionRequest request = createRequest(prompt, false);
213218

214-
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
219+
Supplier<ChatModelObservationContext> observationContext = () -> ChatModelObservationContext.builder()
215220
.prompt(prompt)
216221
.provider(AnthropicApi.PROVIDER_NAME)
217222
.requestOptions(buildRequestOptions(request))
218223
.build();
219224

220-
ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION
221-
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
222-
this.observationRegistry)
223-
.observe(() -> {
225+
Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(
226+
this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, observationContext,
227+
this.observationRegistry);
228+
229+
ChatResponse response = observation.observe(() -> {
224230

225-
ResponseEntity<ChatCompletionResponse> completionEntity = this.retryTemplate
226-
.execute(ctx -> this.anthropicApi.chatCompletionEntity(request));
231+
ResponseEntity<ChatCompletionResponse> completionEntity = this.retryTemplate
232+
.execute(ctx -> this.anthropicApi.chatCompletionEntity(request));
227233

228-
ChatResponse chatResponse = toChatResponse(completionEntity.getBody());
234+
ChatResponse chatResponse = toChatResponse(completionEntity.getBody());
229235

230-
observationContext.setResponse(chatResponse);
236+
ChatModelObservationSupport.getObservationContext(observation)
237+
.ifPresent(context -> context.setResponse(chatResponse));
231238

232-
return chatResponse;
233-
});
239+
return chatResponse;
240+
});
234241

235242
if (!isProxyToolCalls(prompt, this.defaultOptions) && response != null
236243
&& this.isToolCall(response, Set.of("tool_use"))) {
@@ -243,17 +250,19 @@ public ChatResponse call(Prompt prompt) {
243250

244251
@Override
245252
public Flux<ChatResponse> stream(Prompt prompt) {
253+
246254
return Flux.deferContextual(contextView -> {
255+
247256
ChatCompletionRequest request = createRequest(prompt, true);
248257

249-
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
250-
.prompt(prompt)
251-
.provider(AnthropicApi.PROVIDER_NAME)
258+
Supplier<ChatModelObservationContext> observationContext = () -> ChatModelObservationContext.builder()
252259
.requestOptions(buildRequestOptions(request))
260+
.provider(AnthropicApi.PROVIDER_NAME)
261+
.prompt(prompt)
253262
.build();
254263

255264
Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(
256-
this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
265+
this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, observationContext,
257266
this.observationRegistry);
258267

259268
observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start();
@@ -276,7 +285,8 @@ public Flux<ChatResponse> stream(Prompt prompt) {
276285
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
277286
// @formatter:on
278287

279-
return new MessageAggregator().aggregate(chatResponseFlux, observationContext::setResponse);
288+
return new MessageAggregator().aggregate(chatResponseFlux,
289+
ChatModelObservationSupport.setChatResponseInObservationContext(observation));
280290
});
281291
}
282292

@@ -408,7 +418,7 @@ else if (message.getMessageType() == MessageType.TOOL) {
408418
String systemPrompt = prompt.getInstructions()
409419
.stream()
410420
.filter(m -> m.getMessageType() == MessageType.SYSTEM)
411-
.map(m -> m.getContent())
421+
.map(Content::getContent)
412422
.collect(Collectors.joining(System.lineSeparator()));
413423

414424
ChatCompletionRequest request = new ChatCompletionRequest(this.defaultOptions.getModel(), userMessages,

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
import java.util.Optional;
2626
import java.util.Set;
2727
import java.util.concurrent.ConcurrentHashMap;
28+
import java.util.concurrent.ConcurrentMap;
2829
import java.util.concurrent.atomic.AtomicBoolean;
30+
import java.util.function.Supplier;
2931

3032
import com.azure.ai.openai.OpenAIAsyncClient;
3133
import com.azure.ai.openai.OpenAIClient;
@@ -78,6 +80,7 @@
7880
import org.springframework.ai.chat.observation.ChatModelObservationContext;
7981
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
8082
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
83+
import org.springframework.ai.chat.observation.ChatModelObservationSupport;
8184
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
8285
import org.springframework.ai.chat.prompt.ChatOptions;
8386
import org.springframework.ai.chat.prompt.Prompt;
@@ -87,6 +90,7 @@
8790
import org.springframework.ai.model.function.FunctionCallbackContext;
8891
import org.springframework.ai.model.function.FunctionCallingOptions;
8992
import org.springframework.ai.observation.conventions.AiProvider;
93+
import org.springframework.ai.util.ValueUtils;
9094
import org.springframework.util.Assert;
9195
import org.springframework.util.CollectionUtils;
9296

@@ -195,24 +199,24 @@ public AzureOpenAiChatOptions getDefaultOptions() {
195199
@Override
196200
public ChatResponse call(Prompt prompt) {
197201

198-
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
202+
Supplier<ChatModelObservationContext> observationContext = () -> ChatModelObservationContext.builder()
199203
.prompt(prompt)
200204
.provider(AiProvider.AZURE_OPENAI.value())
201205
.requestOptions(prompt.getOptions() != null ? prompt.getOptions() : this.defaultOptions)
202206
.build();
203207

204-
ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION
205-
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
206-
this.observationRegistry)
207-
.observe(() -> {
208-
ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt);
209-
options.setStream(false);
210-
211-
ChatCompletions chatCompletions = this.openAIClient.getChatCompletions(options.getModel(), options);
212-
ChatResponse chatResponse = toChatResponse(chatCompletions);
213-
observationContext.setResponse(chatResponse);
214-
return chatResponse;
215-
});
208+
Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(
209+
this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, observationContext,
210+
this.observationRegistry);
211+
212+
ChatResponse response = observation.observe(() -> {
213+
ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt).setStream(false);
214+
ChatCompletions chatCompletions = this.openAIClient.getChatCompletions(options.getModel(), options);
215+
ChatResponse chatResponse = toChatResponse(chatCompletions);
216+
ChatModelObservationSupport.getObservationContext(observation)
217+
.ifPresent(context -> context.setResponse(chatResponse));
218+
return chatResponse;
219+
});
216220

217221
if (!isProxyToolCalls(prompt, this.defaultOptions)
218222
&& isToolCall(response, Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS).toLowerCase()))) {
@@ -229,24 +233,28 @@ && isToolCall(response, Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS
229233
public Flux<ChatResponse> stream(Prompt prompt) {
230234

231235
return Flux.deferContextual(contextView -> {
232-
ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt);
233-
options.setStream(true);
236+
237+
ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt).setStream(true);
234238

235239
Flux<ChatCompletions> chatCompletionsStream = this.openAIAsyncClient
236240
.getChatCompletionsStream(options.getModel(), options);
237241

242+
// @formatter:off
238243
// For chunked responses, only the first chunk contains the choice role.
239244
// The rest of the chunks with same ID share the same role.
240-
ConcurrentHashMap<String, String> roleMap = new ConcurrentHashMap<>();
245+
// TODO: Why is roleMap not used? I am guessing it should have served the same
246+
// purpose as the roleMap in OpenAiChatModel.stream(:Prompt)
247+
// @formatter:on
248+
ConcurrentMap<String, String> roleMap = new ConcurrentHashMap<>();
241249

242-
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
243-
.prompt(prompt)
250+
Supplier<ChatModelObservationContext> observationContext = () -> ChatModelObservationContext.builder()
251+
.requestOptions(ValueUtils.defaultIfNull(prompt.getOptions(), this.defaultOptions))
244252
.provider(AiProvider.AZURE_OPENAI.value())
245-
.requestOptions(prompt.getOptions() != null ? prompt.getOptions() : this.defaultOptions)
253+
.prompt(prompt)
246254
.build();
247255

248256
Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(
249-
this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
257+
this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, observationContext,
250258
this.observationRegistry);
251259

252260
observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start();
@@ -295,7 +303,8 @@ public Flux<ChatResponse> stream(Prompt prompt) {
295303
.doFinally(s -> observation.stop())
296304
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
297305

298-
return new MessageAggregator().aggregate(flux, observationContext::setResponse);
306+
return new MessageAggregator().aggregate(flux,
307+
ChatModelObservationSupport.setChatResponseInObservationContext(observation));
299308
});
300309

301310
});

0 commit comments

Comments
 (0)