Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions models/spring-ai-azure-openai/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>io.micrometer</groupId>
<artifactId>micrometer-observation-test</artifactId>
<scope>test</scope>
</dependency>
</dependencies>

</project>
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,16 @@

package org.springframework.ai.azure.openai;

import com.azure.ai.openai.OpenAIAsyncClient;
import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.OpenAIClientBuilder;
import com.azure.ai.openai.models.*;
import com.azure.core.util.BinaryData;
import java.util.ArrayList;
import java.util.Base64;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;

import org.springframework.ai.azure.openai.metadata.AzureOpenAiUsage;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
Expand All @@ -36,27 +41,51 @@
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.observation.ChatModelObservationContext;
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.Media;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;

import com.azure.ai.openai.OpenAIAsyncClient;
import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.OpenAIClientBuilder;
import com.azure.ai.openai.models.ChatChoice;
import com.azure.ai.openai.models.ChatCompletions;
import com.azure.ai.openai.models.ChatCompletionsFunctionToolCall;
import com.azure.ai.openai.models.ChatCompletionsFunctionToolDefinition;
import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat;
import com.azure.ai.openai.models.ChatCompletionsOptions;
import com.azure.ai.openai.models.ChatCompletionsResponseFormat;
import com.azure.ai.openai.models.ChatCompletionsTextResponseFormat;
import com.azure.ai.openai.models.ChatCompletionsToolCall;
import com.azure.ai.openai.models.ChatCompletionsToolDefinition;
import com.azure.ai.openai.models.ChatMessageContentItem;
import com.azure.ai.openai.models.ChatMessageImageContentItem;
import com.azure.ai.openai.models.ChatMessageImageUrl;
import com.azure.ai.openai.models.ChatMessageTextContentItem;
import com.azure.ai.openai.models.ChatRequestAssistantMessage;
import com.azure.ai.openai.models.ChatRequestMessage;
import com.azure.ai.openai.models.ChatRequestSystemMessage;
import com.azure.ai.openai.models.ChatRequestToolMessage;
import com.azure.ai.openai.models.ChatRequestUserMessage;
import com.azure.ai.openai.models.CompletionsFinishReason;
import com.azure.ai.openai.models.ContentFilterResultsForPrompt;
import com.azure.ai.openai.models.FunctionCall;
import com.azure.ai.openai.models.FunctionDefinition;
import com.azure.core.util.BinaryData;
import io.micrometer.observation.ObservationRegistry;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.util.ArrayList;
import java.util.Base64;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;

/**
* {@link ChatModel} implementation for {@literal Microsoft Azure AI} backed by
* {@link OpenAIClient}.
Expand All @@ -81,6 +110,8 @@ public class AzureOpenAiChatModel extends AbstractToolCallSupport implements Cha

private static final Double DEFAULT_TEMPERATURE = 0.7;

private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention();

/**
* The {@link OpenAIClient} used to interact with the Azure OpenAI service.
*/
Expand All @@ -96,8 +127,18 @@ public class AzureOpenAiChatModel extends AbstractToolCallSupport implements Cha
*/
private final AzureOpenAiChatOptions defaultOptions;

