diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-azure-openai/pom.xml b/auto-configurations/models/spring-ai-autoconfigure-model-azure-openai/pom.xml index d087952e24c..319428d487c 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-azure-openai/pom.xml +++ b/auto-configurations/models/spring-ai-autoconfigure-model-azure-openai/pom.xml @@ -103,6 +103,12 @@ mockito-core test - + + com.azure + azure-identity + 1.15.4 + compile + + diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-azure-openai/src/main/java/org/springframework/ai/model/azure/openai/autoconfigure/AzureOpenAiClientBuilderConfiguration.java b/auto-configurations/models/spring-ai-autoconfigure-model-azure-openai/src/main/java/org/springframework/ai/model/azure/openai/autoconfigure/AzureOpenAiClientBuilderConfiguration.java index 739b0a8e7da..784c4bf4433 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-azure-openai/src/main/java/org/springframework/ai/model/azure/openai/autoconfigure/AzureOpenAiClientBuilderConfiguration.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-azure-openai/src/main/java/org/springframework/ai/model/azure/openai/autoconfigure/AzureOpenAiClientBuilderConfiguration.java @@ -27,6 +27,8 @@ import com.azure.core.util.ClientOptions; import com.azure.core.util.Header; +import com.azure.identity.DefaultAzureCredential; +import com.azure.identity.DefaultAzureCredentialBuilder; import org.springframework.beans.factory.ObjectProvider; import org.springframework.boot.autoconfigure.condition.ConditionalOnBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; @@ -55,48 +57,39 @@ public class AzureOpenAiClientBuilderConfiguration { public OpenAIClientBuilder openAIClientBuilder(AzureOpenAiConnectionProperties connectionProperties, ObjectProvider customizers) { - if (StringUtils.hasText(connectionProperties.getApiKey())) { - - Assert.hasText(connectionProperties.getEndpoint(), "Endpoint must not be empty"); - - Map customHeaders = connectionProperties.getCustomHeaders(); - List
headers = customHeaders.entrySet() - .stream() - .map(entry -> new Header(entry.getKey(), entry.getValue())) - .collect(Collectors.toList()); - ClientOptions clientOptions = new ClientOptions().setApplicationId(APPLICATION_ID).setHeaders(headers); - OpenAIClientBuilder clientBuilder = new OpenAIClientBuilder().endpoint(connectionProperties.getEndpoint()) - .credential(new AzureKeyCredential(connectionProperties.getApiKey())) - .clientOptions(clientOptions); - applyOpenAIClientBuilderCustomizers(clientBuilder, customizers); - return clientBuilder; - } + final OpenAIClientBuilder clientBuilder; // Connect to OpenAI (e.g. not the Azure OpenAI). The deploymentName property is // used as OpenAI model name. if (StringUtils.hasText(connectionProperties.getOpenAiApiKey())) { - OpenAIClientBuilder clientBuilder = new OpenAIClientBuilder().endpoint("https://api.openai.com/v1") + clientBuilder = new OpenAIClientBuilder().endpoint("https://api.openai.com/v1") .credential(new KeyCredential(connectionProperties.getOpenAiApiKey())) .clientOptions(new ClientOptions().setApplicationId(APPLICATION_ID)); applyOpenAIClientBuilderCustomizers(clientBuilder, customizers); return clientBuilder; } - throw new IllegalArgumentException("Either API key or OpenAI API key must not be empty"); - } + Map customHeaders = connectionProperties.getCustomHeaders(); + List
headers = customHeaders.entrySet() + .stream() + .map(entry -> new Header(entry.getKey(), entry.getValue())) + .collect(Collectors.toList()); + ClientOptions clientOptions = new ClientOptions().setApplicationId(APPLICATION_ID).setHeaders(headers); - @Bean - @ConditionalOnMissingBean - @ConditionalOnBean(TokenCredential.class) - public OpenAIClientBuilder openAIClientWithTokenCredential(AzureOpenAiConnectionProperties connectionProperties, - TokenCredential tokenCredential, ObjectProvider customizers) { - - Assert.notNull(tokenCredential, "TokenCredential must not be null"); Assert.hasText(connectionProperties.getEndpoint(), "Endpoint must not be empty"); - OpenAIClientBuilder clientBuilder = new OpenAIClientBuilder().endpoint(connectionProperties.getEndpoint()) - .credential(tokenCredential) - .clientOptions(new ClientOptions().setApplicationId(APPLICATION_ID)); + if (!StringUtils.hasText(connectionProperties.getApiKey())) { + // Entra ID configuration, as the API key is not set + clientBuilder = new OpenAIClientBuilder().endpoint(connectionProperties.getEndpoint()) + .credential(new DefaultAzureCredentialBuilder().build()) + .clientOptions(clientOptions); + } + else { + // Azure OpenAI configuration using API key and endpoint + clientBuilder = new OpenAIClientBuilder().endpoint(connectionProperties.getEndpoint()) + .credential(new AzureKeyCredential(connectionProperties.getApiKey())) + .clientOptions(clientOptions); + } applyOpenAIClientBuilderCustomizers(clientBuilder, customizers); return clientBuilder; } diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-azure-openai/src/test/java/org/springframework/ai/model/azure/openai/autoconfigure/AzureOpenAiAutoConfigurationEntraIT.java b/auto-configurations/models/spring-ai-autoconfigure-model-azure-openai/src/test/java/org/springframework/ai/model/azure/openai/autoconfigure/AzureOpenAiAutoConfigurationEntraIT.java new file mode 100644 index 00000000000..1f6c5fec54c --- /dev/null +++ b/auto-configurations/models/spring-ai-autoconfigure-model-azure-openai/src/test/java/org/springframework/ai/model/azure/openai/autoconfigure/AzureOpenAiAutoConfigurationEntraIT.java @@ -0,0 +1,286 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.azure.openai.autoconfigure; + +import java.lang.reflect.Field; +import java.net.URI; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.Collectors; + +import com.azure.ai.openai.OpenAIClient; +import com.azure.ai.openai.OpenAIClientBuilder; +import com.azure.ai.openai.implementation.OpenAIClientImpl; +import com.azure.core.http.HttpHeader; +import com.azure.core.http.HttpHeaderName; +import com.azure.core.http.HttpMethod; +import com.azure.core.http.HttpPipeline; +import com.azure.core.http.HttpRequest; +import com.azure.core.http.HttpResponse; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.DisabledIfEnvironmentVariable; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import reactor.core.publisher.Flux; + +import org.springframework.ai.azure.openai.AzureOpenAiAudioTranscriptionModel; +import org.springframework.ai.azure.openai.AzureOpenAiChatModel; +import org.springframework.ai.azure.openai.AzureOpenAiEmbeddingModel; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.Resource; +import org.springframework.util.ReflectionUtils; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Christian Tzolov + * @author Piotr Olaszewski + * @author Soby Chacko + * @author Manuel Andreo Garcia + * @since 0.8.0 + */ +@DisabledIfEnvironmentVariable(named = "AZURE_OPENAI_API_KEY", matches = ".+") +@EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_ENDPOINT", matches = ".+") +class AzureOpenAiAutoConfigurationEntraIT { + + private static String CHAT_MODEL_NAME = "gpt-4o"; + + private static String EMBEDDING_MODEL_NAME = "text-embedding-ada-002"; + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner().withPropertyValues( + // @formatter:off + "spring.ai.azure.openai.endpoint=" + System.getenv("AZURE_OPENAI_ENDPOINT"), + + "spring.ai.azure.openai.chat.options.deployment-name=" + CHAT_MODEL_NAME, + "spring.ai.azure.openai.chat.options.temperature=0.8", + "spring.ai.azure.openai.chat.options.maxTokens=123", + + "spring.ai.azure.openai.embedding.options.deployment-name=" + EMBEDDING_MODEL_NAME, + "spring.ai.azure.openai.audio.transcription.options.deployment-name=" + System.getenv("AZURE_OPENAI_TRANSCRIPTION_DEPLOYMENT_NAME") + // @formatter:on + ); + + private final Message systemMessage = new SystemPromptTemplate(""" + You are a helpful AI assistant. Your name is {name}. + You are an AI assistant that helps people find information. + Your name is {name} + You should reply to the user's request with your name and also in the style of a {voice}. + """).createMessage(Map.of("name", "Bob", "voice", "pirate")); + + private final UserMessage userMessage = new UserMessage( + "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did."); + + @Test + void chatCompletion() { + this.contextRunner.withConfiguration(AutoConfigurations.of(AzureOpenAiChatAutoConfiguration.class)) + .run(context -> { + AzureOpenAiChatModel chatModel = context.getBean(AzureOpenAiChatModel.class); + ChatResponse response = chatModel.call(new Prompt(List.of(this.userMessage, this.systemMessage))); + assertThat(response.getResult().getOutput().getText()).contains("Blackbeard"); + }); + } + + @Test + void httpRequestContainsUserAgentAndCustomHeaders() { + this.contextRunner.withConfiguration(AutoConfigurations.of(AzureOpenAiChatAutoConfiguration.class)) + .withPropertyValues("spring.ai.azure.openai.custom-headers.foo=bar", + "spring.ai.azure.openai.custom-headers.fizz=buzz") + .run(context -> { + OpenAIClientBuilder openAIClientBuilder = context.getBean(OpenAIClientBuilder.class); + OpenAIClient openAIClient = openAIClientBuilder.buildClient(); + Field serviceClientField = ReflectionUtils.findField(OpenAIClient.class, "serviceClient"); + assertThat(serviceClientField).isNotNull(); + ReflectionUtils.makeAccessible(serviceClientField); + OpenAIClientImpl oaci = (OpenAIClientImpl) ReflectionUtils.getField(serviceClientField, openAIClient); + assertThat(oaci).isNotNull(); + HttpPipeline httpPipeline = oaci.getHttpPipeline(); + HttpResponse httpResponse = httpPipeline + .send(new HttpRequest(HttpMethod.POST, new URI(System.getenv("AZURE_OPENAI_ENDPOINT")).toURL())) + .block(); + assertThat(httpResponse).isNotNull(); + HttpHeader httpHeader = httpResponse.getRequest().getHeaders().get(HttpHeaderName.USER_AGENT); + assertThat(httpHeader.getValue().startsWith("spring-ai azsdk-java-azure-ai-openai/")).isTrue(); + HttpHeader customHeader1 = httpResponse.getRequest().getHeaders().get("foo"); + assertThat(customHeader1.getValue()).isEqualTo("bar"); + HttpHeader customHeader2 = httpResponse.getRequest().getHeaders().get("fizz"); + assertThat(customHeader2.getValue()).isEqualTo("buzz"); + }); + } + + @Test + void chatCompletionStreaming() { + this.contextRunner.withConfiguration(AutoConfigurations.of(AzureOpenAiChatAutoConfiguration.class)) + .run(context -> { + + AzureOpenAiChatModel chatModel = context.getBean(AzureOpenAiChatModel.class); + + Flux response = chatModel + .stream(new Prompt(List.of(this.userMessage, this.systemMessage))); + + List responses = response.collectList().block(); + assertThat(responses.size()).isGreaterThan(10); + + String stitchedResponseContent = responses.stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getText) + .collect(Collectors.joining()); + + assertThat(stitchedResponseContent).contains("Blackbeard"); + }); + } + + @Test + void embedding() { + this.contextRunner.withConfiguration(AutoConfigurations.of(AzureOpenAiEmbeddingAutoConfiguration.class)) + .run(context -> { + AzureOpenAiEmbeddingModel embeddingModel = context.getBean(AzureOpenAiEmbeddingModel.class); + + EmbeddingResponse embeddingResponse = embeddingModel + .embedForResponse(List.of("Hello World", "World is big and salvation is near")); + assertThat(embeddingResponse.getResults()).hasSize(2); + assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); + assertThat(embeddingResponse.getResults().get(0).getIndex()).isEqualTo(0); + assertThat(embeddingResponse.getResults().get(1).getOutput()).isNotEmpty(); + assertThat(embeddingResponse.getResults().get(1).getIndex()).isEqualTo(1); + + assertThat(embeddingModel.dimensions()).isEqualTo(1536); + }); + + } + + @Test + @EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_TRANSCRIPTION_DEPLOYMENT_NAME", matches = ".+") + void transcribe() { + this.contextRunner + .withConfiguration(AutoConfigurations.of(AzureOpenAiAudioTranscriptionAutoConfiguration.class)) + .run(context -> { + AzureOpenAiAudioTranscriptionModel transcriptionModel = context + .getBean(AzureOpenAiAudioTranscriptionModel.class); + Resource audioFile = new ClassPathResource("/speech/jfk.flac"); + String response = transcriptionModel.call(audioFile); + assertThat(response).isEqualTo( + "And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country."); + }); + } + + @Test + void chatActivation() { + + // Disable the chat auto-configuration. + this.contextRunner.withConfiguration(AutoConfigurations.of(AzureOpenAiChatAutoConfiguration.class)) + .withPropertyValues("spring.ai.model.chat=none") + .run(context -> { + assertThat(context.getBeansOfType(AzureOpenAiChatProperties.class)).isEmpty(); + assertThat(context.getBeansOfType(AzureOpenAiChatModel.class)).isEmpty(); + }); + + // The chat auto-configuration is enabled by default. + this.contextRunner.withConfiguration(AutoConfigurations.of(AzureOpenAiChatAutoConfiguration.class)) + .run(context -> { + assertThat(context.getBeansOfType(AzureOpenAiChatModel.class)).isNotEmpty(); + assertThat(context.getBeansOfType(AzureOpenAiChatProperties.class)).isNotEmpty(); + }); + + // Explicitly enable the chat auto-configuration. + this.contextRunner.withConfiguration(AutoConfigurations.of(AzureOpenAiChatAutoConfiguration.class)) + .withPropertyValues("spring.ai.model.chat=azure-openai") + .run(context -> { + assertThat(context.getBeansOfType(AzureOpenAiChatModel.class)).isNotEmpty(); + assertThat(context.getBeansOfType(AzureOpenAiChatProperties.class)).isNotEmpty(); + }); + } + + @Test + void embeddingActivation() { + + // Disable the embedding auto-configuration. + this.contextRunner.withConfiguration(AutoConfigurations.of(AzureOpenAiEmbeddingAutoConfiguration.class)) + .withPropertyValues("spring.ai.model.embedding=none") + .run(context -> { + assertThat(context.getBeansOfType(AzureOpenAiEmbeddingModel.class)).isEmpty(); + assertThat(context.getBeansOfType(AzureOpenAiEmbeddingProperties.class)).isEmpty(); + }); + + // The embedding auto-configuration is enabled by default. + this.contextRunner.withConfiguration(AutoConfigurations.of(AzureOpenAiEmbeddingAutoConfiguration.class)) + .run(context -> { + assertThat(context.getBeansOfType(AzureOpenAiEmbeddingModel.class)).isNotEmpty(); + assertThat(context.getBeansOfType(AzureOpenAiEmbeddingProperties.class)).isNotEmpty(); + }); + + // Explicitly enable the embedding auto-configuration. + this.contextRunner.withConfiguration(AutoConfigurations.of(AzureOpenAiEmbeddingAutoConfiguration.class)) + .withPropertyValues("spring.ai.model.embedding=azure-openai") + .run(context -> { + assertThat(context.getBeansOfType(AzureOpenAiEmbeddingModel.class)).isNotEmpty(); + assertThat(context.getBeansOfType(AzureOpenAiEmbeddingProperties.class)).isNotEmpty(); + }); + } + + @Test + void audioTranscriptionActivation() { + + // Disable the transcription auto-configuration. + this.contextRunner + .withConfiguration(AutoConfigurations.of(AzureOpenAiAudioTranscriptionAutoConfiguration.class)) + .withPropertyValues("spring.ai.model.audio.transcription=none") + .run(context -> { + assertThat(context.getBeansOfType(AzureOpenAiAudioTranscriptionModel.class)).isEmpty(); + assertThat(context.getBeansOfType(AzureOpenAiAudioTranscriptionProperties.class)).isEmpty(); + }); + + // The transcription auto-configuration is enabled by default. + this.contextRunner + .withConfiguration(AutoConfigurations.of(AzureOpenAiAudioTranscriptionAutoConfiguration.class)) + .run(context -> assertThat(context.getBeansOfType(AzureOpenAiAudioTranscriptionModel.class)).isNotEmpty()); + + // Explicitly enable the transcription auto-configuration. + this.contextRunner + .withConfiguration(AutoConfigurations.of(AzureOpenAiAudioTranscriptionAutoConfiguration.class)) + .withPropertyValues("spring.ai.model.audio.transcription=azure-openai") + .run(context -> assertThat(context.getBeansOfType(AzureOpenAiAudioTranscriptionModel.class)).isNotEmpty()); + } + + @Test + void openAIClientBuilderCustomizer() { + AtomicBoolean firstCustomizationApplied = new AtomicBoolean(false); + AtomicBoolean secondCustomizationApplied = new AtomicBoolean(false); + this.contextRunner.withConfiguration(AutoConfigurations.of(AzureOpenAiChatAutoConfiguration.class)) + .withBean("first", AzureOpenAIClientBuilderCustomizer.class, + () -> clientBuilder -> firstCustomizationApplied.set(true)) + .withBean("second", AzureOpenAIClientBuilderCustomizer.class, + () -> clientBuilder -> secondCustomizationApplied.set(true)) + .run(context -> { + context.getBean(OpenAIClientBuilder.class); + assertThat(firstCustomizationApplied.get()).isTrue(); + assertThat(secondCustomizationApplied.get()).isTrue(); + }); + } + +} diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/azure-openai-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/azure-openai-chat.adoc index 9705f31f46b..e2cf866cedf 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/azure-openai-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/azure-openai-chat.adoc @@ -10,12 +10,11 @@ The Azure OpenAI client offers three options to connect: using an Azure API key === Azure API Key & Endpoint -Obtain your Azure OpenAI `endpoint` and `api-key` from the Azure OpenAI Service section on the https://portal.azure.com[Azure Portal]. +To access models using an API key, obtain your Azure OpenAI `endpoint` and `api-key` from the Azure OpenAI Service section on the https://portal.azure.com[Azure Portal]. Spring AI defines two configuration properties: - -1. `spring.ai.azure.openai.api-key`: Set this to the value of the `API Key` obtained from Azure. +1. `spring.ai.azure.openai.api-key`: Set this to the value the `API Key` obtained from Azure. 2. `spring.ai.azure.openai.endpoint`: Set this to the endpoint URL obtained when provisioning your model in Azure. You can set these configuration properties by exporting environment variables: @@ -39,9 +38,12 @@ export SPRING_AI_AZURE_OPENAI_OPENAI_API_KEY= === Microsoft Entra ID -To authenticate using Microsoft Entra ID (formerly Azure Active Directory), create a `TokenCredential` bean in your configuration. -If this bean is available, an `OpenAIClient` instance will be created using the token credentials. -bd +For keyless authentication using Microsoft Entra ID (formerly Azure Active Directory), set _only_ the `spring.ai.azure.openai.endpoint` configuration property and _not_ the api-key property mentioned above. + +Finding only the endpoint property, your application will evaluate several different options for retrieving credentials and an `OpenAIClient` instance will be created using the token credentials. + +NOTE: It is no longer necessary to create a `TokenCredential` bean; it is configured for you automatically. + === Deployment Name To use Azure AI applications, you need to create an Azure AI Deployment through the link:https://oai.azure.com/portal[Azure AI Portal]. @@ -67,7 +69,7 @@ This is because in OpenAI there is no `Deployment Name`, only a `Model Name`. NOTE: The property `spring.ai.azure.openai.chat.options.model` has been renamed to `spring.ai.azure.openai.chat.options.deployment-name`. NOTE: If you decide to connect to `OpenAI` instead of `Azure OpenAI`, by setting the `spring.ai.azure.openai.openai-api-key=` property, -then the `spring.ai.azure.openai.chat.options.deployment-name` is treathed as an link:https://platform.openai.com/docs/models[OpenAI model] name. +then the `spring.ai.azure.openai.chat.options.deployment-name` is treated as an link:https://platform.openai.com/docs/models[OpenAI model] name. ==== Access the OpenAI Model diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/azure-openai-embeddings.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/azure-openai-embeddings.adoc index 4c8340f26ca..01c5dfce9f6 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/azure-openai-embeddings.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/azure-openai-embeddings.adoc @@ -46,9 +46,11 @@ export SPRING_AI_AZURE_OPENAI_OPENAI_API_KEY= === Microsoft Entra ID -To authenticate using Microsoft Entra ID (formerly Azure Active Directory), create a `TokenCredential` bean in your configuration. -If this bean is available, an `OpenAIClient` instance will be created using the token credentials. +For keyless authentication using Microsoft Entra ID (formerly Azure Active Directory), set _only_ the `spring.ai.azure.openai.endpoint` configuration property and _not_ the api-key property mentioned above. +Finding only the endpoint property, your application will evaluate several different options for retrieving credentials and an `OpenAIClient` instance will be created using the token credentials. + +NOTE: It is no longer necessary to create a `TokenCredential` bean; it is configured for you automatically. === Add Repositories and BOM