From 6ba81ad6ef9f7561988b2d937d5405eeb2b6c574 Mon Sep 17 00:00:00 2001 From: Soby Chacko Date: Fri, 20 Sep 2024 14:38:58 -0400 Subject: [PATCH] Adding integration test for Azure custom headers --- .../azure/AzureOpenAiAutoConfigurationIT.java | 40 +++++++++++-------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java index 11cd9409fbd..c0a7245659d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.azure; import com.azure.ai.openai.OpenAIClient; @@ -95,22 +96,29 @@ void chatCompletion() { } @Test - void httpRequestContainsUserAgentHeader() { - contextRunner.run(context -> { - OpenAIClient openAIClient = context.getBean(OpenAIClient.class); - Field serviceClientField = ReflectionUtils.findField(OpenAIClient.class, "serviceClient"); - assertThat(serviceClientField).isNotNull(); - ReflectionUtils.makeAccessible(serviceClientField); - OpenAIClientImpl oaci = (OpenAIClientImpl) ReflectionUtils.getField(serviceClientField, openAIClient); - assertThat(oaci).isNotNull(); - HttpPipeline httpPipeline = oaci.getHttpPipeline(); - HttpResponse httpResponse = httpPipeline - .send(new HttpRequest(HttpMethod.POST, new URI(System.getenv("AZURE_OPENAI_ENDPOINT")).toURL())) - .block(); - assertThat(httpResponse).isNotNull(); - HttpHeader httpHeader = httpResponse.getRequest().getHeaders().get(HttpHeaderName.USER_AGENT); - assertThat(httpHeader.getValue().startsWith("spring-ai azsdk-java-azure-ai-openai/")).isTrue(); - }); + void httpRequestContainsUserAgentAndCustomHeaders() { + contextRunner + .withPropertyValues("spring.ai.azure.openai.custom-headers.foo=bar", + "spring.ai.azure.openai.custom-headers.fizz=buzz") + .run(context -> { + OpenAIClient openAIClient = context.getBean(OpenAIClient.class); + Field serviceClientField = ReflectionUtils.findField(OpenAIClient.class, "serviceClient"); + assertThat(serviceClientField).isNotNull(); + ReflectionUtils.makeAccessible(serviceClientField); + OpenAIClientImpl oaci = (OpenAIClientImpl) ReflectionUtils.getField(serviceClientField, openAIClient); + assertThat(oaci).isNotNull(); + HttpPipeline httpPipeline = oaci.getHttpPipeline(); + HttpResponse httpResponse = httpPipeline + .send(new HttpRequest(HttpMethod.POST, new URI(System.getenv("AZURE_OPENAI_ENDPOINT")).toURL())) + .block(); + assertThat(httpResponse).isNotNull(); + HttpHeader httpHeader = httpResponse.getRequest().getHeaders().get(HttpHeaderName.USER_AGENT); + assertThat(httpHeader.getValue().startsWith("spring-ai azsdk-java-azure-ai-openai/")).isTrue(); + HttpHeader customHeader1 = httpResponse.getRequest().getHeaders().get("foo"); + assertThat(customHeader1.getValue()).isEqualTo("bar"); + HttpHeader customHeader2 = httpResponse.getRequest().getHeaders().get("fizz"); + assertThat(customHeader2.getValue()).isEqualTo("buzz"); + }); } @Test