public AzureOpenAiChatModel(OpenAIClientBuilder microsoftOpenAiClient) {
this(microsoftOpenAiClient,
/**
* Observation registry used for instrumentation.
*/
private final ObservationRegistry observationRegistry;

/**
* Conventions to use for generating observations.
*/
private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;

public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder) {
this(openAIClientBuilder,
AzureOpenAiChatOptions.builder()
.withDeploymentName(DEFAULT_DEPLOYMENT_NAME)
.withTemperature(DEFAULT_TEMPERATURE)
Expand All @@ -115,12 +156,19 @@ public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAi

public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiChatOptions options,
FunctionCallbackContext functionCallbackContext, List<FunctionCallback> toolFunctionCallbacks) {
this(openAIClientBuilder, options, functionCallbackContext, List.of(), ObservationRegistry.NOOP);
}

public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiChatOptions options,
FunctionCallbackContext functionCallbackContext, List<FunctionCallback> toolFunctionCallbacks,
ObservationRegistry observationRegistry) {
super(functionCallbackContext, options, toolFunctionCallbacks);
Assert.notNull(openAIClientBuilder, "com.azure.ai.openai.OpenAIClient must not be null");
Assert.notNull(options, "AzureOpenAiChatOptions must not be null");
this.openAIClient = openAIClientBuilder.buildClient();
this.openAIAsyncClient = openAIClientBuilder.buildAsyncClient();
this.defaultOptions = options;
this.observationRegistry = observationRegistry;
}

public AzureOpenAiChatOptions getDefaultOptions() {
Expand All @@ -130,22 +178,34 @@ public AzureOpenAiChatOptions getDefaultOptions() {
@Override
public ChatResponse call(Prompt prompt) {

ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt);
options.setStream(false);
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
.prompt(prompt)
.provider(AiProvider.AZURE_OPENAI.value())
.requestOptions(prompt.getOptions() != null ? prompt.getOptions() : this.defaultOptions)
.build();

ChatCompletions chatCompletions = this.openAIClient.getChatCompletions(options.getModel(), options);
ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
this.observationRegistry)
.observe(() -> {
ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt);
options.setStream(false);

ChatResponse chatResponse = toChatResponse(chatCompletions);
ChatCompletions chatCompletions = this.openAIClient.getChatCompletions(options.getModel(), options);
ChatResponse chatResponse = toChatResponse(chatCompletions);
observationContext.setResponse(chatResponse);
return chatResponse;
});

if (!isProxyToolCalls(prompt, this.defaultOptions)
&& isToolCall(chatResponse, Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS).toLowerCase()))) {
var toolCallConversation = handleToolCalls(prompt, chatResponse);
&& isToolCall(response, Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS).toLowerCase()))) {
var toolCallConversation = handleToolCalls(prompt, response);
// Recursively call the call method with the tool call message
// conversation that contains the call responses.
return this.call(new Prompt(toolCallConversation, prompt.getOptions()));
}

return chatResponse;
return response;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.azure.openai;

import static com.azure.core.http.policy.HttpLogDetailLevel.BODY_AND_HEADERS;
import static org.assertj.core.api.Assertions.assertThat;

import java.util.List;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;

import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.observation.conventions.AiOperationType;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.context.annotation.Bean;

import com.azure.ai.openai.OpenAIClientBuilder;
import com.azure.ai.openai.OpenAIServiceVersion;
import com.azure.core.credential.AzureKeyCredential;
import com.azure.core.http.policy.HttpLogOptions;
import io.micrometer.common.KeyValue;
import io.micrometer.observation.tck.TestObservationRegistry;
import io.micrometer.observation.tck.TestObservationRegistryAssert;

