Skip to content

Commit dc3279b

Browse files
committed
added feature to customize the OpenAIClientBuilder whilst retaining the default auto-configuration
Signed-off-by: Manuel Andreo Garcia <[email protected]>
1 parent 1c41c6a commit dc3279b

File tree

3 files changed

+61
-5
lines changed

3 files changed

+61
-5
lines changed

spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,18 @@ public class AzureOpenAiAutoConfiguration {
6464

6565
private static final String APPLICATION_ID = "spring-ai";
6666

67+
@Bean
68+
@ConditionalOnMissingBean
69+
public OpenAIClientBuilderCustomizer openAIClientBuilderCustomizer() {
70+
return clientBuilder -> {
71+
};
72+
}
73+
6774
@Bean
6875
@ConditionalOnMissingBean // ({ OpenAIClient.class, TokenCredential.class })
69-
public OpenAIClientBuilder openAIClientBuilder(AzureOpenAiConnectionProperties connectionProperties) {
76+
public OpenAIClientBuilder openAIClientBuilder(AzureOpenAiConnectionProperties connectionProperties,
77+
ObjectProvider<OpenAIClientBuilderCustomizer> customizers) {
78+
7079
if (StringUtils.hasText(connectionProperties.getApiKey())) {
7180

7281
Assert.hasText(connectionProperties.getEndpoint(), "Endpoint must not be empty");
@@ -77,17 +86,21 @@ public OpenAIClientBuilder openAIClientBuilder(AzureOpenAiConnectionProperties c
7786
.map(entry -> new Header(entry.getKey(), entry.getValue()))
7887
.collect(Collectors.toList());
7988
ClientOptions clientOptions = new ClientOptions().setApplicationId(APPLICATION_ID).setHeaders(headers);
80-
return new OpenAIClientBuilder().endpoint(connectionProperties.getEndpoint())
89+
OpenAIClientBuilder clientBuilder = new OpenAIClientBuilder().endpoint(connectionProperties.getEndpoint())
8190
.credential(new AzureKeyCredential(connectionProperties.getApiKey()))
8291
.clientOptions(clientOptions);
92+
applyOpenAIClientBuilderCustomizers(clientBuilder, customizers);
93+
return clientBuilder;
8394
}
8495

8596
// Connect to OpenAI (e.g. not the Azure OpenAI). The deploymentName property is
8697
// used as OpenAI model name.
8798
if (StringUtils.hasText(connectionProperties.getOpenAiApiKey())) {
88-
return new OpenAIClientBuilder().endpoint("https://api.openai.com/v1")
99+
OpenAIClientBuilder clientBuilder = new OpenAIClientBuilder().endpoint("https://api.openai.com/v1")
89100
.credential(new KeyCredential(connectionProperties.getOpenAiApiKey()))
90101
.clientOptions(new ClientOptions().setApplicationId(APPLICATION_ID));
102+
applyOpenAIClientBuilderCustomizers(clientBuilder, customizers);
103+
return clientBuilder;
91104
}
92105

93106
throw new IllegalArgumentException("Either API key or OpenAI API key must not be empty");
@@ -97,14 +110,16 @@ public OpenAIClientBuilder openAIClientBuilder(AzureOpenAiConnectionProperties c
97110
@ConditionalOnMissingBean
98111
@ConditionalOnBean(TokenCredential.class)
99112
public OpenAIClientBuilder openAIClientWithTokenCredential(AzureOpenAiConnectionProperties connectionProperties,
100-
TokenCredential tokenCredential) {
113+
TokenCredential tokenCredential, ObjectProvider<OpenAIClientBuilderCustomizer> customizers) {
101114

102115
Assert.notNull(tokenCredential, "TokenCredential must not be null");
103116
Assert.hasText(connectionProperties.getEndpoint(), "Endpoint must not be empty");
104117

105-
return new OpenAIClientBuilder().endpoint(connectionProperties.getEndpoint())
118+
OpenAIClientBuilder clientBuilder = new OpenAIClientBuilder().endpoint(connectionProperties.getEndpoint())
106119
.credential(tokenCredential)
107120
.clientOptions(new ClientOptions().setApplicationId(APPLICATION_ID));
121+
applyOpenAIClientBuilderCustomizers(clientBuilder, customizers);
122+
return clientBuilder;
108123
}
109124

110125
@Bean
@@ -169,4 +184,9 @@ public AzureOpenAiAudioTranscriptionModel azureOpenAiAudioTranscriptionModel(Ope
169184
return new AzureOpenAiAudioTranscriptionModel(openAIClient.buildClient(), audioProperties.getOptions());
170185
}
171186

187+
private void applyOpenAIClientBuilderCustomizers(OpenAIClientBuilder clientBuilder,
188+
ObjectProvider<OpenAIClientBuilderCustomizer> customizers) {
189+
customizers.orderedStream().forEach(customizer -> customizer.customize(clientBuilder));
190+
}
191+
172192
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
package org.springframework.ai.autoconfigure.azure.openai;
2+
3+
import com.azure.ai.openai.OpenAIClientBuilder;
4+
5+
/**
6+
* Callback interface that can be implemented by beans wishing to customize the
7+
* {@link OpenAIClientBuilder} whilst retaining the default auto-configuration.
8+
*/
9+
@FunctionalInterface
10+
public interface OpenAIClientBuilderCustomizer {
11+
12+
/**
13+
* Customize the {@link OpenAIClientBuilder}.
14+
* @param clientBuilder the {@link OpenAIClientBuilder} to customize
15+
*/
16+
void customize(OpenAIClientBuilder clientBuilder);
17+
18+
}

spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.net.URI;
2121
import java.util.List;
2222
import java.util.Map;
23+
import java.util.concurrent.atomic.AtomicBoolean;
2324
import java.util.stream.Collectors;
2425

2526
import com.azure.ai.openai.OpenAIClient;
@@ -33,6 +34,7 @@
3334
import com.azure.core.http.HttpResponse;
3435
import org.junit.jupiter.api.Test;
3536
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
37+
import org.springframework.ai.autoconfigure.azure.openai.OpenAIClientBuilderCustomizer;
3638
import reactor.core.publisher.Flux;
3739

3840
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration;
@@ -228,4 +230,20 @@ void audioTranscriptionActivation() {
228230
.run(context -> assertThat(context.getBeansOfType(AzureOpenAiAudioTranscriptionModel.class)).isNotEmpty());
229231
}
230232

233+
@Test
234+
void openAIClientBuilderCustomizer() {
235+
AtomicBoolean firstCustomizationApplied = new AtomicBoolean(false);
236+
AtomicBoolean secondCustomizationApplied = new AtomicBoolean(false);
237+
this.contextRunner
238+
.withBean("first", OpenAIClientBuilderCustomizer.class,
239+
() -> clientBuilder -> firstCustomizationApplied.set(true))
240+
.withBean("second", OpenAIClientBuilderCustomizer.class,
241+
() -> clientBuilder -> secondCustomizationApplied.set(true))
242+
.run(context -> {
243+
context.getBean(OpenAIClientBuilder.class);
244+
assertThat(firstCustomizationApplied.get()).isTrue();
245+
assertThat(secondCustomizationApplied.get()).isTrue();
246+
});
247+
}
248+
231249
}

0 commit comments

Comments
 (0)