Skip to content

Commit 7da0f0a

Browse files
committed
GH-1284: Custom Header Support in Azure OpenAI
Resolves #1284 Support adding custom headers in Azure OpenAI via configuration properties.
1 parent 035036c commit 7da0f0a

File tree

3 files changed

+63
-9
lines changed

3 files changed

+63
-9
lines changed

spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
package org.springframework.ai.autoconfigure.azure.openai;
1717

1818
import java.util.List;
19+
import java.util.Map;
20+
import java.util.stream.Collectors;
1921

2022
import org.springframework.ai.azure.openai.AzureOpenAiAudioTranscriptionModel;
2123
import org.springframework.ai.azure.openai.AzureOpenAiChatModel;
@@ -40,6 +42,7 @@
4042
import com.azure.core.credential.KeyCredential;
4143
import com.azure.core.credential.TokenCredential;
4244
import com.azure.core.util.ClientOptions;
45+
import com.azure.core.util.Header;
4346

4447
/**
4548
* @author Piotr Olaszewski
@@ -57,14 +60,19 @@ public class AzureOpenAiAutoConfiguration {
5760
@Bean
5861
@ConditionalOnMissingBean({ OpenAIClient.class, TokenCredential.class })
5962
public OpenAIClient openAIClient(AzureOpenAiConnectionProperties connectionProperties) {
60-
6163
if (StringUtils.hasText(connectionProperties.getApiKey())) {
6264

6365
Assert.hasText(connectionProperties.getEndpoint(), "Endpoint must not be empty");
6466

67+
Map<String, String> customHeaders = connectionProperties.getCustomHeaders();
68+
List<Header> headers = customHeaders.entrySet()
69+
.stream()
70+
.map(entry -> new Header(entry.getKey(), entry.getValue()))
71+
.collect(Collectors.toList());
72+
ClientOptions clientOptions = new ClientOptions().setApplicationId(APPLICATION_ID).setHeaders(headers);
6573
return new OpenAIClientBuilder().endpoint(connectionProperties.getEndpoint())
6674
.credential(new AzureKeyCredential(connectionProperties.getApiKey()))
67-
.clientOptions(new ClientOptions().setApplicationId(APPLICATION_ID))
75+
.clientOptions(clientOptions)
6876
.buildClient();
6977
}
7078

spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiConnectionProperties.java

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023 - 2024 the original author or authors.
2+
* Copyright 2023-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -13,8 +13,12 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16+
1617
package org.springframework.ai.autoconfigure.azure.openai;
1718

19+
import java.util.HashMap;
20+
import java.util.Map;
21+
1822
import org.springframework.boot.context.properties.ConfigurationProperties;
1923

2024
@ConfigurationProperties(AzureOpenAiConnectionProperties.CONFIG_PREFIX)
@@ -40,6 +44,8 @@ public class AzureOpenAiConnectionProperties {
4044
*/
4145
private String endpoint;
4246

47+
private Map<String, String> customHeaders = new HashMap<>();
48+
4349
public String getEndpoint() {
4450
return this.endpoint;
4551
}
@@ -64,4 +70,12 @@ public void setOpenAiApiKey(String openAiApiKey) {
6470
this.openAiApiKey = openAiApiKey;
6571
}
6672

73+
public Map<String, String> getCustomHeaders() {
74+
return customHeaders;
75+
}
76+
77+
public void setCustomHeaders(Map<String, String> customHeaders) {
78+
this.customHeaders = customHeaders;
79+
}
80+
6781
}

spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,17 @@
1515
*/
1616
package org.springframework.ai.autoconfigure.azure;
1717

18+
import static org.assertj.core.api.Assertions.assertThat;
19+
20+
import java.lang.reflect.Field;
21+
import java.net.URI;
22+
import java.util.List;
23+
import java.util.Map;
24+
import java.util.stream.Collectors;
25+
1826
import org.junit.jupiter.api.Test;
1927
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
28+
2029
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration;
2130
import org.springframework.ai.azure.openai.AzureOpenAiAudioTranscriptionModel;
2231
import org.springframework.ai.azure.openai.AzureOpenAiChatModel;
@@ -33,14 +42,18 @@
3342
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
3443
import org.springframework.core.io.ClassPathResource;
3544
import org.springframework.core.io.Resource;
45+
import org.springframework.util.ReflectionUtils;
46+
47+
import com.azure.ai.openai.OpenAIClient;
48+
import com.azure.ai.openai.implementation.OpenAIClientImpl;
49+
import com.azure.core.http.HttpHeader;
50+
import com.azure.core.http.HttpHeaderName;
51+
import com.azure.core.http.HttpMethod;
52+
import com.azure.core.http.HttpPipeline;
53+
import com.azure.core.http.HttpRequest;
54+
import com.azure.core.http.HttpResponse;
3655
import reactor.core.publisher.Flux;
3756

38-
import java.util.List;
39-
import java.util.Map;
40-
import java.util.stream.Collectors;
41-
42-
import static org.assertj.core.api.Assertions.assertThat;
43-
4457
/**
4558
* @author Christian Tzolov
4659
* @author Piotr Olaszewski
@@ -87,6 +100,25 @@ public void chatCompletion() {
87100
});
88101
}
89102

103+
@Test
104+
void httpRequestContainsUserAgentHeader() {
105+
contextRunner.run(context -> {
106+
OpenAIClient openAIClient = context.getBean(OpenAIClient.class);
107+
Field serviceClientField = ReflectionUtils.findField(OpenAIClient.class, "serviceClient");
108+
assertThat(serviceClientField).isNotNull();
109+
ReflectionUtils.makeAccessible(serviceClientField);
110+
OpenAIClientImpl oaci = (OpenAIClientImpl) ReflectionUtils.getField(serviceClientField, openAIClient);
111+
assertThat(oaci).isNotNull();
112+
HttpPipeline httpPipeline = oaci.getHttpPipeline();
113+
HttpResponse httpResponse = httpPipeline
114+
.send(new HttpRequest(HttpMethod.POST, new URI(System.getenv("AZURE_OPENAI_ENDPOINT")).toURL()))
115+
.block();
116+
assertThat(httpResponse).isNotNull();
117+
HttpHeader httpHeader = httpResponse.getRequest().getHeaders().get(HttpHeaderName.USER_AGENT);
118+
assertThat(httpHeader.getValue().startsWith("spring-ai azsdk-java-azure-ai-openai/")).isTrue();
119+
});
120+
}
121+
90122
@Test
91123
public void chatCompletionStreaming() {
92124
contextRunner.run(context -> {

0 commit comments

Comments
 (0)