/**
* @author Soby Chacko
*/
@SpringBootTest(classes = AzureOpenAiChatModelObservationIT.TestConfiguration.class)
@EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_API_KEY", matches = ".+")
@EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_ENDPOINT", matches = ".+")
class AzureOpenAiChatModelObservationIT {

@Autowired
private AzureOpenAiChatModel chatModel;

@Autowired
TestObservationRegistry observationRegistry;

@Test
void observationForImperativeChatOperation() {

var options = AzureOpenAiChatOptions.builder()
.withFrequencyPenalty(0.0)
.withMaxTokens(2048)
.withPresencePenalty(0.0)
.withStop(List.of("this-is-the-end"))
.withTemperature(0.7)
.withTopP(1.0)
.build();

Prompt prompt = new Prompt("Why does a raven look like a desk?", options);

ChatResponse chatResponse = chatModel.call(prompt);
assertThat(chatResponse.getResult().getOutput().getContent()).isNotEmpty();

ChatResponseMetadata responseMetadata = chatResponse.getMetadata();
assertThat(responseMetadata).isNotNull();

validate(responseMetadata);
}

private void validate(ChatResponseMetadata responseMetadata) {
TestObservationRegistryAssert.assertThat(observationRegistry)
.doesNotHaveAnyRemainingCurrentObservation()
.hasObservationWithNameEqualTo(DefaultChatModelObservationConvention.DEFAULT_NAME)
.that()
.hasLowCardinalityKeyValue(
ChatModelObservationDocumentation.LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(),
AiOperationType.CHAT.value())
.hasLowCardinalityKeyValue(ChatModelObservationDocumentation.LowCardinalityKeyNames.AI_PROVIDER.asString(),
AiProvider.AZURE_OPENAI.value())
.hasLowCardinalityKeyValue(
ChatModelObservationDocumentation.LowCardinalityKeyNames.RESPONSE_MODEL.asString(),
responseMetadata.getModel())
.hasHighCardinalityKeyValue(
ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_FREQUENCY_PENALTY.asString(),
"0.0")
.hasHighCardinalityKeyValue(
ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_MAX_TOKENS.asString(), "2048")
.hasHighCardinalityKeyValue(
ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_PRESENCE_PENALTY.asString(),
"0.0")
.hasHighCardinalityKeyValue(
ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_STOP_SEQUENCES.asString(),
"[\"this-is-the-end\"]")
.hasHighCardinalityKeyValue(
ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_TEMPERATURE.asString(), "0.7")
.hasHighCardinalityKeyValue(
ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_TOP_K.asString(),
KeyValue.NONE_VALUE)
.hasHighCardinalityKeyValue(
ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_TOP_P.asString(), "1.0")
.hasHighCardinalityKeyValue(
ChatModelObservationDocumentation.HighCardinalityKeyNames.RESPONSE_ID.asString(),
responseMetadata.getId())
.hasHighCardinalityKeyValue(
ChatModelObservationDocumentation.HighCardinalityKeyNames.RESPONSE_FINISH_REASONS.asString(),
"[\"stop\"]")
.hasHighCardinalityKeyValue(
ChatModelObservationDocumentation.HighCardinalityKeyNames.USAGE_INPUT_TOKENS.asString(),
String.valueOf(responseMetadata.getUsage().getPromptTokens()))
.hasHighCardinalityKeyValue(
ChatModelObservationDocumentation.HighCardinalityKeyNames.USAGE_OUTPUT_TOKENS.asString(),
String.valueOf(responseMetadata.getUsage().getGenerationTokens()))
.hasHighCardinalityKeyValue(
ChatModelObservationDocumentation.HighCardinalityKeyNames.USAGE_TOTAL_TOKENS.asString(),
String.valueOf(responseMetadata.getUsage().getTotalTokens()))
.hasBeenStarted()
.hasBeenStopped();
}

@SpringBootConfiguration
public static class TestConfiguration {

@Bean
public TestObservationRegistry observationRegistry() {
return TestObservationRegistry.create();
}

@Bean
public OpenAIClientBuilder openAIClient() {
return new OpenAIClientBuilder().credential(new AzureKeyCredential(System.getenv("AZURE_OPENAI_API_KEY")))
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
.serviceVersion(OpenAIServiceVersion.V2024_02_15_PREVIEW)
.httpLogOptions(new HttpLogOptions().setLogLevel(BODY_AND_HEADERS));
}

@Bean
public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder,
TestObservationRegistry observationRegistry) {
return new AzureOpenAiChatModel(openAIClientBuilder,
AzureOpenAiChatOptions.builder().withDeploymentName("gpt-4o").withMaxTokens(1000).build(), null,
List.of(), observationRegistry);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ public enum AiProvider {
OPENAI("openai"),
SPRING_AI("spring_ai"),
VERTEX_AI("vertex_ai"),
OCI_GENAI("oci_genai");
OCI_GENAI("oci_genai"),
AZURE_OPENAI("azure-openai");

private final String value;

Expand Down