diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java index feed776cce4..d9409dbaeb3 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java @@ -50,7 +50,9 @@ import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; +import com.azure.ai.openai.OpenAIAsyncClient; import com.azure.ai.openai.OpenAIClient; +import com.azure.ai.openai.OpenAIClientBuilder; import com.azure.ai.openai.models.ChatChoice; import com.azure.ai.openai.models.ChatCompletions; import com.azure.ai.openai.models.ChatCompletionsFunctionToolCall; @@ -108,12 +110,17 @@ public class AzureOpenAiChatModel extends AbstractToolCallSupport implements Cha */ private final OpenAIClient openAIClient; + /** + * The {@link OpenAIAsyncClient} used for streaming async operations. + */ + private final OpenAIAsyncClient openAIAsyncClient; + /** * The configuration information for a chat completions request. */ private AzureOpenAiChatOptions defaultOptions; - public AzureOpenAiChatModel(OpenAIClient microsoftOpenAiClient) { + public AzureOpenAiChatModel(OpenAIClientBuilder microsoftOpenAiClient) { this(microsoftOpenAiClient, AzureOpenAiChatOptions.builder() .withDeploymentName(DEFAULT_DEPLOYMENT_NAME) @@ -121,21 +128,22 @@ public AzureOpenAiChatModel(OpenAIClient microsoftOpenAiClient) { .build()); } - public AzureOpenAiChatModel(OpenAIClient microsoftOpenAiClient, AzureOpenAiChatOptions options) { - this(microsoftOpenAiClient, options, null); + public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiChatOptions options) { + this(openAIClientBuilder, options, null); } - public AzureOpenAiChatModel(OpenAIClient microsoftOpenAiClient, AzureOpenAiChatOptions options, + public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiChatOptions options, FunctionCallbackContext functionCallbackContext) { - this(microsoftOpenAiClient, options, functionCallbackContext, List.of()); + this(openAIClientBuilder, options, functionCallbackContext, List.of()); } - public AzureOpenAiChatModel(OpenAIClient microsoftOpenAiClient, AzureOpenAiChatOptions options, + public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiChatOptions options, FunctionCallbackContext functionCallbackContext, List toolFunctionCallbacks) { super(functionCallbackContext, options, toolFunctionCallbacks); - Assert.notNull(microsoftOpenAiClient, "com.azure.ai.openai.OpenAIClient must not be null"); + Assert.notNull(openAIClientBuilder, "com.azure.ai.openai.OpenAIClient must not be null"); Assert.notNull(options, "AzureOpenAiChatOptions must not be null"); - this.openAIClient = microsoftOpenAiClient; + this.openAIClient = openAIClientBuilder.buildClient(); + this.openAIAsyncClient = openAIClientBuilder.buildAsyncClient(); this.defaultOptions = options; } @@ -170,11 +178,11 @@ public Flux stream(Prompt prompt) { ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt); options.setStream(true); - IterableStream chatCompletionsStream = this.openAIClient + Flux chatCompletionsStream = this.openAIAsyncClient .getChatCompletionsStream(options.getModel(), options); final var isFunctionCall = new AtomicBoolean(false); - final Flux accessibleChatCompletionsFlux = Flux.fromIterable(chatCompletionsStream) + final Flux accessibleChatCompletionsFlux = chatCompletionsStream // Note: the first chat completions can be ignored when using Azure OpenAI // service which is a known service bug. .filter(chatCompletions -> !CollectionUtils.isEmpty(chatCompletions.getChoices())) diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageModel.java index e6da1ebbf38..af7a7d0b762 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiImageModel.java @@ -44,7 +44,6 @@ public class AzureOpenAiImageModel implements ImageModel { private final Logger logger = LoggerFactory.getLogger(getClass()); - @Autowired private final OpenAIClient openAIClient; private final AzureOpenAiImageOptions defaultOptions; diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java index f7edea989b0..dbc6fa46d83 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java @@ -17,6 +17,7 @@ package org.springframework.ai.azure.openai; import com.azure.ai.openai.OpenAIClient; +import com.azure.ai.openai.OpenAIClientBuilder; import com.azure.ai.openai.models.AzureChatEnhancementConfiguration; import com.azure.ai.openai.models.AzureChatOCREnhancementConfiguration; import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat; @@ -44,7 +45,7 @@ public class AzureChatCompletionsOptionsTests { @Test public void createRequestWithChatOptions() { - OpenAIClient mockClient = Mockito.mock(OpenAIClient.class); + OpenAIClientBuilder mockClient = Mockito.mock(OpenAIClientBuilder.class); AzureChatEnhancementConfiguration mockAzureChatEnhancementConfiguration = Mockito .mock(AzureChatEnhancementConfiguration.class); diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatClientTest.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatClientTest.java new file mode 100644 index 00000000000..eec6e952624 --- /dev/null +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatClientTest.java @@ -0,0 +1,114 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.azure.openai; + +import static com.azure.core.http.policy.HttpLogDetailLevel.BODY_AND_HEADERS; +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.Arrays; +import java.util.stream.Collectors; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; + +import com.azure.ai.openai.OpenAIClientBuilder; +import com.azure.ai.openai.OpenAIServiceVersion; +import com.azure.core.credential.AzureKeyCredential; +import com.azure.core.http.policy.HttpLogOptions; + +/** + * @author Soby Chacko + */ +@SpringBootTest(classes = AzureOpenAiChatClientTest.TestConfiguration.class) +@EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_API_KEY", matches = ".+") +@EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_ENDPOINT", matches = ".+") +public class AzureOpenAiChatClientTest { + + @Autowired + private ChatClient chatClient; + + @Test + void streamingAndImperativeResponsesContainIdenticalRelevantResults() { + String prompt = "Name all states in the USA and their capitals, add a space followed by a hyphen, then another space between the two. " + + "List them with a numerical index. Do not use any abbreviations in state or capitals."; + + // Imperative call + String rawDataFromImperativeCall = chatClient.prompt(prompt).call().content(); + String imperativeStatesData = extractStatesData(rawDataFromImperativeCall); + String formattedImperativeResponse = formatResponse(imperativeStatesData); + + // Streaming call + String stitchedResponseFromStream = chatClient.prompt(prompt) + .stream() + .content() + .collectList() + .block() + .stream() + .collect(Collectors.joining()); + String streamingStatesData = extractStatesData(stitchedResponseFromStream); + String formattedStreamingResponse = formatResponse(streamingStatesData); + + // Assertions + assertThat(formattedStreamingResponse).isEqualTo(formattedImperativeResponse); + assertThat(formattedStreamingResponse).contains("1. Alabama - Montgomery"); + assertThat(formattedStreamingResponse).contains("50. Wyoming - Cheyenne"); + assertThat(formattedStreamingResponse.lines().count()).isEqualTo(50); + } + + private String extractStatesData(String rawData) { + int firstStateIndex = rawData.indexOf("1. Alabama - Montgomery"); + String lastAlphabeticalState = "50. Wyoming - Cheyenne"; + int lastStateIndex = rawData.indexOf(lastAlphabeticalState) + lastAlphabeticalState.length(); + return rawData.substring(firstStateIndex, lastStateIndex); + } + + private String formatResponse(String response) { + return String.join("\n", Arrays.stream(response.split("\n")).map(String::strip).toArray(String[]::new)); + } + + @SpringBootConfiguration + public static class TestConfiguration { + + @Bean + public OpenAIClientBuilder openAIClient() { + return new OpenAIClientBuilder().credential(new AzureKeyCredential(System.getenv("AZURE_OPENAI_API_KEY"))) + .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) + .serviceVersion(OpenAIServiceVersion.V2024_02_15_PREVIEW) + .httpLogOptions(new HttpLogOptions().setLogLevel(BODY_AND_HEADERS)); + } + + @Bean + public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder) { + return new AzureOpenAiChatModel(openAIClientBuilder, + AzureOpenAiChatOptions.builder().withDeploymentName("gpt-4o").withMaxTokens(1000).build()); + + } + + @Bean + public ChatClient chatClient(AzureOpenAiChatModel azureOpenAiChatModel) { + return ChatClient.builder(azureOpenAiChatModel).build(); + } + + } + +} diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelIT.java index 24be25953f6..14c58fc3bc8 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelIT.java @@ -15,7 +15,6 @@ */ package org.springframework.ai.azure.openai; -import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.OpenAIClientBuilder; import com.azure.ai.openai.OpenAIServiceVersion; import com.azure.core.credential.AzureKeyCredential; @@ -262,17 +261,16 @@ record ActorsFilmsRecord(String actor, List movies) { public static class TestConfiguration { @Bean - public OpenAIClient openAIClient() { + public OpenAIClientBuilder openAIClientBuilder() { return new OpenAIClientBuilder().credential(new AzureKeyCredential(System.getenv("AZURE_OPENAI_API_KEY"))) .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) .serviceVersion(OpenAIServiceVersion.V2024_02_15_PREVIEW) - .httpLogOptions(new HttpLogOptions().setLogLevel(BODY_AND_HEADERS)) - .buildClient(); + .httpLogOptions(new HttpLogOptions().setLogLevel(BODY_AND_HEADERS)); } @Bean - public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClient openAIClient) { - return new AzureOpenAiChatModel(openAIClient, + public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder) { + return new AzureOpenAiChatModel(openAIClientBuilder, AzureOpenAiChatOptions.builder().withDeploymentName("gpt-4o").withMaxTokens(1000).build()); } diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/MockAzureOpenAiTestConfiguration.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/MockAzureOpenAiTestConfiguration.java index 6ae824badcf..b57b1422a79 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/MockAzureOpenAiTestConfiguration.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/MockAzureOpenAiTestConfiguration.java @@ -1,5 +1,5 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.azure.openai; import com.azure.ai.openai.OpenAIClient; @@ -40,8 +41,9 @@ * {@link Dispatcher} to integrate with Spring {@link MockMvc}. * * @author John Blum + * @author Soby Chacko * @see org.springframework.boot.SpringBootConfiguration - * @see org.springframework.ai.test.config.MockAiTestConfiguration + * @see org.springframework.ai.azure.openai.MockAiTestConfiguration * @since 0.7.0 */ @SpringBootConfiguration @@ -51,15 +53,13 @@ public class MockAzureOpenAiTestConfiguration { @Bean - OpenAIClient microsoftAzureOpenAiClient(MockWebServer webServer) { - + OpenAIClientBuilder microsoftAzureOpenAiClient(MockWebServer webServer) { HttpUrl baseUrl = webServer.url(MockAiTestConfiguration.SPRING_AI_API_PATH); - - return new OpenAIClientBuilder().endpoint(baseUrl.toString()).buildClient(); + return new OpenAIClientBuilder().endpoint(baseUrl.toString()); } @Bean - AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClient microsoftAzureOpenAiClient) { + AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClientBuilder microsoftAzureOpenAiClient) { return new AzureOpenAiChatModel(microsoftAzureOpenAiClient); } diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java index 2b876455c4c..635407cd766 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java @@ -183,14 +183,13 @@ void functionCallSequentialAndStreamTest() { public static class TestConfiguration { @Bean - public OpenAIClient openAIClient() { + public OpenAIClientBuilder openAIClient() { return new OpenAIClientBuilder().credential(new AzureKeyCredential(System.getenv("AZURE_OPENAI_API_KEY"))) - .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) - .buildClient(); + .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")); } @Bean - public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClient openAIClient, String selectedModel) { + public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClientBuilder openAIClient, String selectedModel) { return new AzureOpenAiChatModel(openAIClient, AzureOpenAiChatOptions.builder().withDeploymentName(selectedModel).withMaxTokens(500).build()); } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java index 6f43eceae8c..64aa55c1572 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java @@ -59,7 +59,7 @@ public class AzureOpenAiAutoConfiguration { @Bean @ConditionalOnMissingBean({ OpenAIClient.class, TokenCredential.class }) - public OpenAIClient openAIClient(AzureOpenAiConnectionProperties connectionProperties) { + public OpenAIClientBuilder openAIClientBuilder(AzureOpenAiConnectionProperties connectionProperties) { if (StringUtils.hasText(connectionProperties.getApiKey())) { Assert.hasText(connectionProperties.getEndpoint(), "Endpoint must not be empty"); @@ -72,8 +72,7 @@ public OpenAIClient openAIClient(AzureOpenAiConnectionProperties connectionPrope ClientOptions clientOptions = new ClientOptions().setApplicationId(APPLICATION_ID).setHeaders(headers); return new OpenAIClientBuilder().endpoint(connectionProperties.getEndpoint()) .credential(new AzureKeyCredential(connectionProperties.getApiKey())) - .clientOptions(clientOptions) - .buildClient(); + .clientOptions(clientOptions); } // Connect to OpenAI (e.g. not the Azure OpenAI). The deploymentName property is @@ -81,8 +80,7 @@ public OpenAIClient openAIClient(AzureOpenAiConnectionProperties connectionPrope if (StringUtils.hasText(connectionProperties.getOpenAiApiKey())) { return new OpenAIClientBuilder().endpoint("https://api.openai.com/v1") .credential(new KeyCredential(connectionProperties.getOpenAiApiKey())) - .clientOptions(new ClientOptions().setApplicationId(APPLICATION_ID)) - .buildClient(); + .clientOptions(new ClientOptions().setApplicationId(APPLICATION_ID)); } throw new IllegalArgumentException("Either API key or OpenAI API key must not be empty"); @@ -91,7 +89,7 @@ public OpenAIClient openAIClient(AzureOpenAiConnectionProperties connectionPrope @Bean @ConditionalOnMissingBean @ConditionalOnBean(TokenCredential.class) - public OpenAIClient openAIClientWithTokenCredential(AzureOpenAiConnectionProperties connectionProperties, + public OpenAIClientBuilder openAIClientWithTokenCredential(AzureOpenAiConnectionProperties connectionProperties, TokenCredential tokenCredential) { Assert.notNull(tokenCredential, "TokenCredential must not be null"); @@ -99,19 +97,18 @@ public OpenAIClient openAIClientWithTokenCredential(AzureOpenAiConnectionPropert return new OpenAIClientBuilder().endpoint(connectionProperties.getEndpoint()) .credential(tokenCredential) - .clientOptions(new ClientOptions().setApplicationId(APPLICATION_ID)) - .buildClient(); + .clientOptions(new ClientOptions().setApplicationId(APPLICATION_ID)); } @Bean @ConditionalOnMissingBean @ConditionalOnProperty(prefix = AzureOpenAiChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", matchIfMissing = true) - public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClient openAIClient, + public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiChatProperties chatProperties, List toolFunctionCallbacks, FunctionCallbackContext functionCallbackContext) { - return new AzureOpenAiChatModel(openAIClient, chatProperties.getOptions(), functionCallbackContext, + return new AzureOpenAiChatModel(openAIClientBuilder, chatProperties.getOptions(), functionCallbackContext, toolFunctionCallbacks); } @@ -119,9 +116,9 @@ public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClient openAIClient, @ConditionalOnMissingBean @ConditionalOnProperty(prefix = AzureOpenAiEmbeddingProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", matchIfMissing = true) - public AzureOpenAiEmbeddingModel azureOpenAiEmbeddingModel(OpenAIClient openAIClient, + public AzureOpenAiEmbeddingModel azureOpenAiEmbeddingModel(OpenAIClientBuilder openAIClient, AzureOpenAiEmbeddingProperties embeddingProperties) { - return new AzureOpenAiEmbeddingModel(openAIClient, embeddingProperties.getMetadataMode(), + return new AzureOpenAiEmbeddingModel(openAIClient.buildClient(), embeddingProperties.getMetadataMode(), embeddingProperties.getOptions()); } @@ -137,19 +134,19 @@ public FunctionCallbackContext springAiFunctionManager(ApplicationContext contex @ConditionalOnMissingBean @ConditionalOnProperty(prefix = AzureOpenAiImageOptionsProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", matchIfMissing = true) - public AzureOpenAiImageModel azureOpenAiImageClient(OpenAIClient openAIClient, + public AzureOpenAiImageModel azureOpenAiImageClient(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiImageOptionsProperties imageProperties) { - return new AzureOpenAiImageModel(openAIClient, imageProperties.getOptions()); + return new AzureOpenAiImageModel(openAIClientBuilder.buildClient(), imageProperties.getOptions()); } @Bean @ConditionalOnMissingBean @ConditionalOnProperty(prefix = AzureOpenAiAudioTranscriptionProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", matchIfMissing = true) - public AzureOpenAiAudioTranscriptionModel azureOpenAiAudioTranscriptionModel(OpenAIClient openAIClient, + public AzureOpenAiAudioTranscriptionModel azureOpenAiAudioTranscriptionModel(OpenAIClientBuilder openAIClient, AzureOpenAiAudioTranscriptionProperties audioProperties) { - return new AzureOpenAiAudioTranscriptionModel(openAIClient, audioProperties.getOptions()); + return new AzureOpenAiAudioTranscriptionModel(openAIClient.buildClient(), audioProperties.getOptions()); } } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java index c0a7245659d..0f4e7fe4930 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java @@ -17,6 +17,7 @@ package org.springframework.ai.autoconfigure.azure; import com.azure.ai.openai.OpenAIClient; +import com.azure.ai.openai.OpenAIClientBuilder; import com.azure.ai.openai.implementation.OpenAIClientImpl; import com.azure.core.http.*; import org.junit.jupiter.api.Test; @@ -101,7 +102,8 @@ void httpRequestContainsUserAgentAndCustomHeaders() { .withPropertyValues("spring.ai.azure.openai.custom-headers.foo=bar", "spring.ai.azure.openai.custom-headers.fizz=buzz") .run(context -> { - OpenAIClient openAIClient = context.getBean(OpenAIClient.class); + OpenAIClientBuilder openAIClientBuilder = context.getBean(OpenAIClientBuilder.class); + OpenAIClient openAIClient = openAIClientBuilder.buildClient(); Field serviceClientField = ReflectionUtils.findField(OpenAIClient.class, "serviceClient"); assertThat(serviceClientField).isNotNull(); ReflectionUtils.makeAccessible(serviceClientField);