Skip to content

Commit 67b27bf

Browse files
committed
Added Entra ID identity management for Azure OpenAI, clean autoconfiguration, and updated docs to reflect changes.
1 parent 9e71b16 commit 67b27bf

File tree

5 files changed

+328
-39
lines changed

5 files changed

+328
-39
lines changed

auto-configurations/models/spring-ai-autoconfigure-model-azure-openai/pom.xml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,12 @@
103103
<artifactId>mockito-core</artifactId>
104104
<scope>test</scope>
105105
</dependency>
106-
</dependencies>
106+
<dependency>
107+
<groupId>com.azure</groupId>
108+
<artifactId>azure-identity</artifactId>
109+
<version>1.15.4</version>
110+
<scope>compile</scope>
111+
</dependency>
112+
</dependencies>
107113

108114
</project>

auto-configurations/models/spring-ai-autoconfigure-model-azure-openai/src/main/java/org/springframework/ai/model/azure/openai/autoconfigure/AzureOpenAiClientBuilderConfiguration.java

Lines changed: 22 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
import com.azure.core.util.ClientOptions;
2828
import com.azure.core.util.Header;
2929

30+
import com.azure.identity.DefaultAzureCredential;
31+
import com.azure.identity.DefaultAzureCredentialBuilder;
3032
import org.springframework.beans.factory.ObjectProvider;
3133
import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
3234
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
@@ -55,48 +57,39 @@ public class AzureOpenAiClientBuilderConfiguration {
5557
public OpenAIClientBuilder openAIClientBuilder(AzureOpenAiConnectionProperties connectionProperties,
5658
ObjectProvider<AzureOpenAIClientBuilderCustomizer> customizers) {
5759

58-
if (StringUtils.hasText(connectionProperties.getApiKey())) {
59-
60-
Assert.hasText(connectionProperties.getEndpoint(), "Endpoint must not be empty");
61-
62-
Map<String, String> customHeaders = connectionProperties.getCustomHeaders();
63-
List<Header> headers = customHeaders.entrySet()
64-
.stream()
65-
.map(entry -> new Header(entry.getKey(), entry.getValue()))
66-
.collect(Collectors.toList());
67-
ClientOptions clientOptions = new ClientOptions().setApplicationId(APPLICATION_ID).setHeaders(headers);
68-
OpenAIClientBuilder clientBuilder = new OpenAIClientBuilder().endpoint(connectionProperties.getEndpoint())
69-
.credential(new AzureKeyCredential(connectionProperties.getApiKey()))
70-
.clientOptions(clientOptions);
71-
applyOpenAIClientBuilderCustomizers(clientBuilder, customizers);
72-
return clientBuilder;
73-
}
60+
final OpenAIClientBuilder clientBuilder;
7461

7562
// Connect to OpenAI (e.g. not the Azure OpenAI). The deploymentName property is
7663
// used as OpenAI model name.
7764
if (StringUtils.hasText(connectionProperties.getOpenAiApiKey())) {
78-
OpenAIClientBuilder clientBuilder = new OpenAIClientBuilder().endpoint("https://api.openai.com/v1")
65+
clientBuilder = new OpenAIClientBuilder().endpoint("https://api.openai.com/v1")
7966
.credential(new KeyCredential(connectionProperties.getOpenAiApiKey()))
8067
.clientOptions(new ClientOptions().setApplicationId(APPLICATION_ID));
8168
applyOpenAIClientBuilderCustomizers(clientBuilder, customizers);
8269
return clientBuilder;
8370
}
8471

85-
throw new IllegalArgumentException("Either API key or OpenAI API key must not be empty");
86-
}
72+
Map<String, String> customHeaders = connectionProperties.getCustomHeaders();
73+
List<Header> headers = customHeaders.entrySet()
74+
.stream()
75+
.map(entry -> new Header(entry.getKey(), entry.getValue()))
76+
.collect(Collectors.toList());
77+
ClientOptions clientOptions = new ClientOptions().setApplicationId(APPLICATION_ID).setHeaders(headers);
8778

88-
@Bean
89-
@ConditionalOnMissingBean
90-
@ConditionalOnBean(TokenCredential.class)
91-
public OpenAIClientBuilder openAIClientWithTokenCredential(AzureOpenAiConnectionProperties connectionProperties,
92-
TokenCredential tokenCredential, ObjectProvider<AzureOpenAIClientBuilderCustomizer> customizers) {
93-
94-
Assert.notNull(tokenCredential, "TokenCredential must not be null");
9579
Assert.hasText(connectionProperties.getEndpoint(), "Endpoint must not be empty");
9680

97-
OpenAIClientBuilder clientBuilder = new OpenAIClientBuilder().endpoint(connectionProperties.getEndpoint())
98-
.credential(tokenCredential)
99-
.clientOptions(new ClientOptions().setApplicationId(APPLICATION_ID));
81+
if (!StringUtils.hasText(connectionProperties.getApiKey())) {
82+
// Entra ID configuration, as the API key is not set
83+
clientBuilder = new OpenAIClientBuilder().endpoint(connectionProperties.getEndpoint())
84+
.credential(new DefaultAzureCredentialBuilder().build())
85+
.clientOptions(clientOptions);
86+
}
87+
else {
88+
// Azure OpenAI configuration using API key and endpoint
89+
clientBuilder = new OpenAIClientBuilder().endpoint(connectionProperties.getEndpoint())
90+
.credential(new AzureKeyCredential(connectionProperties.getApiKey()))
91+
.clientOptions(clientOptions);
92+
}
10093
applyOpenAIClientBuilderCustomizers(clientBuilder, customizers);
10194
return clientBuilder;
10295
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,286 @@
1+
/*
2+
* Copyright 2023-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.model.azure.openai.autoconfigure;
18+
19+
import java.lang.reflect.Field;
20+
import java.net.URI;
21+
import java.util.List;
22+
import java.util.Map;
23+
import java.util.concurrent.atomic.AtomicBoolean;
24+
import java.util.stream.Collectors;
25+
26+
import com.azure.ai.openai.OpenAIClient;
27+
import com.azure.ai.openai.OpenAIClientBuilder;
28+
import com.azure.ai.openai.implementation.OpenAIClientImpl;
29+
import com.azure.core.http.HttpHeader;
30+
import com.azure.core.http.HttpHeaderName;
31+
import com.azure.core.http.HttpMethod;
32+
import com.azure.core.http.HttpPipeline;
33+
import com.azure.core.http.HttpRequest;
34+
import com.azure.core.http.HttpResponse;
35+
import org.junit.jupiter.api.Test;
36+
import org.junit.jupiter.api.condition.DisabledIfEnvironmentVariable;
37+
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
38+
import reactor.core.publisher.Flux;
39+
40+
import org.springframework.ai.azure.openai.AzureOpenAiAudioTranscriptionModel;
41+
import org.springframework.ai.azure.openai.AzureOpenAiChatModel;
42+
import org.springframework.ai.azure.openai.AzureOpenAiEmbeddingModel;
43+
import org.springframework.ai.chat.messages.AssistantMessage;
44+
import org.springframework.ai.chat.messages.Message;
45+
import org.springframework.ai.chat.messages.UserMessage;
46+
import org.springframework.ai.chat.model.ChatResponse;
47+
import org.springframework.ai.chat.model.Generation;
48+
import org.springframework.ai.chat.prompt.Prompt;
49+
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
50+
import org.springframework.ai.embedding.EmbeddingResponse;
51+
import org.springframework.boot.autoconfigure.AutoConfigurations;
52+
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
53+
import org.springframework.core.io.ClassPathResource;
54+
import org.springframework.core.io.Resource;
55+
import org.springframework.util.ReflectionUtils;
56+
57+
import static org.assertj.core.api.Assertions.assertThat;
58+
59+
/**
60+
* @author Christian Tzolov
61+
* @author Piotr Olaszewski
62+
* @author Soby Chacko
63+
* @author Manuel Andreo Garcia
64+
* @since 0.8.0
65+
*/
66+
@DisabledIfEnvironmentVariable(named = "AZURE_OPENAI_API_KEY", matches = ".+")
67+
@EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_ENDPOINT", matches = ".+")
68+
class AzureOpenAiAutoConfigurationEntraIT {
69+
70+
private static String CHAT_MODEL_NAME = "gpt-4o";
71+
72+
private static String EMBEDDING_MODEL_NAME = "text-embedding-ada-002";
73+
74+
private final ApplicationContextRunner contextRunner = new ApplicationContextRunner().withPropertyValues(
75+
// @formatter:off
76+
"spring.ai.azure.openai.endpoint=" + System.getenv("AZURE_OPENAI_ENDPOINT"),
77+
78+
"spring.ai.azure.openai.chat.options.deployment-name=" + CHAT_MODEL_NAME,
79+
"spring.ai.azure.openai.chat.options.temperature=0.8",
80+
"spring.ai.azure.openai.chat.options.maxTokens=123",
81+
82+
"spring.ai.azure.openai.embedding.options.deployment-name=" + EMBEDDING_MODEL_NAME,
83+
"spring.ai.azure.openai.audio.transcription.options.deployment-name=" + System.getenv("AZURE_OPENAI_TRANSCRIPTION_DEPLOYMENT_NAME")
84+
// @formatter:on
85+
);
86+
87+
private final Message systemMessage = new SystemPromptTemplate("""
88+
You are a helpful AI assistant. Your name is {name}.
89+
You are an AI assistant that helps people find information.
90+
Your name is {name}
91+
You should reply to the user's request with your name and also in the style of a {voice}.
92+
""").createMessage(Map.of("name", "Bob", "voice", "pirate"));
93+
94+
private final UserMessage userMessage = new UserMessage(
95+
"Tell me about 3 famous pirates from the Golden Age of Piracy and why they did.");
96+
97+
@Test
98+
void chatCompletion() {
99+
this.contextRunner.withConfiguration(AutoConfigurations.of(AzureOpenAiChatAutoConfiguration.class))
100+
.run(context -> {
101+
AzureOpenAiChatModel chatModel = context.getBean(AzureOpenAiChatModel.class);
102+
ChatResponse response = chatModel.call(new Prompt(List.of(this.userMessage, this.systemMessage)));
103+
assertThat(response.getResult().getOutput().getText()).contains("Blackbeard");
104+
});
105+
}
106+
107+
@Test
108+
void httpRequestContainsUserAgentAndCustomHeaders() {
109+
this.contextRunner.withConfiguration(AutoConfigurations.of(AzureOpenAiChatAutoConfiguration.class))
110+
.withPropertyValues("spring.ai.azure.openai.custom-headers.foo=bar",
111+
"spring.ai.azure.openai.custom-headers.fizz=buzz")
112+
.run(context -> {
113+
OpenAIClientBuilder openAIClientBuilder = context.getBean(OpenAIClientBuilder.class);
114+
OpenAIClient openAIClient = openAIClientBuilder.buildClient();
115+
Field serviceClientField = ReflectionUtils.findField(OpenAIClient.class, "serviceClient");
116+
assertThat(serviceClientField).isNotNull();
117+
ReflectionUtils.makeAccessible(serviceClientField);
118+
OpenAIClientImpl oaci = (OpenAIClientImpl) ReflectionUtils.getField(serviceClientField, openAIClient);
119+
assertThat(oaci).isNotNull();
120+
HttpPipeline httpPipeline = oaci.getHttpPipeline();
121+
HttpResponse httpResponse = httpPipeline
122+
.send(new HttpRequest(HttpMethod.POST, new URI(System.getenv("AZURE_OPENAI_ENDPOINT")).toURL()))
123+
.block();
124+
assertThat(httpResponse).isNotNull();
125+
HttpHeader httpHeader = httpResponse.getRequest().getHeaders().get(HttpHeaderName.USER_AGENT);
126+
assertThat(httpHeader.getValue().startsWith("spring-ai azsdk-java-azure-ai-openai/")).isTrue();
127+
HttpHeader customHeader1 = httpResponse.getRequest().getHeaders().get("foo");
128+
assertThat(customHeader1.getValue()).isEqualTo("bar");
129+
HttpHeader customHeader2 = httpResponse.getRequest().getHeaders().get("fizz");
130+
assertThat(customHeader2.getValue()).isEqualTo("buzz");
131+
});
132+
}
133+
134+
@Test
135+
void chatCompletionStreaming() {
136+
this.contextRunner.withConfiguration(AutoConfigurations.of(AzureOpenAiChatAutoConfiguration.class))
137+
.run(context -> {
138+
139+
AzureOpenAiChatModel chatModel = context.getBean(AzureOpenAiChatModel.class);
140+
141+
Flux<ChatResponse> response = chatModel
142+
.stream(new Prompt(List.of(this.userMessage, this.systemMessage)));
143+
144+
List<ChatResponse> responses = response.collectList().block();
145+
assertThat(responses.size()).isGreaterThan(10);
146+
147+
String stitchedResponseContent = responses.stream()
148+
.map(ChatResponse::getResults)
149+
.flatMap(List::stream)
150+
.map(Generation::getOutput)
151+
.map(AssistantMessage::getText)
152+
.collect(Collectors.joining());
153+
154+
assertThat(stitchedResponseContent).contains("Blackbeard");
155+
});
156+
}
157+
158+
@Test
159+
void embedding() {
160+
this.contextRunner.withConfiguration(AutoConfigurations.of(AzureOpenAiEmbeddingAutoConfiguration.class))
161+
.run(context -> {
162+
AzureOpenAiEmbeddingModel embeddingModel = context.getBean(AzureOpenAiEmbeddingModel.class);
163+
164+
EmbeddingResponse embeddingResponse = embeddingModel
165+
.embedForResponse(List.of("Hello World", "World is big and salvation is near"));
166+
assertThat(embeddingResponse.getResults()).hasSize(2);
167+
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty();
168+
assertThat(embeddingResponse.getResults().get(0).getIndex()).isEqualTo(0);
169+
assertThat(embeddingResponse.getResults().get(1).getOutput()).isNotEmpty();
170+
assertThat(embeddingResponse.getResults().get(1).getIndex()).isEqualTo(1);
171+
172+
assertThat(embeddingModel.dimensions()).isEqualTo(1536);
173+
});
174+
175+
}
176+
177+
@Test
178+
@EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_TRANSCRIPTION_DEPLOYMENT_NAME", matches = ".+")
179+
void transcribe() {
180+
this.contextRunner
181+
.withConfiguration(AutoConfigurations.of(AzureOpenAiAudioTranscriptionAutoConfiguration.class))
182+
.run(context -> {
183+
AzureOpenAiAudioTranscriptionModel transcriptionModel = context
184+
.getBean(AzureOpenAiAudioTranscriptionModel.class);
185+
Resource audioFile = new ClassPathResource("/speech/jfk.flac");
186+
String response = transcriptionModel.call(audioFile);
187+
assertThat(response).isEqualTo(
188+
"And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country.");
189+
});
190+
}
191+
192+
@Test
193+
void chatActivation() {
194+
195+
// Disable the chat auto-configuration.
196+
this.contextRunner.withConfiguration(AutoConfigurations.of(AzureOpenAiChatAutoConfiguration.class))
197+
.withPropertyValues("spring.ai.model.chat=none")
198+
.run(context -> {
199+
assertThat(context.getBeansOfType(AzureOpenAiChatProperties.class)).isEmpty();
200+
assertThat(context.getBeansOfType(AzureOpenAiChatModel.class)).isEmpty();
201+
});
202+
203+
// The chat auto-configuration is enabled by default.
204+
this.contextRunner.withConfiguration(AutoConfigurations.of(AzureOpenAiChatAutoConfiguration.class))
205+
.run(context -> {
206+
assertThat(context.getBeansOfType(AzureOpenAiChatModel.class)).isNotEmpty();
207+
assertThat(context.getBeansOfType(AzureOpenAiChatProperties.class)).isNotEmpty();
208+
});
209+
210+
// Explicitly enable the chat auto-configuration.
211+
this.contextRunner.withConfiguration(AutoConfigurations.of(AzureOpenAiChatAutoConfiguration.class))
212+
.withPropertyValues("spring.ai.model.chat=azure-openai")
213+
.run(context -> {
214+
assertThat(context.getBeansOfType(AzureOpenAiChatModel.class)).isNotEmpty();
215+
assertThat(context.getBeansOfType(AzureOpenAiChatProperties.class)).isNotEmpty();
216+
});
217+
}
218+
219+
@Test
220+
void embeddingActivation() {
221+
222+
// Disable the embedding auto-configuration.
223+
this.contextRunner.withConfiguration(AutoConfigurations.of(AzureOpenAiEmbeddingAutoConfiguration.class))
224+
.withPropertyValues("spring.ai.model.embedding=none")
225+
.run(context -> {
226+
assertThat(context.getBeansOfType(AzureOpenAiEmbeddingModel.class)).isEmpty();
227+
assertThat(context.getBeansOfType(AzureOpenAiEmbeddingProperties.class)).isEmpty();
228+
});
229+
230+
// The embedding auto-configuration is enabled by default.
231+
this.contextRunner.withConfiguration(AutoConfigurations.of(AzureOpenAiEmbeddingAutoConfiguration.class))
232+
.run(context -> {
233+
assertThat(context.getBeansOfType(AzureOpenAiEmbeddingModel.class)).isNotEmpty();
234+
assertThat(context.getBeansOfType(AzureOpenAiEmbeddingProperties.class)).isNotEmpty();
235+
});
236+
237+
// Explicitly enable the embedding auto-configuration.
238+
this.contextRunner.withConfiguration(AutoConfigurations.of(AzureOpenAiEmbeddingAutoConfiguration.class))
239+
.withPropertyValues("spring.ai.model.embedding=azure-openai")
240+
.run(context -> {
241+
assertThat(context.getBeansOfType(AzureOpenAiEmbeddingModel.class)).isNotEmpty();
242+
assertThat(context.getBeansOfType(AzureOpenAiEmbeddingProperties.class)).isNotEmpty();
243+
});
244+
}
245+
246+
@Test
247+
void audioTranscriptionActivation() {
248+
249+
// Disable the transcription auto-configuration.
250+
this.contextRunner
251+
.withConfiguration(AutoConfigurations.of(AzureOpenAiAudioTranscriptionAutoConfiguration.class))
252+
.withPropertyValues("spring.ai.model.audio.transcription=none")
253+
.run(context -> {
254+
assertThat(context.getBeansOfType(AzureOpenAiAudioTranscriptionModel.class)).isEmpty();
255+
assertThat(context.getBeansOfType(AzureOpenAiAudioTranscriptionProperties.class)).isEmpty();
256+
});
257+
258+
// The transcription auto-configuration is enabled by default.
259+
this.contextRunner
260+
.withConfiguration(AutoConfigurations.of(AzureOpenAiAudioTranscriptionAutoConfiguration.class))
261+
.run(context -> assertThat(context.getBeansOfType(AzureOpenAiAudioTranscriptionModel.class)).isNotEmpty());
262+
263+
// Explicitly enable the transcription auto-configuration.
264+
this.contextRunner
265+
.withConfiguration(AutoConfigurations.of(AzureOpenAiAudioTranscriptionAutoConfiguration.class))
266+
.withPropertyValues("spring.ai.model.audio.transcription=azure-openai")
267+
.run(context -> assertThat(context.getBeansOfType(AzureOpenAiAudioTranscriptionModel.class)).isNotEmpty());
268+
}
269+
270+
@Test
271+
void openAIClientBuilderCustomizer() {
272+
AtomicBoolean firstCustomizationApplied = new AtomicBoolean(false);
273+
AtomicBoolean secondCustomizationApplied = new AtomicBoolean(false);
274+
this.contextRunner.withConfiguration(AutoConfigurations.of(AzureOpenAiChatAutoConfiguration.class))
275+
.withBean("first", AzureOpenAIClientBuilderCustomizer.class,
276+
() -> clientBuilder -> firstCustomizationApplied.set(true))
277+
.withBean("second", AzureOpenAIClientBuilderCustomizer.class,
278+
() -> clientBuilder -> secondCustomizationApplied.set(true))
279+
.run(context -> {
280+
context.getBean(OpenAIClientBuilder.class);
281+
assertThat(firstCustomizationApplied.get()).isTrue();
282+
assertThat(secondCustomizationApplied.get()).isTrue();
283+
});
284+
}
285+
286+
}

0 commit comments

Comments
 (0)