diff --git a/models/spring-ai-azure-openai/pom.xml b/models/spring-ai-azure-openai/pom.xml index d70c098a074..63443562899 100644 --- a/models/spring-ai-azure-openai/pom.xml +++ b/models/spring-ai-azure-openai/pom.xml @@ -58,6 +58,12 @@ spring-boot-starter-test test + + + io.micrometer + micrometer-observation-test + test + diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java index 62c4f2198c2..c61b6e26f30 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java @@ -16,11 +16,16 @@ package org.springframework.ai.azure.openai; -import com.azure.ai.openai.OpenAIAsyncClient; -import com.azure.ai.openai.OpenAIClient; -import com.azure.ai.openai.OpenAIClientBuilder; -import com.azure.ai.openai.models.*; -import com.azure.core.util.BinaryData; +import java.util.ArrayList; +import java.util.Base64; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.atomic.AtomicBoolean; + import org.springframework.ai.azure.openai.metadata.AzureOpenAiUsage; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; @@ -36,27 +41,51 @@ import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.observation.ChatModelObservationContext; +import org.springframework.ai.chat.observation.ChatModelObservationConvention; +import org.springframework.ai.chat.observation.ChatModelObservationDocumentation; +import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.Media; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackContext; +import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; + +import com.azure.ai.openai.OpenAIAsyncClient; +import com.azure.ai.openai.OpenAIClient; +import com.azure.ai.openai.OpenAIClientBuilder; +import com.azure.ai.openai.models.ChatChoice; +import com.azure.ai.openai.models.ChatCompletions; +import com.azure.ai.openai.models.ChatCompletionsFunctionToolCall; +import com.azure.ai.openai.models.ChatCompletionsFunctionToolDefinition; +import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat; +import com.azure.ai.openai.models.ChatCompletionsOptions; +import com.azure.ai.openai.models.ChatCompletionsResponseFormat; +import com.azure.ai.openai.models.ChatCompletionsTextResponseFormat; +import com.azure.ai.openai.models.ChatCompletionsToolCall; +import com.azure.ai.openai.models.ChatCompletionsToolDefinition; +import com.azure.ai.openai.models.ChatMessageContentItem; +import com.azure.ai.openai.models.ChatMessageImageContentItem; +import com.azure.ai.openai.models.ChatMessageImageUrl; +import com.azure.ai.openai.models.ChatMessageTextContentItem; +import com.azure.ai.openai.models.ChatRequestAssistantMessage; +import com.azure.ai.openai.models.ChatRequestMessage; +import com.azure.ai.openai.models.ChatRequestSystemMessage; +import com.azure.ai.openai.models.ChatRequestToolMessage; +import com.azure.ai.openai.models.ChatRequestUserMessage; +import com.azure.ai.openai.models.CompletionsFinishReason; +import com.azure.ai.openai.models.ContentFilterResultsForPrompt; +import com.azure.ai.openai.models.FunctionCall; +import com.azure.ai.openai.models.FunctionDefinition; +import com.azure.core.util.BinaryData; +import io.micrometer.observation.ObservationRegistry; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import java.util.ArrayList; -import java.util.Base64; -import java.util.Collections; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.Set; -import java.util.concurrent.atomic.AtomicBoolean; - /** * {@link ChatModel} implementation for {@literal Microsoft Azure AI} backed by * {@link OpenAIClient}. @@ -81,6 +110,8 @@ public class AzureOpenAiChatModel extends AbstractToolCallSupport implements Cha private static final Double DEFAULT_TEMPERATURE = 0.7; + private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention(); + /** * The {@link OpenAIClient} used to interact with the Azure OpenAI service. */ @@ -96,8 +127,18 @@ public class AzureOpenAiChatModel extends AbstractToolCallSupport implements Cha */ private final AzureOpenAiChatOptions defaultOptions; - public AzureOpenAiChatModel(OpenAIClientBuilder microsoftOpenAiClient) { - this(microsoftOpenAiClient, + /** + * Observation registry used for instrumentation. + */ + private final ObservationRegistry observationRegistry; + + /** + * Conventions to use for generating observations. + */ + private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; + + public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder) { + this(openAIClientBuilder, AzureOpenAiChatOptions.builder() .withDeploymentName(DEFAULT_DEPLOYMENT_NAME) .withTemperature(DEFAULT_TEMPERATURE) @@ -115,12 +156,19 @@ public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAi public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiChatOptions options, FunctionCallbackContext functionCallbackContext, List toolFunctionCallbacks) { + this(openAIClientBuilder, options, functionCallbackContext, List.of(), ObservationRegistry.NOOP); + } + + public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiChatOptions options, + FunctionCallbackContext functionCallbackContext, List toolFunctionCallbacks, + ObservationRegistry observationRegistry) { super(functionCallbackContext, options, toolFunctionCallbacks); Assert.notNull(openAIClientBuilder, "com.azure.ai.openai.OpenAIClient must not be null"); Assert.notNull(options, "AzureOpenAiChatOptions must not be null"); this.openAIClient = openAIClientBuilder.buildClient(); this.openAIAsyncClient = openAIClientBuilder.buildAsyncClient(); this.defaultOptions = options; + this.observationRegistry = observationRegistry; } public AzureOpenAiChatOptions getDefaultOptions() { @@ -130,22 +178,34 @@ public AzureOpenAiChatOptions getDefaultOptions() { @Override public ChatResponse call(Prompt prompt) { - ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt); - options.setStream(false); + ChatModelObservationContext observationContext = ChatModelObservationContext.builder() + .prompt(prompt) + .provider(AiProvider.AZURE_OPENAI.value()) + .requestOptions(prompt.getOptions() != null ? prompt.getOptions() : this.defaultOptions) + .build(); - ChatCompletions chatCompletions = this.openAIClient.getChatCompletions(options.getModel(), options); + ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION + .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + this.observationRegistry) + .observe(() -> { + ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt); + options.setStream(false); - ChatResponse chatResponse = toChatResponse(chatCompletions); + ChatCompletions chatCompletions = this.openAIClient.getChatCompletions(options.getModel(), options); + ChatResponse chatResponse = toChatResponse(chatCompletions); + observationContext.setResponse(chatResponse); + return chatResponse; + }); if (!isProxyToolCalls(prompt, this.defaultOptions) - && isToolCall(chatResponse, Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS).toLowerCase()))) { - var toolCallConversation = handleToolCalls(prompt, chatResponse); + && isToolCall(response, Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS).toLowerCase()))) { + var toolCallConversation = handleToolCalls(prompt, response); // Recursively call the call method with the tool call message // conversation that contains the call responses. return this.call(new Prompt(toolCallConversation, prompt.getOptions())); } - return chatResponse; + return response; } @Override diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelObservationIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelObservationIT.java new file mode 100644 index 00000000000..df4907a09ee --- /dev/null +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelObservationIT.java @@ -0,0 +1,159 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.azure.openai; + +import static com.azure.core.http.policy.HttpLogDetailLevel.BODY_AND_HEADERS; +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.observation.ChatModelObservationDocumentation; +import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.observation.conventions.AiOperationType; +import org.springframework.ai.observation.conventions.AiProvider; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; + +import com.azure.ai.openai.OpenAIClientBuilder; +import com.azure.ai.openai.OpenAIServiceVersion; +import com.azure.core.credential.AzureKeyCredential; +import com.azure.core.http.policy.HttpLogOptions; +import io.micrometer.common.KeyValue; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; + +/** + * @author Soby Chacko + */ +@SpringBootTest(classes = AzureOpenAiChatModelObservationIT.TestConfiguration.class) +@EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_API_KEY", matches = ".+") +@EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_ENDPOINT", matches = ".+") +class AzureOpenAiChatModelObservationIT { + + @Autowired + private AzureOpenAiChatModel chatModel; + + @Autowired + TestObservationRegistry observationRegistry; + + @Test + void observationForImperativeChatOperation() { + + var options = AzureOpenAiChatOptions.builder() + .withFrequencyPenalty(0.0) + .withMaxTokens(2048) + .withPresencePenalty(0.0) + .withStop(List.of("this-is-the-end")) + .withTemperature(0.7) + .withTopP(1.0) + .build(); + + Prompt prompt = new Prompt("Why does a raven look like a desk?", options); + + ChatResponse chatResponse = chatModel.call(prompt); + assertThat(chatResponse.getResult().getOutput().getContent()).isNotEmpty(); + + ChatResponseMetadata responseMetadata = chatResponse.getMetadata(); + assertThat(responseMetadata).isNotNull(); + + validate(responseMetadata); + } + + private void validate(ChatResponseMetadata responseMetadata) { + TestObservationRegistryAssert.assertThat(observationRegistry) + .doesNotHaveAnyRemainingCurrentObservation() + .hasObservationWithNameEqualTo(DefaultChatModelObservationConvention.DEFAULT_NAME) + .that() + .hasLowCardinalityKeyValue( + ChatModelObservationDocumentation.LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(), + AiOperationType.CHAT.value()) + .hasLowCardinalityKeyValue(ChatModelObservationDocumentation.LowCardinalityKeyNames.AI_PROVIDER.asString(), + AiProvider.AZURE_OPENAI.value()) + .hasLowCardinalityKeyValue( + ChatModelObservationDocumentation.LowCardinalityKeyNames.RESPONSE_MODEL.asString(), + responseMetadata.getModel()) + .hasHighCardinalityKeyValue( + ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_FREQUENCY_PENALTY.asString(), + "0.0") + .hasHighCardinalityKeyValue( + ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_MAX_TOKENS.asString(), "2048") + .hasHighCardinalityKeyValue( + ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_PRESENCE_PENALTY.asString(), + "0.0") + .hasHighCardinalityKeyValue( + ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_STOP_SEQUENCES.asString(), + "[\"this-is-the-end\"]") + .hasHighCardinalityKeyValue( + ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_TEMPERATURE.asString(), "0.7") + .hasHighCardinalityKeyValue( + ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_TOP_K.asString(), + KeyValue.NONE_VALUE) + .hasHighCardinalityKeyValue( + ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_TOP_P.asString(), "1.0") + .hasHighCardinalityKeyValue( + ChatModelObservationDocumentation.HighCardinalityKeyNames.RESPONSE_ID.asString(), + responseMetadata.getId()) + .hasHighCardinalityKeyValue( + ChatModelObservationDocumentation.HighCardinalityKeyNames.RESPONSE_FINISH_REASONS.asString(), + "[\"stop\"]") + .hasHighCardinalityKeyValue( + ChatModelObservationDocumentation.HighCardinalityKeyNames.USAGE_INPUT_TOKENS.asString(), + String.valueOf(responseMetadata.getUsage().getPromptTokens())) + .hasHighCardinalityKeyValue( + ChatModelObservationDocumentation.HighCardinalityKeyNames.USAGE_OUTPUT_TOKENS.asString(), + String.valueOf(responseMetadata.getUsage().getGenerationTokens())) + .hasHighCardinalityKeyValue( + ChatModelObservationDocumentation.HighCardinalityKeyNames.USAGE_TOTAL_TOKENS.asString(), + String.valueOf(responseMetadata.getUsage().getTotalTokens())) + .hasBeenStarted() + .hasBeenStopped(); + } + + @SpringBootConfiguration + public static class TestConfiguration { + + @Bean + public TestObservationRegistry observationRegistry() { + return TestObservationRegistry.create(); + } + + @Bean + public OpenAIClientBuilder openAIClient() { + return new OpenAIClientBuilder().credential(new AzureKeyCredential(System.getenv("AZURE_OPENAI_API_KEY"))) + .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) + .serviceVersion(OpenAIServiceVersion.V2024_02_15_PREVIEW) + .httpLogOptions(new HttpLogOptions().setLogLevel(BODY_AND_HEADERS)); + } + + @Bean + public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, + TestObservationRegistry observationRegistry) { + return new AzureOpenAiChatModel(openAIClientBuilder, + AzureOpenAiChatOptions.builder().withDeploymentName("gpt-4o").withMaxTokens(1000).build(), null, + List.of(), observationRegistry); + } + + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java index cb9a76aea12..4d0a6acbe71 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java @@ -35,7 +35,8 @@ public enum AiProvider { OPENAI("openai"), SPRING_AI("spring_ai"), VERTEX_AI("vertex_ai"), - OCI_GENAI("oci_genai"); + OCI_GENAI("oci_genai"), + AZURE_OPENAI("azure-openai"); private final String value;