Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import com.azure.ai.openai.OpenAIClientBuilder;
import com.azure.ai.openai.models.*;
import com.azure.core.util.BinaryData;
import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationRegistry;
import org.springframework.ai.azure.openai.metadata.AzureOpenAiUsage;
import org.springframework.ai.chat.messages.AssistantMessage;
Expand All @@ -37,6 +38,7 @@
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.MessageAggregator;
import org.springframework.ai.chat.observation.ChatModelObservationContext;
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
Expand All @@ -51,8 +53,9 @@
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;

import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.util.ArrayList;
import java.util.Base64;
Expand All @@ -62,6 +65,7 @@
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;

/**
Expand Down Expand Up @@ -189,51 +193,83 @@ && isToolCall(response, Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS
@Override
public Flux<ChatResponse> stream(Prompt prompt) {

ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt);
options.setStream(true);

Flux<ChatCompletions> chatCompletionsStream = this.openAIAsyncClient
.getChatCompletionsStream(options.getModel(), options);

final var isFunctionCall = new AtomicBoolean(false);
final Flux<ChatCompletions> accessibleChatCompletionsFlux = chatCompletionsStream
// Note: the first chat completions can be ignored when using Azure OpenAI
// service which is a known service bug.
.filter(chatCompletions -> !CollectionUtils.isEmpty(chatCompletions.getChoices()))
.map(chatCompletions -> {
final var toolCalls = chatCompletions.getChoices().get(0).getDelta().getToolCalls();
isFunctionCall.set(toolCalls != null && !toolCalls.isEmpty());
return chatCompletions;
})
.windowUntil(chatCompletions -> {
if (isFunctionCall.get() && chatCompletions.getChoices()
.get(0)
.getFinishReason() == CompletionsFinishReason.TOOL_CALLS) {
isFunctionCall.set(false);
return true;
return Flux.deferContextual(contextView -> {
ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt);
options.setStream(true);

Flux<ChatCompletions> chatCompletionsStream = this.openAIAsyncClient
.getChatCompletionsStream(options.getModel(), options);

// For chunked responses, only the first chunk contains the choice role.
// The rest of the chunks with same ID share the same role.
ConcurrentHashMap<String, String> roleMap = new ConcurrentHashMap<>();

ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
.prompt(prompt)
.provider(AiProvider.AZURE_OPENAI.value())
.requestOptions(prompt.getOptions() != null ? prompt.getOptions() : this.defaultOptions)
.build();

Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(
this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
this.observationRegistry);

observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start();

final var isFunctionCall = new AtomicBoolean(false);

final Flux<ChatCompletions> accessibleChatCompletionsFlux = chatCompletionsStream
// Note: the first chat completions can be ignored when using Azure OpenAI
// service which is a known service bug.
.filter(chatCompletions -> !CollectionUtils.isEmpty(chatCompletions.getChoices()))
.map(chatCompletions -> {
final var toolCalls = chatCompletions.getChoices().get(0).getDelta().getToolCalls();
isFunctionCall.set(toolCalls != null && !toolCalls.isEmpty());
return chatCompletions;
})
.windowUntil(chatCompletions -> {
if (isFunctionCall.get() && chatCompletions.getChoices()
.get(0)
.getFinishReason() == CompletionsFinishReason.TOOL_CALLS) {
isFunctionCall.set(false);
return true;
}
return !isFunctionCall.get();
})
.concatMapIterable(window -> {
final var reduce = window.reduce(MergeUtils.emptyChatCompletions(),
MergeUtils::mergeChatCompletions);
return List.of(reduce);
})
.flatMap(mono -> mono);

return accessibleChatCompletionsFlux.switchMap(chatCompletions -> {

ChatResponse chatResponse = toChatResponse(chatCompletions);

if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(chatResponse,
Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS).toLowerCase()))) {
var toolCallConversation = handleToolCalls(prompt, chatResponse);
// Recursively call the call method with the tool call message
// conversation that contains the call responses.
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
}
return !isFunctionCall.get();
})
.concatMapIterable(window -> {
final var reduce = window.reduce(MergeUtils.emptyChatCompletions(), MergeUtils::mergeChatCompletions);
return List.of(reduce);
})
.flatMap(mono -> mono);

return accessibleChatCompletionsFlux.switchMap(chatCompletions -> {

ChatResponse chatResponse = toChatResponse(chatCompletions);

if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(chatResponse,
Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS).toLowerCase()))) {
var toolCallConversation = handleToolCalls(prompt, chatResponse);
// Recursively call the call method with the tool call message
// conversation that contains the call responses.
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
}

return Mono.just(chatResponse);
Flux<ChatResponse> flux = Flux.just(chatResponse).doOnError(observation::error).doFinally(s -> {
// TODO: Consider a custom ObservationContext and
// include additional metadata
// if (s == SignalType.CANCEL) {
// observationContext.setAborted(true);
// }
observation.stop();
}).contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
// @formatter:on

return new MessageAggregator().aggregate(flux, observationContext::setResponse);
});

});

}

private ChatResponse toChatResponse(ChatCompletions chatCompletions) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
import static org.assertj.core.api.Assertions.assertThat;

import java.util.List;
import java.util.stream.Collectors;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;

Expand All @@ -42,6 +44,7 @@
import io.micrometer.common.KeyValue;
import io.micrometer.observation.tck.TestObservationRegistry;
import io.micrometer.observation.tck.TestObservationRegistryAssert;
import reactor.core.publisher.Flux;

/**
* @author Soby Chacko
Expand All @@ -57,6 +60,11 @@ class AzureOpenAiChatModelObservationIT {
@Autowired
TestObservationRegistry observationRegistry;

@BeforeEach
void beforeEach() {
observationRegistry.clear();
}

@Test
void observationForImperativeChatOperation() {

Expand All @@ -77,22 +85,63 @@ void observationForImperativeChatOperation() {
ChatResponseMetadata responseMetadata = chatResponse.getMetadata();
assertThat(responseMetadata).isNotNull();

validate(responseMetadata);
validate(responseMetadata, true);
}

@Test
void observationForStreamingChatOperation() {

var options = AzureOpenAiChatOptions.builder()
.withFrequencyPenalty(0.0)
.withDeploymentName("gpt-4o")
.withMaxTokens(2048)
.withPresencePenalty(0.0)
.withStop(List.of("this-is-the-end"))
.withTemperature(0.7)
.withTopP(1.0)
.build();

Prompt prompt = new Prompt("Why does a raven look like a desk?", options);

Flux<ChatResponse> chatResponseFlux = chatModel.stream(prompt);
List<ChatResponse> responses = chatResponseFlux.collectList().block();
assertThat(responses).isNotEmpty();
assertThat(responses).hasSizeGreaterThan(10);

String aggregatedResponse = responses.subList(0, responses.size() - 1)
.stream()
.map(r -> r.getResult().getOutput().getContent())
.collect(Collectors.joining());
assertThat(aggregatedResponse).isNotEmpty();

ChatResponse lastChatResponse = responses.get(responses.size() - 1);

ChatResponseMetadata responseMetadata = lastChatResponse.getMetadata();
assertThat(responseMetadata).isNotNull();

validate(responseMetadata, false);
}

private void validate(ChatResponseMetadata responseMetadata) {
TestObservationRegistryAssert.assertThat(observationRegistry)
private void validate(ChatResponseMetadata responseMetadata, boolean checkModel) {

TestObservationRegistryAssert.That that = TestObservationRegistryAssert.assertThat(observationRegistry)
.doesNotHaveAnyRemainingCurrentObservation()
.hasObservationWithNameEqualTo(DefaultChatModelObservationConvention.DEFAULT_NAME)
.that()
.hasObservationWithNameEqualTo(DefaultChatModelObservationConvention.DEFAULT_NAME);

// TODO - Investigate why streaming does not contain model in the response.
if (checkModel) {
that.that()
.hasLowCardinalityKeyValue(
ChatModelObservationDocumentation.LowCardinalityKeyNames.RESPONSE_MODEL.asString(),
responseMetadata.getModel());
}

that.that()
.hasLowCardinalityKeyValue(
ChatModelObservationDocumentation.LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(),
AiOperationType.CHAT.value())
.hasLowCardinalityKeyValue(ChatModelObservationDocumentation.LowCardinalityKeyNames.AI_PROVIDER.asString(),
AiProvider.AZURE_OPENAI.value())
.hasLowCardinalityKeyValue(
ChatModelObservationDocumentation.LowCardinalityKeyNames.RESPONSE_MODEL.asString(),
responseMetadata.getModel())
.hasHighCardinalityKeyValue(
ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_FREQUENCY_PENALTY.asString(),
"0.0")
Expand Down