diff --git a/models/spring-ai-azure-openai/pom.xml b/models/spring-ai-azure-openai/pom.xml
index d70c098a074..63443562899 100644
--- a/models/spring-ai-azure-openai/pom.xml
+++ b/models/spring-ai-azure-openai/pom.xml
@@ -58,6 +58,12 @@
spring-boot-starter-test
test
+
+
+ io.micrometer
+ micrometer-observation-test
+ test
+
diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java
index 62c4f2198c2..c61b6e26f30 100644
--- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java
+++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java
@@ -16,11 +16,16 @@
package org.springframework.ai.azure.openai;
-import com.azure.ai.openai.OpenAIAsyncClient;
-import com.azure.ai.openai.OpenAIClient;
-import com.azure.ai.openai.OpenAIClientBuilder;
-import com.azure.ai.openai.models.*;
-import com.azure.core.util.BinaryData;
+import java.util.ArrayList;
+import java.util.Base64;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+import java.util.concurrent.atomic.AtomicBoolean;
+
import org.springframework.ai.azure.openai.metadata.AzureOpenAiUsage;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
@@ -36,27 +41,51 @@
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
+import org.springframework.ai.chat.observation.ChatModelObservationContext;
+import org.springframework.ai.chat.observation.ChatModelObservationConvention;
+import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
+import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.Media;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
+import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
+
+import com.azure.ai.openai.OpenAIAsyncClient;
+import com.azure.ai.openai.OpenAIClient;
+import com.azure.ai.openai.OpenAIClientBuilder;
+import com.azure.ai.openai.models.ChatChoice;
+import com.azure.ai.openai.models.ChatCompletions;
+import com.azure.ai.openai.models.ChatCompletionsFunctionToolCall;
+import com.azure.ai.openai.models.ChatCompletionsFunctionToolDefinition;
+import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat;
+import com.azure.ai.openai.models.ChatCompletionsOptions;
+import com.azure.ai.openai.models.ChatCompletionsResponseFormat;
+import com.azure.ai.openai.models.ChatCompletionsTextResponseFormat;
+import com.azure.ai.openai.models.ChatCompletionsToolCall;
+import com.azure.ai.openai.models.ChatCompletionsToolDefinition;
+import com.azure.ai.openai.models.ChatMessageContentItem;
+import com.azure.ai.openai.models.ChatMessageImageContentItem;
+import com.azure.ai.openai.models.ChatMessageImageUrl;
+import com.azure.ai.openai.models.ChatMessageTextContentItem;
+import com.azure.ai.openai.models.ChatRequestAssistantMessage;
+import com.azure.ai.openai.models.ChatRequestMessage;
+import com.azure.ai.openai.models.ChatRequestSystemMessage;
+import com.azure.ai.openai.models.ChatRequestToolMessage;
+import com.azure.ai.openai.models.ChatRequestUserMessage;
+import com.azure.ai.openai.models.CompletionsFinishReason;
+import com.azure.ai.openai.models.ContentFilterResultsForPrompt;
+import com.azure.ai.openai.models.FunctionCall;
+import com.azure.ai.openai.models.FunctionDefinition;
+import com.azure.core.util.BinaryData;
+import io.micrometer.observation.ObservationRegistry;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
-import java.util.ArrayList;
-import java.util.Base64;
-import java.util.Collections;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Map;
-import java.util.Optional;
-import java.util.Set;
-import java.util.concurrent.atomic.AtomicBoolean;
-
/**
* {@link ChatModel} implementation for {@literal Microsoft Azure AI} backed by
* {@link OpenAIClient}.
@@ -81,6 +110,8 @@ public class AzureOpenAiChatModel extends AbstractToolCallSupport implements Cha
private static final Double DEFAULT_TEMPERATURE = 0.7;
+ private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention();
+
/**
* The {@link OpenAIClient} used to interact with the Azure OpenAI service.
*/
@@ -96,8 +127,18 @@ public class AzureOpenAiChatModel extends AbstractToolCallSupport implements Cha
*/
private final AzureOpenAiChatOptions defaultOptions;
- public AzureOpenAiChatModel(OpenAIClientBuilder microsoftOpenAiClient) {
- this(microsoftOpenAiClient,
+ /**
+ * Observation registry used for instrumentation.
+ */
+ private final ObservationRegistry observationRegistry;
+
+ /**
+ * Conventions to use for generating observations.
+ */
+ private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
+
+ public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder) {
+ this(openAIClientBuilder,
AzureOpenAiChatOptions.builder()
.withDeploymentName(DEFAULT_DEPLOYMENT_NAME)
.withTemperature(DEFAULT_TEMPERATURE)
@@ -115,12 +156,19 @@ public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAi
public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiChatOptions options,
FunctionCallbackContext functionCallbackContext, List toolFunctionCallbacks) {
+ this(openAIClientBuilder, options, functionCallbackContext, List.of(), ObservationRegistry.NOOP);
+ }
+
+ public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiChatOptions options,
+ FunctionCallbackContext functionCallbackContext, List toolFunctionCallbacks,
+ ObservationRegistry observationRegistry) {
super(functionCallbackContext, options, toolFunctionCallbacks);
Assert.notNull(openAIClientBuilder, "com.azure.ai.openai.OpenAIClient must not be null");
Assert.notNull(options, "AzureOpenAiChatOptions must not be null");
this.openAIClient = openAIClientBuilder.buildClient();
this.openAIAsyncClient = openAIClientBuilder.buildAsyncClient();
this.defaultOptions = options;
+ this.observationRegistry = observationRegistry;
}
public AzureOpenAiChatOptions getDefaultOptions() {
@@ -130,22 +178,34 @@ public AzureOpenAiChatOptions getDefaultOptions() {
@Override
public ChatResponse call(Prompt prompt) {
- ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt);
- options.setStream(false);
+ ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
+ .prompt(prompt)
+ .provider(AiProvider.AZURE_OPENAI.value())
+ .requestOptions(prompt.getOptions() != null ? prompt.getOptions() : this.defaultOptions)
+ .build();
- ChatCompletions chatCompletions = this.openAIClient.getChatCompletions(options.getModel(), options);
+ ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION
+ .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
+ this.observationRegistry)
+ .observe(() -> {
+ ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt);
+ options.setStream(false);
- ChatResponse chatResponse = toChatResponse(chatCompletions);
+ ChatCompletions chatCompletions = this.openAIClient.getChatCompletions(options.getModel(), options);
+ ChatResponse chatResponse = toChatResponse(chatCompletions);
+ observationContext.setResponse(chatResponse);
+ return chatResponse;
+ });
if (!isProxyToolCalls(prompt, this.defaultOptions)
- && isToolCall(chatResponse, Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS).toLowerCase()))) {
- var toolCallConversation = handleToolCalls(prompt, chatResponse);
+ && isToolCall(response, Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS).toLowerCase()))) {
+ var toolCallConversation = handleToolCalls(prompt, response);
// Recursively call the call method with the tool call message
// conversation that contains the call responses.
return this.call(new Prompt(toolCallConversation, prompt.getOptions()));
}
- return chatResponse;
+ return response;
}
@Override
diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelObservationIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelObservationIT.java
new file mode 100644
index 00000000000..df4907a09ee
--- /dev/null
+++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelObservationIT.java
@@ -0,0 +1,159 @@
+/*
+ * Copyright 2023-2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.ai.azure.openai;
+
+import static com.azure.core.http.policy.HttpLogDetailLevel.BODY_AND_HEADERS;
+import static org.assertj.core.api.Assertions.assertThat;
+
+import java.util.List;
+
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
+
+import org.springframework.ai.chat.metadata.ChatResponseMetadata;
+import org.springframework.ai.chat.model.ChatResponse;
+import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
+import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
+import org.springframework.ai.chat.prompt.Prompt;
+import org.springframework.ai.observation.conventions.AiOperationType;
+import org.springframework.ai.observation.conventions.AiProvider;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.boot.SpringBootConfiguration;
+import org.springframework.boot.test.context.SpringBootTest;
+import org.springframework.context.annotation.Bean;
+
+import com.azure.ai.openai.OpenAIClientBuilder;
+import com.azure.ai.openai.OpenAIServiceVersion;
+import com.azure.core.credential.AzureKeyCredential;
+import com.azure.core.http.policy.HttpLogOptions;
+import io.micrometer.common.KeyValue;
+import io.micrometer.observation.tck.TestObservationRegistry;
+import io.micrometer.observation.tck.TestObservationRegistryAssert;
+
+/**
+ * @author Soby Chacko
+ */
+@SpringBootTest(classes = AzureOpenAiChatModelObservationIT.TestConfiguration.class)
+@EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_API_KEY", matches = ".+")
+@EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_ENDPOINT", matches = ".+")
+class AzureOpenAiChatModelObservationIT {
+
+ @Autowired
+ private AzureOpenAiChatModel chatModel;
+
+ @Autowired
+ TestObservationRegistry observationRegistry;
+
+ @Test
+ void observationForImperativeChatOperation() {
+
+ var options = AzureOpenAiChatOptions.builder()
+ .withFrequencyPenalty(0.0)
+ .withMaxTokens(2048)
+ .withPresencePenalty(0.0)
+ .withStop(List.of("this-is-the-end"))
+ .withTemperature(0.7)
+ .withTopP(1.0)
+ .build();
+
+ Prompt prompt = new Prompt("Why does a raven look like a desk?", options);
+
+ ChatResponse chatResponse = chatModel.call(prompt);
+ assertThat(chatResponse.getResult().getOutput().getContent()).isNotEmpty();
+
+ ChatResponseMetadata responseMetadata = chatResponse.getMetadata();
+ assertThat(responseMetadata).isNotNull();
+
+ validate(responseMetadata);
+ }
+
+ private void validate(ChatResponseMetadata responseMetadata) {
+ TestObservationRegistryAssert.assertThat(observationRegistry)
+ .doesNotHaveAnyRemainingCurrentObservation()
+ .hasObservationWithNameEqualTo(DefaultChatModelObservationConvention.DEFAULT_NAME)
+ .that()
+ .hasLowCardinalityKeyValue(
+ ChatModelObservationDocumentation.LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(),
+ AiOperationType.CHAT.value())
+ .hasLowCardinalityKeyValue(ChatModelObservationDocumentation.LowCardinalityKeyNames.AI_PROVIDER.asString(),
+ AiProvider.AZURE_OPENAI.value())
+ .hasLowCardinalityKeyValue(
+ ChatModelObservationDocumentation.LowCardinalityKeyNames.RESPONSE_MODEL.asString(),
+ responseMetadata.getModel())
+ .hasHighCardinalityKeyValue(
+ ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_FREQUENCY_PENALTY.asString(),
+ "0.0")
+ .hasHighCardinalityKeyValue(
+ ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_MAX_TOKENS.asString(), "2048")
+ .hasHighCardinalityKeyValue(
+ ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_PRESENCE_PENALTY.asString(),
+ "0.0")
+ .hasHighCardinalityKeyValue(
+ ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_STOP_SEQUENCES.asString(),
+ "[\"this-is-the-end\"]")
+ .hasHighCardinalityKeyValue(
+ ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_TEMPERATURE.asString(), "0.7")
+ .hasHighCardinalityKeyValue(
+ ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_TOP_K.asString(),
+ KeyValue.NONE_VALUE)
+ .hasHighCardinalityKeyValue(
+ ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_TOP_P.asString(), "1.0")
+ .hasHighCardinalityKeyValue(
+ ChatModelObservationDocumentation.HighCardinalityKeyNames.RESPONSE_ID.asString(),
+ responseMetadata.getId())
+ .hasHighCardinalityKeyValue(
+ ChatModelObservationDocumentation.HighCardinalityKeyNames.RESPONSE_FINISH_REASONS.asString(),
+ "[\"stop\"]")
+ .hasHighCardinalityKeyValue(
+ ChatModelObservationDocumentation.HighCardinalityKeyNames.USAGE_INPUT_TOKENS.asString(),
+ String.valueOf(responseMetadata.getUsage().getPromptTokens()))
+ .hasHighCardinalityKeyValue(
+ ChatModelObservationDocumentation.HighCardinalityKeyNames.USAGE_OUTPUT_TOKENS.asString(),
+ String.valueOf(responseMetadata.getUsage().getGenerationTokens()))
+ .hasHighCardinalityKeyValue(
+ ChatModelObservationDocumentation.HighCardinalityKeyNames.USAGE_TOTAL_TOKENS.asString(),
+ String.valueOf(responseMetadata.getUsage().getTotalTokens()))
+ .hasBeenStarted()
+ .hasBeenStopped();
+ }
+
+ @SpringBootConfiguration
+ public static class TestConfiguration {
+
+ @Bean
+ public TestObservationRegistry observationRegistry() {
+ return TestObservationRegistry.create();
+ }
+
+ @Bean
+ public OpenAIClientBuilder openAIClient() {
+ return new OpenAIClientBuilder().credential(new AzureKeyCredential(System.getenv("AZURE_OPENAI_API_KEY")))
+ .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
+ .serviceVersion(OpenAIServiceVersion.V2024_02_15_PREVIEW)
+ .httpLogOptions(new HttpLogOptions().setLogLevel(BODY_AND_HEADERS));
+ }
+
+ @Bean
+ public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder,
+ TestObservationRegistry observationRegistry) {
+ return new AzureOpenAiChatModel(openAIClientBuilder,
+ AzureOpenAiChatOptions.builder().withDeploymentName("gpt-4o").withMaxTokens(1000).build(), null,
+ List.of(), observationRegistry);
+ }
+
+ }
+
+}
diff --git a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java
index cb9a76aea12..4d0a6acbe71 100644
--- a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java
+++ b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java
@@ -35,7 +35,8 @@ public enum AiProvider {
OPENAI("openai"),
SPRING_AI("spring_ai"),
VERTEX_AI("vertex_ai"),
- OCI_GENAI("oci_genai");
+ OCI_GENAI("oci_genai"),
+ AZURE_OPENAI("azure-openai");
private final String value;