Skip to content

Commit 8fb9d33

Browse files
author
Manuel Andreo Garcia
committed
added feature to customize the OpenAIClientBuilder whilst retaining the default auto-configuration
Signed-off-by: Manuel Andreo Garcia <[email protected]>
1 parent 1c41c6a commit 8fb9d33

File tree

3 files changed

+52
-5
lines changed

3 files changed

+52
-5
lines changed

spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,18 @@ public class AzureOpenAiAutoConfiguration {
6464

6565
private static final String APPLICATION_ID = "spring-ai";
6666

67+
@Bean
68+
@ConditionalOnMissingBean
69+
public OpenAIClientBuilderCustomizer openAIClientBuilderCustomizer() {
70+
return clientBuilder -> {
71+
};
72+
}
73+
6774
@Bean
6875
@ConditionalOnMissingBean // ({ OpenAIClient.class, TokenCredential.class })
69-
public OpenAIClientBuilder openAIClientBuilder(AzureOpenAiConnectionProperties connectionProperties) {
76+
public OpenAIClientBuilder openAIClientBuilder(AzureOpenAiConnectionProperties connectionProperties,
77+
OpenAIClientBuilderCustomizer customizer) {
78+
7079
if (StringUtils.hasText(connectionProperties.getApiKey())) {
7180

7281
Assert.hasText(connectionProperties.getEndpoint(), "Endpoint must not be empty");
@@ -77,17 +86,21 @@ public OpenAIClientBuilder openAIClientBuilder(AzureOpenAiConnectionProperties c
7786
.map(entry -> new Header(entry.getKey(), entry.getValue()))
7887
.collect(Collectors.toList());
7988
ClientOptions clientOptions = new ClientOptions().setApplicationId(APPLICATION_ID).setHeaders(headers);
80-
return new OpenAIClientBuilder().endpoint(connectionProperties.getEndpoint())
89+
OpenAIClientBuilder clientBuilder = new OpenAIClientBuilder().endpoint(connectionProperties.getEndpoint())
8190
.credential(new AzureKeyCredential(connectionProperties.getApiKey()))
8291
.clientOptions(clientOptions);
92+
customizer.customize(clientBuilder);
93+
return clientBuilder;
8394
}
8495

8596
// Connect to OpenAI (e.g. not the Azure OpenAI). The deploymentName property is
8697
// used as OpenAI model name.
8798
if (StringUtils.hasText(connectionProperties.getOpenAiApiKey())) {
88-
return new OpenAIClientBuilder().endpoint("https://api.openai.com/v1")
99+
OpenAIClientBuilder clientBuilder = new OpenAIClientBuilder().endpoint("https://api.openai.com/v1")
89100
.credential(new KeyCredential(connectionProperties.getOpenAiApiKey()))
90101
.clientOptions(new ClientOptions().setApplicationId(APPLICATION_ID));
102+
customizer.customize(clientBuilder);
103+
return clientBuilder;
91104
}
92105

93106
throw new IllegalArgumentException("Either API key or OpenAI API key must not be empty");
@@ -97,14 +110,16 @@ public OpenAIClientBuilder openAIClientBuilder(AzureOpenAiConnectionProperties c
97110
@ConditionalOnMissingBean
98111
@ConditionalOnBean(TokenCredential.class)
99112
public OpenAIClientBuilder openAIClientWithTokenCredential(AzureOpenAiConnectionProperties connectionProperties,
100-
TokenCredential tokenCredential) {
113+
TokenCredential tokenCredential, OpenAIClientBuilderCustomizer customizer) {
101114

102115
Assert.notNull(tokenCredential, "TokenCredential must not be null");
103116
Assert.hasText(connectionProperties.getEndpoint(), "Endpoint must not be empty");
104117

105-
return new OpenAIClientBuilder().endpoint(connectionProperties.getEndpoint())
118+
OpenAIClientBuilder clientBuilder = new OpenAIClientBuilder().endpoint(connectionProperties.getEndpoint())
106119
.credential(tokenCredential)
107120
.clientOptions(new ClientOptions().setApplicationId(APPLICATION_ID));
121+
customizer.customize(clientBuilder);
122+
return clientBuilder;
108123
}
109124

110125
@Bean
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
package org.springframework.ai.autoconfigure.azure.openai;
2+
3+
import com.azure.ai.openai.OpenAIClientBuilder;
4+
5+
/**
6+
* Callback interface that can be implemented by beans wishing to customize the
7+
* {@link OpenAIClientBuilder} whilst retaining the default auto-configuration.
8+
*/
9+
@FunctionalInterface
10+
public interface OpenAIClientBuilderCustomizer {
11+
12+
/**
13+
* Customize the {@link OpenAIClientBuilder}.
14+
* @param clientBuilder the {@link OpenAIClientBuilder} to customize
15+
*/
16+
void customize(OpenAIClientBuilder clientBuilder);
17+
18+
}

spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.net.URI;
2121
import java.util.List;
2222
import java.util.Map;
23+
import java.util.concurrent.atomic.AtomicBoolean;
2324
import java.util.stream.Collectors;
2425

2526
import com.azure.ai.openai.OpenAIClient;
@@ -33,6 +34,7 @@
3334
import com.azure.core.http.HttpResponse;
3435
import org.junit.jupiter.api.Test;
3536
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
37+
import org.springframework.ai.autoconfigure.azure.openai.OpenAIClientBuilderCustomizer;
3638
import reactor.core.publisher.Flux;
3739

3840
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration;
@@ -228,4 +230,16 @@ void audioTranscriptionActivation() {
228230
.run(context -> assertThat(context.getBeansOfType(AzureOpenAiAudioTranscriptionModel.class)).isNotEmpty());
229231
}
230232

233+
@Test
234+
void openAIClientBuilderCustomizer() {
235+
AtomicBoolean customized = new AtomicBoolean(false);
236+
this.contextRunner
237+
.withBean(OpenAIClientBuilderCustomizer.class,
238+
(OpenAIClientBuilderCustomizer) clientBuilder -> customized.set(true))
239+
.run(context -> {
240+
context.getBean(OpenAIClientBuilder.class);
241+
assertThat(customized.get()).isEqualTo(true);
242+
});
243+
}
244+
231245
}

0 commit comments

Comments
 (0)