Skip to content

Commit fc2690c

Browse files
committed
Add flexible API key management for OpenAI
Introduces a new API key interface that allows users to customize how API keys are provided and managed in their Spring AI applications. This change improves security and flexibility by: - Adding core ApiKey interface and SimpleApiKey implementation - Adding builder pattern for OpenAiApi creation - Deprecating public constructors in favor of builder API (since 1.0.0.M6) - Added docs The new system enables users to implement their own key management strategies while maintaining backward compatibility with property-based configuration.
1 parent 45421b1 commit fc2690c

File tree

6 files changed

+432
-7
lines changed

6 files changed

+432
-7
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@
3131
import reactor.core.publisher.Flux;
3232
import reactor.core.publisher.Mono;
3333

34+
import org.springframework.ai.model.ApiKey;
3435
import org.springframework.ai.model.ChatModelDescription;
3536
import org.springframework.ai.model.ModelOptionsUtils;
37+
import org.springframework.ai.model.SimpleApiKey;
3638
import org.springframework.ai.openai.api.common.OpenAiApiConstants;
3739
import org.springframework.ai.retry.RetryUtils;
3840
import org.springframework.core.ParameterizedTypeReference;
@@ -62,6 +64,10 @@
6264
*/
6365
public class OpenAiApi {
6466

67+
public static Builder builder() {
68+
return new Builder();
69+
}
70+
6571
public static final OpenAiApi.ChatModel DEFAULT_CHAT_MODEL = ChatModel.GPT_4_O;
6672

6773
public static final String DEFAULT_EMBEDDING_MODEL = EmbeddingModel.TEXT_EMBEDDING_ADA_002.getValue();
@@ -81,7 +87,9 @@ public class OpenAiApi {
8187
/**
8288
* Create a new chat completion api with base URL set to https://api.openai.com
8389
* @param apiKey OpenAI apiKey.
90+
* @deprecated since 1.0.0.M6 - use {@link #builder()} instead
8491
*/
92+
@Deprecated(since = "1.0.0.M6")
8593
public OpenAiApi(String apiKey) {
8694
this(OpenAiApiConstants.DEFAULT_BASE_URL, apiKey);
8795
}
@@ -90,7 +98,9 @@ public OpenAiApi(String apiKey) {
9098
* Create a new chat completion api.
9199
* @param baseUrl api base URL.
92100
* @param apiKey OpenAI apiKey.
101+
* @deprecated since 1.0.0.M6 - use {@link #builder()} instead
93102
*/
103+
@Deprecated(since = "1.0.0.M6")
94104
public OpenAiApi(String baseUrl, String apiKey) {
95105
this(baseUrl, apiKey, RestClient.builder(), WebClient.builder());
96106
}
@@ -101,7 +111,9 @@ public OpenAiApi(String baseUrl, String apiKey) {
101111
* @param apiKey OpenAI apiKey.
102112
* @param restClientBuilder RestClient builder.
103113
* @param webClientBuilder WebClient builder.
114+
* @deprecated since 1.0.0.M6 - use {@link #builder()} instead
104115
*/
116+
@Deprecated(since = "1.0.0.M6")
105117
public OpenAiApi(String baseUrl, String apiKey, RestClient.Builder restClientBuilder,
106118
WebClient.Builder webClientBuilder) {
107119
this(baseUrl, apiKey, restClientBuilder, webClientBuilder, RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER);
@@ -114,7 +126,9 @@ public OpenAiApi(String baseUrl, String apiKey, RestClient.Builder restClientBui
114126
* @param restClientBuilder RestClient builder.
115127
* @param webClientBuilder WebClient builder.
116128
* @param responseErrorHandler Response error handler.
129+
* @deprecated since 1.0.0.M6 - use {@link #builder()} instead
117130
*/
131+
@Deprecated(since = "1.0.0.M6")
118132
public OpenAiApi(String baseUrl, String apiKey, RestClient.Builder restClientBuilder,
119133
WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) {
120134
this(baseUrl, apiKey, "/v1/chat/completions", "/v1/embeddings", restClientBuilder, webClientBuilder,
@@ -130,7 +144,9 @@ public OpenAiApi(String baseUrl, String apiKey, RestClient.Builder restClientBui
130144
* @param restClientBuilder RestClient builder.
131145
* @param webClientBuilder WebClient builder.
132146
* @param responseErrorHandler Response error handler.
147+
* @deprecated since 1.0.0.M6 - use {@link #builder()} instead
133148
*/
149+
@Deprecated(since = "1.0.0.M6")
134150
public OpenAiApi(String baseUrl, String apiKey, String completionsPath, String embeddingsPath,
135151
RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder,
136152
ResponseErrorHandler responseErrorHandler) {
@@ -149,10 +165,32 @@ public OpenAiApi(String baseUrl, String apiKey, String completionsPath, String e
149165
* @param restClientBuilder RestClient builder.
150166
* @param webClientBuilder WebClient builder.
151167
* @param responseErrorHandler Response error handler.
168+
* @deprecated since 1.0.0.M6 - use {@link #builder()} instead
152169
*/
170+
@Deprecated(since = "1.0.0.M6")
153171
public OpenAiApi(String baseUrl, String apiKey, MultiValueMap<String, String> headers, String completionsPath,
154172
String embeddingsPath, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder,
155173
ResponseErrorHandler responseErrorHandler) {
174+
this(baseUrl, new SimpleApiKey(apiKey), headers, completionsPath, embeddingsPath, restClientBuilder,
175+
webClientBuilder, responseErrorHandler);
176+
}
177+
178+
/**
179+
* Create a new chat completion api.
180+
* @param baseUrl api base URL.
181+
* @param apiKey OpenAI apiKey.
182+
* @param headers the http headers to use.
183+
* @param completionsPath the path to the chat completions endpoint.
184+
* @param embeddingsPath the path to the embeddings endpoint.
185+
* @param restClientBuilder RestClient builder.
186+
* @param webClientBuilder WebClient builder.
187+
* @param responseErrorHandler Response error handler.
188+
* @deprecated since 1.0.0.M6 - use {@link #builder()} instead
189+
*/
190+
@Deprecated(since = "1.0.0.M6")
191+
public OpenAiApi(String baseUrl, ApiKey apiKey, MultiValueMap<String, String> headers, String completionsPath,
192+
String embeddingsPath, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder,
193+
ResponseErrorHandler responseErrorHandler) {
156194

157195
Assert.hasText(completionsPath, "Completions Path must not be null");
158196
Assert.hasText(embeddingsPath, "Embeddings Path must not be null");
@@ -162,7 +200,7 @@ public OpenAiApi(String baseUrl, String apiKey, MultiValueMap<String, String> he
162200
this.embeddingsPath = embeddingsPath;
163201
// @formatter:off
164202
Consumer<HttpHeaders> finalHeaders = h -> {
165-
h.setBearerAuth(apiKey);
203+
h.setBearerAuth(apiKey.getValue());
166204
h.setContentType(MediaType.APPLICATION_JSON);
167205
h.addAll(headers);
168206
};
@@ -1607,4 +1645,78 @@ public record EmbeddingList<T>(// @formatter:off
16071645
@JsonProperty("usage") Usage usage) { // @formatter:on
16081646
}
16091647

1648+
public static class Builder {
1649+
1650+
private String baseUrl = OpenAiApiConstants.DEFAULT_BASE_URL;
1651+
1652+
private ApiKey apiKey;
1653+
1654+
private MultiValueMap<String, String> headers = new LinkedMultiValueMap<>();
1655+
1656+
private String completionsPath = "/v1/chat/completions";
1657+
1658+
private String embeddingsPath = "/v1/embeddings";
1659+
1660+
private RestClient.Builder restClientBuilder = RestClient.builder();
1661+
1662+
private WebClient.Builder webClientBuilder = WebClient.builder();
1663+
1664+
private ResponseErrorHandler responseErrorHandler = RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER;
1665+
1666+
public Builder baseUrl(String baseUrl) {
1667+
Assert.hasText(baseUrl, "baseUrl cannot be null or empty");
1668+
this.baseUrl = baseUrl;
1669+
return this;
1670+
}
1671+
1672+
public Builder apiKey(ApiKey apiKey) {
1673+
Assert.notNull(apiKey, "apiKey cannot be null");
1674+
this.apiKey = apiKey;
1675+
return this;
1676+
}
1677+
1678+
public Builder headers(MultiValueMap<String, String> headers) {
1679+
Assert.notNull(headers, "headers cannot be null");
1680+
this.headers = headers;
1681+
return this;
1682+
}
1683+
1684+
public Builder completionsPath(String completionsPath) {
1685+
Assert.hasText(completionsPath, "completionsPath cannot be null or empty");
1686+
this.completionsPath = completionsPath;
1687+
return this;
1688+
}
1689+
1690+
public Builder embeddingsPath(String embeddingsPath) {
1691+
Assert.hasText(embeddingsPath, "embeddingsPath cannot be null or empty");
1692+
this.embeddingsPath = embeddingsPath;
1693+
return this;
1694+
}
1695+
1696+
public Builder restClientBuilder(RestClient.Builder restClientBuilder) {
1697+
Assert.notNull(restClientBuilder, "restClientBuilder cannot be null");
1698+
this.restClientBuilder = restClientBuilder;
1699+
return this;
1700+
}
1701+
1702+
public Builder webClientBuilder(WebClient.Builder webClientBuilder) {
1703+
Assert.notNull(webClientBuilder, "webClientBuilder cannot be null");
1704+
this.webClientBuilder = webClientBuilder;
1705+
return this;
1706+
}
1707+
1708+
public Builder responseErrorHandler(ResponseErrorHandler responseErrorHandler) {
1709+
Assert.notNull(responseErrorHandler, "responseErrorHandler cannot be null");
1710+
this.responseErrorHandler = responseErrorHandler;
1711+
return this;
1712+
}
1713+
1714+
public OpenAiApi build() {
1715+
Assert.notNull(this.apiKey, "apiKey must be set");
1716+
return new OpenAiApi(this.baseUrl, this.apiKey, this.headers, this.completionsPath, this.embeddingsPath,
1717+
this.restClientBuilder, this.webClientBuilder, this.responseErrorHandler);
1718+
}
1719+
1720+
}
1721+
16101722
}
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
/*
2+
* Copyright 2023-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.openai.api;
17+
18+
import org.junit.jupiter.api.Test;
19+
20+
import org.springframework.ai.model.ApiKey;
21+
import org.springframework.ai.model.SimpleApiKey;
22+
import org.springframework.util.LinkedMultiValueMap;
23+
import org.springframework.util.MultiValueMap;
24+
import org.springframework.web.client.ResponseErrorHandler;
25+
import org.springframework.web.client.RestClient;
26+
import org.springframework.web.reactive.function.client.WebClient;
27+
28+
import static org.assertj.core.api.Assertions.assertThat;
29+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
30+
import static org.mockito.Mockito.mock;
31+
32+
public class OpenAiApiBuilderTests {
33+
34+
private static final ApiKey TEST_API_KEY = new SimpleApiKey("test-api-key");
35+
36+
private static final String TEST_BASE_URL = "https://test.openai.com";
37+
38+
private static final String TEST_COMPLETIONS_PATH = "/test/completions";
39+
40+
private static final String TEST_EMBEDDINGS_PATH = "/test/embeddings";
41+
42+
@Test
43+
void testMinimalBuilder() {
44+
OpenAiApi api = OpenAiApi.builder().apiKey(TEST_API_KEY).build();
45+
46+
assertThat(api).isNotNull();
47+
}
48+
49+
@Test
50+
void testFullBuilder() {
51+
MultiValueMap<String, String> headers = new LinkedMultiValueMap<>();
52+
headers.add("Custom-Header", "test-value");
53+
RestClient.Builder restClientBuilder = RestClient.builder();
54+
WebClient.Builder webClientBuilder = WebClient.builder();
55+
ResponseErrorHandler errorHandler = mock(ResponseErrorHandler.class);
56+
57+
OpenAiApi api = OpenAiApi.builder()
58+
.apiKey(TEST_API_KEY)
59+
.baseUrl(TEST_BASE_URL)
60+
.headers(headers)
61+
.completionsPath(TEST_COMPLETIONS_PATH)
62+
.embeddingsPath(TEST_EMBEDDINGS_PATH)
63+
.restClientBuilder(restClientBuilder)
64+
.webClientBuilder(webClientBuilder)
65+
.responseErrorHandler(errorHandler)
66+
.build();
67+
68+
assertThat(api).isNotNull();
69+
}
70+
71+
@Test
72+
void testDefaultValues() {
73+
OpenAiApi api = OpenAiApi.builder().apiKey(TEST_API_KEY).build();
74+
75+
assertThat(api).isNotNull();
76+
// We can't directly test the default values as they're private fields,
77+
// but we know the builder succeeded with defaults
78+
}
79+
80+
@Test
81+
void testMissingApiKey() {
82+
assertThatThrownBy(() -> {
83+
OpenAiApi.builder().build();
84+
}).isInstanceOf(IllegalArgumentException.class).hasMessageContaining("apiKey must be set");
85+
}
86+
87+
@Test
88+
void testInvalidBaseUrl() {
89+
assertThatThrownBy(() -> {
90+
OpenAiApi.builder().baseUrl("").build();
91+
}).isInstanceOf(IllegalArgumentException.class).hasMessageContaining("baseUrl cannot be null or empty");
92+
93+
assertThatThrownBy(() -> {
94+
OpenAiApi.builder().baseUrl(null).build();
95+
}).isInstanceOf(IllegalArgumentException.class).hasMessageContaining("baseUrl cannot be null or empty");
96+
}
97+
98+
@Test
99+
void testInvalidHeaders() {
100+
assertThatThrownBy(() -> {
101+
OpenAiApi.builder().headers(null).build();
102+
}).isInstanceOf(IllegalArgumentException.class).hasMessageContaining("headers cannot be null");
103+
}
104+
105+
@Test
106+
void testInvalidCompletionsPath() {
107+
assertThatThrownBy(() -> {
108+
OpenAiApi.builder().completionsPath("").build();
109+
}).isInstanceOf(IllegalArgumentException.class).hasMessageContaining("completionsPath cannot be null or empty");
110+
111+
assertThatThrownBy(() -> {
112+
OpenAiApi.builder().completionsPath(null).build();
113+
}).isInstanceOf(IllegalArgumentException.class).hasMessageContaining("completionsPath cannot be null or empty");
114+
}
115+
116+
@Test
117+
void testInvalidEmbeddingsPath() {
118+
assertThatThrownBy(() -> {
119+
OpenAiApi.builder().embeddingsPath("").build();
120+
}).isInstanceOf(IllegalArgumentException.class).hasMessageContaining("embeddingsPath cannot be null or empty");
121+
122+
assertThatThrownBy(() -> {
123+
OpenAiApi.builder().embeddingsPath(null).build();
124+
}).isInstanceOf(IllegalArgumentException.class).hasMessageContaining("embeddingsPath cannot be null or empty");
125+
}
126+
127+
@Test
128+
void testInvalidRestClientBuilder() {
129+
assertThatThrownBy(() -> {
130+
OpenAiApi.builder().restClientBuilder(null).build();
131+
}).isInstanceOf(IllegalArgumentException.class).hasMessageContaining("restClientBuilder cannot be null");
132+
}
133+
134+
@Test
135+
void testInvalidWebClientBuilder() {
136+
assertThatThrownBy(() -> {
137+
OpenAiApi.builder().webClientBuilder(null).build();
138+
}).isInstanceOf(IllegalArgumentException.class).hasMessageContaining("webClientBuilder cannot be null");
139+
}
140+
141+
@Test
142+
void testInvalidResponseErrorHandler() {
143+
assertThatThrownBy(() -> {
144+
OpenAiApi.builder().responseErrorHandler(null).build();
145+
}).isInstanceOf(IllegalArgumentException.class).hasMessageContaining("responseErrorHandler cannot be null");
146+
}
147+
148+
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/*
2+
* Copyright 2023-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.model;
17+
18+
/**
19+
* Some model providers API leverage short-lived api keys which must be renewed at regular
20+
* intervals using another credential. For example, a GCP service account can be exchanged
21+
* for an api key to call Vertex AI.
22+
*
23+
* Model clients use the ApiKey interface to get an api key before they make any request
24+
* to the model provider. Implementations of this interface can cache the api key and
25+
* perform a key refresh when it is required.
26+
*
27+
* @author Adib Saikali
28+
*/
29+
public interface ApiKey {
30+
31+
/**
32+
* Returns an api key to use for a making request. Users of this method should NOT
33+
* cache the returned api key, instead call this method whenever you need an api key.
34+
* Implementors of this method MUST ensure that the returned key is not expired.
35+
* @return the current value of the api key
36+
*/
37+
String getValue();
38+
39+
}

0 commit comments

Comments
 (0)