Skip to content

Commit 245fefd

Browse files
committed
Add flexible API key management for OpenAI
Introduces a new API key management system that allows users to customize how API keys are provided and managed in their Spring AI applications. This change improves security and flexibility by: - Adding core ApiKey interface and SimpleApiKey implementation - Introducing @OPENAIAPIKEY qualifier for bean disambiguation - Supporting custom API key providers for secure key management - Adding auto-configuration support for API key injection - Adding builder pattern for OpenAiApi configuration - Deprecating public constructors in favor of builder API (since 1.0.0.M6) The new system enables users to implement their own key management strategies while maintaining backward compatibility with property-based configuration.
1 parent 4e1358a commit 245fefd

File tree

9 files changed

+643
-13
lines changed

9 files changed

+643
-13
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@
3131
import reactor.core.publisher.Flux;
3232
import reactor.core.publisher.Mono;
3333

34+
import org.springframework.ai.model.ApiKey;
3435
import org.springframework.ai.model.ChatModelDescription;
3536
import org.springframework.ai.model.ModelOptionsUtils;
37+
import org.springframework.ai.model.SimpleApiKey;
3638
import org.springframework.ai.openai.api.common.OpenAiApiConstants;
3739
import org.springframework.ai.retry.RetryUtils;
3840
import org.springframework.core.ParameterizedTypeReference;
@@ -61,6 +63,10 @@
6163
*/
6264
public class OpenAiApi {
6365

66+
public static Builder builder() {
67+
return new Builder();
68+
}
69+
6470
public static final OpenAiApi.ChatModel DEFAULT_CHAT_MODEL = ChatModel.GPT_4_O;
6571

6672
public static final String DEFAULT_EMBEDDING_MODEL = EmbeddingModel.TEXT_EMBEDDING_ADA_002.getValue();
@@ -80,7 +86,9 @@ public class OpenAiApi {
8086
/**
8187
* Create a new chat completion api with base URL set to https://api.openai.com
8288
* @param apiKey OpenAI apiKey.
89+
* @deprecated since 1.0.0.M6 - use {@link #builder()} instead
8390
*/
91+
@Deprecated(since = "1.0.0.M6")
8492
public OpenAiApi(String apiKey) {
8593
this(OpenAiApiConstants.DEFAULT_BASE_URL, apiKey);
8694
}
@@ -89,7 +97,9 @@ public OpenAiApi(String apiKey) {
8997
* Create a new chat completion api.
9098
* @param baseUrl api base URL.
9199
* @param apiKey OpenAI apiKey.
100+
* @deprecated since 1.0.0.M6 - use {@link #builder()} instead
92101
*/
102+
@Deprecated(since = "1.0.0.M6")
93103
public OpenAiApi(String baseUrl, String apiKey) {
94104
this(baseUrl, apiKey, RestClient.builder(), WebClient.builder());
95105
}
@@ -100,7 +110,9 @@ public OpenAiApi(String baseUrl, String apiKey) {
100110
* @param apiKey OpenAI apiKey.
101111
* @param restClientBuilder RestClient builder.
102112
* @param webClientBuilder WebClient builder.
113+
* @deprecated since 1.0.0.M6 - use {@link #builder()} instead
103114
*/
115+
@Deprecated(since = "1.0.0.M6")
104116
public OpenAiApi(String baseUrl, String apiKey, RestClient.Builder restClientBuilder,
105117
WebClient.Builder webClientBuilder) {
106118
this(baseUrl, apiKey, restClientBuilder, webClientBuilder, RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER);
@@ -113,7 +125,9 @@ public OpenAiApi(String baseUrl, String apiKey, RestClient.Builder restClientBui
113125
* @param restClientBuilder RestClient builder.
114126
* @param webClientBuilder WebClient builder.
115127
* @param responseErrorHandler Response error handler.
128+
* @deprecated since 1.0.0.M6 - use {@link #builder()} instead
116129
*/
130+
@Deprecated(since = "1.0.0.M6")
117131
public OpenAiApi(String baseUrl, String apiKey, RestClient.Builder restClientBuilder,
118132
WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) {
119133
this(baseUrl, apiKey, "/v1/chat/completions", "/v1/embeddings", restClientBuilder, webClientBuilder,
@@ -129,7 +143,9 @@ public OpenAiApi(String baseUrl, String apiKey, RestClient.Builder restClientBui
129143
* @param restClientBuilder RestClient builder.
130144
* @param webClientBuilder WebClient builder.
131145
* @param responseErrorHandler Response error handler.
146+
* @deprecated since 1.0.0.M6 - use {@link #builder()} instead
132147
*/
148+
@Deprecated(since = "1.0.0.M6")
133149
public OpenAiApi(String baseUrl, String apiKey, String completionsPath, String embeddingsPath,
134150
RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder,
135151
ResponseErrorHandler responseErrorHandler) {
@@ -148,10 +164,32 @@ public OpenAiApi(String baseUrl, String apiKey, String completionsPath, String e
148164
* @param restClientBuilder RestClient builder.
149165
* @param webClientBuilder WebClient builder.
150166
* @param responseErrorHandler Response error handler.
167+
* @deprecated since 1.0.0.M6 - use {@link #builder()} instead
151168
*/
169+
@Deprecated(since = "1.0.0.M6")
152170
public OpenAiApi(String baseUrl, String apiKey, MultiValueMap<String, String> headers, String completionsPath,
153171
String embeddingsPath, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder,
154172
ResponseErrorHandler responseErrorHandler) {
173+
this(baseUrl, new SimpleApiKey(apiKey), headers, completionsPath, embeddingsPath, restClientBuilder,
174+
webClientBuilder, responseErrorHandler);
175+
}
176+
177+
/**
178+
* Create a new chat completion api.
179+
* @param baseUrl api base URL.
180+
* @param apiKey OpenAI apiKey.
181+
* @param headers the http headers to use.
182+
* @param completionsPath the path to the chat completions endpoint.
183+
* @param embeddingsPath the path to the embeddings endpoint.
184+
* @param restClientBuilder RestClient builder.
185+
* @param webClientBuilder WebClient builder.
186+
* @param responseErrorHandler Response error handler.
187+
* @deprecated since 1.0.0.M6 - use {@link #builder()} instead
188+
*/
189+
@Deprecated(since = "1.0.0.M6")
190+
public OpenAiApi(String baseUrl, ApiKey apiKey, MultiValueMap<String, String> headers, String completionsPath,
191+
String embeddingsPath, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder,
192+
ResponseErrorHandler responseErrorHandler) {
155193

156194
Assert.hasText(completionsPath, "Completions Path must not be null");
157195
Assert.hasText(embeddingsPath, "Embeddings Path must not be null");
@@ -161,7 +199,7 @@ public OpenAiApi(String baseUrl, String apiKey, MultiValueMap<String, String> he
161199
this.embeddingsPath = embeddingsPath;
162200
// @formatter:off
163201
Consumer<HttpHeaders> finalHeaders = h -> {
164-
h.setBearerAuth(apiKey);
202+
h.setBearerAuth(apiKey.getValue());
165203
h.setContentType(MediaType.APPLICATION_JSON);
166204
h.addAll(headers);
167205
};
@@ -1507,4 +1545,78 @@ public record EmbeddingList<T>(// @formatter:off
15071545
@JsonProperty("usage") Usage usage) { // @formatter:on
15081546
}
15091547

1548+
public static class Builder {
1549+
1550+
private String baseUrl = OpenAiApiConstants.DEFAULT_BASE_URL;
1551+
1552+
private ApiKey apiKey;
1553+
1554+
private MultiValueMap<String, String> headers = new LinkedMultiValueMap<>();
1555+
1556+
private String completionsPath = "/v1/chat/completions";
1557+
1558+
private String embeddingsPath = "/v1/embeddings";
1559+
1560+
private RestClient.Builder restClientBuilder = RestClient.builder();
1561+
1562+
private WebClient.Builder webClientBuilder = WebClient.builder();
1563+
1564+
private ResponseErrorHandler responseErrorHandler = RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER;
1565+
1566+
public Builder baseUrl(String baseUrl) {
1567+
Assert.hasText(baseUrl, "baseUrl cannot be null or empty");
1568+
this.baseUrl = baseUrl;
1569+
return this;
1570+
}
1571+
1572+
public Builder apiKey(ApiKey apiKey) {
1573+
Assert.notNull(apiKey, "apiKey cannot be null");
1574+
this.apiKey = apiKey;
1575+
return this;
1576+
}
1577+
1578+
public Builder headers(MultiValueMap<String, String> headers) {
1579+
Assert.notNull(headers, "headers cannot be null");
1580+
this.headers = headers;
1581+
return this;
1582+
}
1583+
1584+
public Builder completionsPath(String completionsPath) {
1585+
Assert.hasText(completionsPath, "completionsPath cannot be null or empty");
1586+
this.completionsPath = completionsPath;
1587+
return this;
1588+
}
1589+
1590+
public Builder embeddingsPath(String embeddingsPath) {
1591+
Assert.hasText(embeddingsPath, "embeddingsPath cannot be null or empty");
1592+
this.embeddingsPath = embeddingsPath;
1593+
return this;
1594+
}
1595+
1596+
public Builder restClientBuilder(RestClient.Builder restClientBuilder) {
1597+
Assert.notNull(restClientBuilder, "restClientBuilder cannot be null");
1598+
this.restClientBuilder = restClientBuilder;
1599+
return this;
1600+
}
1601+
1602+
public Builder webClientBuilder(WebClient.Builder webClientBuilder) {
1603+
Assert.notNull(webClientBuilder, "webClientBuilder cannot be null");
1604+
this.webClientBuilder = webClientBuilder;
1605+
return this;
1606+
}
1607+
1608+
public Builder responseErrorHandler(ResponseErrorHandler responseErrorHandler) {
1609+
Assert.notNull(responseErrorHandler, "responseErrorHandler cannot be null");
1610+
this.responseErrorHandler = responseErrorHandler;
1611+
return this;
1612+
}
1613+
1614+
public OpenAiApi build() {
1615+
Assert.notNull(this.apiKey, "apiKey must be set");
1616+
return new OpenAiApi(this.baseUrl, this.apiKey, this.headers, this.completionsPath, this.embeddingsPath,
1617+
this.restClientBuilder, this.webClientBuilder, this.responseErrorHandler);
1618+
}
1619+
1620+
}
1621+
15101622
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/*
2+
* Copyright 2023-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.openai.api;
18+
19+
import org.springframework.beans.factory.annotation.Qualifier;
20+
21+
import java.lang.annotation.*;
22+
23+
/**
24+
* Qualifier annotation for OpenAI API key beans. Used to distinguish OpenAI API keys from
25+
* other provider API keys.
26+
*
27+
* @author Mark Pollack
28+
*/
29+
@Target({ ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER, ElementType.TYPE, ElementType.ANNOTATION_TYPE })
30+
@Retention(RetentionPolicy.RUNTIME)
31+
@Documented
32+
@Qualifier
33+
public @interface OpenAiApiKey {
34+
35+
}
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
/*
2+
* Copyright 2023-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.openai.api;
17+
18+
import org.junit.jupiter.api.Test;
19+
20+
import org.springframework.ai.model.ApiKey;
21+
import org.springframework.ai.model.SimpleApiKey;
22+
import org.springframework.util.LinkedMultiValueMap;
23+
import org.springframework.util.MultiValueMap;
24+
import org.springframework.web.client.ResponseErrorHandler;
25+
import org.springframework.web.client.RestClient;
26+
import org.springframework.web.reactive.function.client.WebClient;
27+
28+
import static org.assertj.core.api.Assertions.assertThat;
29+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
30+
import static org.mockito.Mockito.mock;
31+
32+
public class OpenAiApiBuilderTests {
33+
34+
private static final ApiKey TEST_API_KEY = new SimpleApiKey("test-api-key");
35+
36+
private static final String TEST_BASE_URL = "https://test.openai.com";
37+
38+
private static final String TEST_COMPLETIONS_PATH = "/test/completions";
39+
40+
private static final String TEST_EMBEDDINGS_PATH = "/test/embeddings";
41+
42+
@Test
43+
void testMinimalBuilder() {
44+
OpenAiApi api = OpenAiApi.builder().apiKey(TEST_API_KEY).build();
45+
46+
assertThat(api).isNotNull();
47+
}
48+
49+
@Test
50+
void testFullBuilder() {
51+
MultiValueMap<String, String> headers = new LinkedMultiValueMap<>();
52+
headers.add("Custom-Header", "test-value");
53+
RestClient.Builder restClientBuilder = RestClient.builder();
54+
WebClient.Builder webClientBuilder = WebClient.builder();
55+
ResponseErrorHandler errorHandler = mock(ResponseErrorHandler.class);
56+
57+
OpenAiApi api = OpenAiApi.builder()
58+
.apiKey(TEST_API_KEY)
59+
.baseUrl(TEST_BASE_URL)
60+
.headers(headers)
61+
.completionsPath(TEST_COMPLETIONS_PATH)
62+
.embeddingsPath(TEST_EMBEDDINGS_PATH)
63+
.restClientBuilder(restClientBuilder)
64+
.webClientBuilder(webClientBuilder)
65+
.responseErrorHandler(errorHandler)
66+
.build();
67+
68+
assertThat(api).isNotNull();
69+
}
70+
71+
@Test
72+
void testDefaultValues() {
73+
OpenAiApi api = OpenAiApi.builder().apiKey(TEST_API_KEY).build();
74+
75+
assertThat(api).isNotNull();
76+
// We can't directly test the default values as they're private fields,
77+
// but we know the builder succeeded with defaults
78+
}
79+
80+
@Test
81+
void testMissingApiKey() {
82+
assertThatThrownBy(() -> {
83+
OpenAiApi.builder().build();
84+
}).isInstanceOf(IllegalArgumentException.class).hasMessageContaining("apiKey must be set");
85+
}
86+
87+
@Test
88+
void testInvalidBaseUrl() {
89+
assertThatThrownBy(() -> {
90+
OpenAiApi.builder().baseUrl("").build();
91+
}).isInstanceOf(IllegalArgumentException.class).hasMessageContaining("baseUrl cannot be null or empty");
92+
93+
assertThatThrownBy(() -> {
94+
OpenAiApi.builder().baseUrl(null).build();
95+
}).isInstanceOf(IllegalArgumentException.class).hasMessageContaining("baseUrl cannot be null or empty");
96+
}
97+
98+
@Test
99+
void testInvalidHeaders() {
100+
assertThatThrownBy(() -> {
101+
OpenAiApi.builder().headers(null).build();
102+
}).isInstanceOf(IllegalArgumentException.class).hasMessageContaining("headers cannot be null");
103+
}
104+
105+
@Test
106+
void testInvalidCompletionsPath() {
107+
assertThatThrownBy(() -> {
108+
OpenAiApi.builder().completionsPath("").build();
109+
}).isInstanceOf(IllegalArgumentException.class).hasMessageContaining("completionsPath cannot be null or empty");
110+
111+
assertThatThrownBy(() -> {
112+
OpenAiApi.builder().completionsPath(null).build();
113+
}).isInstanceOf(IllegalArgumentException.class).hasMessageContaining("completionsPath cannot be null or empty");
114+
}
115+
116+
@Test
117+
void testInvalidEmbeddingsPath() {
118+
assertThatThrownBy(() -> {
119+
OpenAiApi.builder().embeddingsPath("").build();
120+
}).isInstanceOf(IllegalArgumentException.class).hasMessageContaining("embeddingsPath cannot be null or empty");
121+
122+
assertThatThrownBy(() -> {
123+
OpenAiApi.builder().embeddingsPath(null).build();
124+
}).isInstanceOf(IllegalArgumentException.class).hasMessageContaining("embeddingsPath cannot be null or empty");
125+
}
126+
127+
@Test
128+
void testInvalidRestClientBuilder() {
129+
assertThatThrownBy(() -> {
130+
OpenAiApi.builder().restClientBuilder(null).build();
131+
}).isInstanceOf(IllegalArgumentException.class).hasMessageContaining("restClientBuilder cannot be null");
132+
}
133+
134+
@Test
135+
void testInvalidWebClientBuilder() {
136+
assertThatThrownBy(() -> {
137+
OpenAiApi.builder().webClientBuilder(null).build();
138+
}).isInstanceOf(IllegalArgumentException.class).hasMessageContaining("webClientBuilder cannot be null");
139+
}
140+
141+
@Test
142+
void testInvalidResponseErrorHandler() {
143+
assertThatThrownBy(() -> {
144+
OpenAiApi.builder().responseErrorHandler(null).build();
145+
}).isInstanceOf(IllegalArgumentException.class).hasMessageContaining("responseErrorHandler cannot be null");
146+
}
147+
148+
}

0 commit comments

Comments
 (0)