diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java index 407efb82ef2..6f43eceae8c 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java @@ -16,6 +16,8 @@ package org.springframework.ai.autoconfigure.azure.openai; import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; import org.springframework.ai.azure.openai.AzureOpenAiAudioTranscriptionModel; import org.springframework.ai.azure.openai.AzureOpenAiChatModel; @@ -40,6 +42,7 @@ import com.azure.core.credential.KeyCredential; import com.azure.core.credential.TokenCredential; import com.azure.core.util.ClientOptions; +import com.azure.core.util.Header; /** * @author Piotr Olaszewski @@ -57,14 +60,19 @@ public class AzureOpenAiAutoConfiguration { @Bean @ConditionalOnMissingBean({ OpenAIClient.class, TokenCredential.class }) public OpenAIClient openAIClient(AzureOpenAiConnectionProperties connectionProperties) { - if (StringUtils.hasText(connectionProperties.getApiKey())) { Assert.hasText(connectionProperties.getEndpoint(), "Endpoint must not be empty"); + Map customHeaders = connectionProperties.getCustomHeaders(); + List
headers = customHeaders.entrySet() + .stream() + .map(entry -> new Header(entry.getKey(), entry.getValue())) + .collect(Collectors.toList()); + ClientOptions clientOptions = new ClientOptions().setApplicationId(APPLICATION_ID).setHeaders(headers); return new OpenAIClientBuilder().endpoint(connectionProperties.getEndpoint()) .credential(new AzureKeyCredential(connectionProperties.getApiKey())) - .clientOptions(new ClientOptions().setApplicationId(APPLICATION_ID)) + .clientOptions(clientOptions) .buildClient(); } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiConnectionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiConnectionProperties.java index cabd6b2e751..16a128260ea 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiConnectionProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiConnectionProperties.java @@ -1,5 +1,5 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,8 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.azure.openai; +import java.util.HashMap; +import java.util.Map; + import org.springframework.boot.context.properties.ConfigurationProperties; @ConfigurationProperties(AzureOpenAiConnectionProperties.CONFIG_PREFIX) @@ -40,6 +44,8 @@ public class AzureOpenAiConnectionProperties { */ private String endpoint; + private Map customHeaders = new HashMap<>(); + public String getEndpoint() { return this.endpoint; } @@ -64,4 +70,12 @@ public void setOpenAiApiKey(String openAiApiKey) { this.openAiApiKey = openAiApiKey; } + public Map getCustomHeaders() { + return customHeaders; + } + + public void setCustomHeaders(Map customHeaders) { + this.customHeaders = customHeaders; + } + } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java index 857d85f520b..4478e23685c 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java @@ -15,8 +15,17 @@ */ package org.springframework.ai.autoconfigure.azure; +import static org.assertj.core.api.Assertions.assertThat; + +import java.lang.reflect.Field; +import java.net.URI; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration; import org.springframework.ai.azure.openai.AzureOpenAiAudioTranscriptionModel; import org.springframework.ai.azure.openai.AzureOpenAiChatModel; @@ -33,14 +42,18 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.core.io.ClassPathResource; import org.springframework.core.io.Resource; +import org.springframework.util.ReflectionUtils; + +import com.azure.ai.openai.OpenAIClient; +import com.azure.ai.openai.implementation.OpenAIClientImpl; +import com.azure.core.http.HttpHeader; +import com.azure.core.http.HttpHeaderName; +import com.azure.core.http.HttpMethod; +import com.azure.core.http.HttpPipeline; +import com.azure.core.http.HttpRequest; +import com.azure.core.http.HttpResponse; import reactor.core.publisher.Flux; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; - -import static org.assertj.core.api.Assertions.assertThat; - /** * @author Christian Tzolov * @author Piotr Olaszewski @@ -87,6 +100,25 @@ public void chatCompletion() { }); } + @Test + void httpRequestContainsUserAgentHeader() { + contextRunner.run(context -> { + OpenAIClient openAIClient = context.getBean(OpenAIClient.class); + Field serviceClientField = ReflectionUtils.findField(OpenAIClient.class, "serviceClient"); + assertThat(serviceClientField).isNotNull(); + ReflectionUtils.makeAccessible(serviceClientField); + OpenAIClientImpl oaci = (OpenAIClientImpl) ReflectionUtils.getField(serviceClientField, openAIClient); + assertThat(oaci).isNotNull(); + HttpPipeline httpPipeline = oaci.getHttpPipeline(); + HttpResponse httpResponse = httpPipeline + .send(new HttpRequest(HttpMethod.POST, new URI(System.getenv("AZURE_OPENAI_ENDPOINT")).toURL())) + .block(); + assertThat(httpResponse).isNotNull(); + HttpHeader httpHeader = httpResponse.getRequest().getHeaders().get(HttpHeaderName.USER_AGENT); + assertThat(httpHeader.getValue().startsWith("spring-ai azsdk-java-azure-ai-openai/")).isTrue(); + }); + } + @Test public void chatCompletionStreaming() { contextRunner.run(context -> {