Skip to content

Commit 1568861

Browse files
committed
Introduce an ApiKey Abstraction to enable Dynamic ApiKeys
Some model providers use short-lived api keys which must be renewed at regular intervals using another credential. For example, a GCP service account can be exchanged for an api key to call Vertex AI. A StaticApiKey implementation is provided for API providers that use long lived api keys. OpenAiApi class was changed so that it accepts an ApiKey instead of a String which caused a lot small changes in downstream integration tests and auto configurations that were creating a OpenAiApi using a String key. Model provider specific implementations of the ApiKey interface will be in a future commit.
1 parent c24ef4e commit 1568861

File tree

42 files changed

+192
-61
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+192
-61
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828
import com.fasterxml.jackson.annotation.JsonInclude;
2929
import com.fasterxml.jackson.annotation.JsonInclude.Include;
3030
import com.fasterxml.jackson.annotation.JsonProperty;
31+
import org.springframework.ai.chat.metadata.Usage;
32+
import org.springframework.ai.embedding.Embedding;
33+
import org.springframework.ai.model.security.ApiKey;
3134
import reactor.core.publisher.Flux;
3235
import reactor.core.publisher.Mono;
3336

@@ -75,13 +78,15 @@ public class OpenAiApi {
7578

7679
private final WebClient webClient;
7780

81+
private final ApiKey apiKey;
82+
7883
private OpenAiStreamFunctionCallingHelper chunkMerger = new OpenAiStreamFunctionCallingHelper();
7984

8085
/**
8186
* Create a new chat completion api with base URL set to https://api.openai.com
8287
* @param apiKey OpenAI apiKey.
8388
*/
84-
public OpenAiApi(String apiKey) {
89+
public OpenAiApi(ApiKey apiKey) {
8590
this(OpenAiApiConstants.DEFAULT_BASE_URL, apiKey);
8691
}
8792

@@ -90,7 +95,7 @@ public OpenAiApi(String apiKey) {
9095
* @param baseUrl api base URL.
9196
* @param apiKey OpenAI apiKey.
9297
*/
93-
public OpenAiApi(String baseUrl, String apiKey) {
98+
public OpenAiApi(String baseUrl, ApiKey apiKey) {
9499
this(baseUrl, apiKey, RestClient.builder(), WebClient.builder());
95100
}
96101

@@ -101,7 +106,7 @@ public OpenAiApi(String baseUrl, String apiKey) {
101106
* @param restClientBuilder RestClient builder.
102107
* @param webClientBuilder WebClient builder.
103108
*/
104-
public OpenAiApi(String baseUrl, String apiKey, RestClient.Builder restClientBuilder,
109+
public OpenAiApi(String baseUrl, ApiKey apiKey, RestClient.Builder restClientBuilder,
105110
WebClient.Builder webClientBuilder) {
106111
this(baseUrl, apiKey, restClientBuilder, webClientBuilder, RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER);
107112
}
@@ -114,7 +119,7 @@ public OpenAiApi(String baseUrl, String apiKey, RestClient.Builder restClientBui
114119
* @param webClientBuilder WebClient builder.
115120
* @param responseErrorHandler Response error handler.
116121
*/
117-
public OpenAiApi(String baseUrl, String apiKey, RestClient.Builder restClientBuilder,
122+
public OpenAiApi(String baseUrl, ApiKey apiKey, RestClient.Builder restClientBuilder,
118123
WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) {
119124
this(baseUrl, apiKey, "/v1/chat/completions", "/v1/embeddings", restClientBuilder, webClientBuilder,
120125
responseErrorHandler);
@@ -130,7 +135,7 @@ public OpenAiApi(String baseUrl, String apiKey, RestClient.Builder restClientBui
130135
* @param webClientBuilder WebClient builder.
131136
* @param responseErrorHandler Response error handler.
132137
*/
133-
public OpenAiApi(String baseUrl, String apiKey, String completionsPath, String embeddingsPath,
138+
public OpenAiApi(String baseUrl, ApiKey apiKey, String completionsPath, String embeddingsPath,
134139
RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder,
135140
ResponseErrorHandler responseErrorHandler) {
136141

@@ -149,19 +154,19 @@ public OpenAiApi(String baseUrl, String apiKey, String completionsPath, String e
149154
* @param webClientBuilder WebClient builder.
150155
* @param responseErrorHandler Response error handler.
151156
*/
152-
public OpenAiApi(String baseUrl, String apiKey, MultiValueMap<String, String> headers, String completionsPath,
157+
public OpenAiApi(String baseUrl, ApiKey apiKey, MultiValueMap<String, String> headers, String completionsPath,
153158
String embeddingsPath, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder,
154159
ResponseErrorHandler responseErrorHandler) {
155160

156161
Assert.hasText(completionsPath, "Completions Path must not be null");
157162
Assert.hasText(embeddingsPath, "Embeddings Path must not be null");
158163
Assert.notNull(headers, "Headers must not be null");
159164

165+
this.apiKey = apiKey;
160166
this.completionsPath = completionsPath;
161167
this.embeddingsPath = embeddingsPath;
162168
// @formatter:off
163169
Consumer<HttpHeaders> finalHeaders = h -> {
164-
h.setBearerAuth(apiKey);
165170
h.setContentType(MediaType.APPLICATION_JSON);
166171
h.addAll(headers);
167172
};
@@ -208,12 +213,12 @@ public ResponseEntity<ChatCompletion> chatCompletionEntity(ChatCompletionRequest
208213
Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false.");
209214
Assert.notNull(additionalHttpHeader, "The additional HTTP headers can not be null.");
210215

211-
return this.restClient.post()
212-
.uri(this.completionsPath)
213-
.headers(headers -> headers.addAll(additionalHttpHeader))
214-
.body(chatRequest)
215-
.retrieve()
216-
.toEntity(ChatCompletion.class);
216+
return this.restClient.post().uri(this.completionsPath).headers(headers -> {
217+
headers.addAll(additionalHttpHeader);
218+
if (!additionalHttpHeader.containsKey(HttpHeaders.AUTHORIZATION)) {
219+
headers.setBearerAuth(apiKey.getValue());
220+
}
221+
}).body(chatRequest).retrieve().toEntity(ChatCompletion.class);
217222
}
218223

219224
/**
@@ -242,9 +247,12 @@ public Flux<ChatCompletionChunk> chatCompletionStream(ChatCompletionRequest chat
242247

243248
AtomicBoolean isInsideTool = new AtomicBoolean(false);
244249

245-
return this.webClient.post()
246-
.uri(this.completionsPath)
247-
.headers(headers -> headers.addAll(additionalHttpHeader))
250+
return this.webClient.post().uri(this.completionsPath).headers(headers -> {
251+
headers.addAll(additionalHttpHeader);
252+
if (!additionalHttpHeader.containsKey(HttpHeaders.AUTHORIZATION)) {
253+
headers.setBearerAuth(apiKey.getValue());
254+
}
255+
})
248256
.body(Mono.just(chatRequest), ChatCompletionRequest.class)
249257
.retrieve()
250258
.bodyToFlux(String.class)
@@ -318,6 +326,7 @@ public <T> ResponseEntity<EmbeddingList<Embedding>> embeddings(EmbeddingRequest<
318326

319327
return this.restClient.post()
320328
.uri(this.embeddingsPath)
329+
.headers(headers -> headers.setBearerAuth(apiKey.getValue()))
321330
.body(embeddingRequest)
322331
.retrieve()
323332
.toEntity(new ParameterizedTypeReference<>() {

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import org.springframework.ai.chat.prompt.Prompt;
2424
import org.springframework.ai.model.function.FunctionCallbackWrapper;
25+
import org.springframework.ai.model.security.StaticApiKey;
2526
import org.springframework.ai.openai.api.OpenAiApi;
2627
import org.springframework.ai.openai.api.tool.MockWeatherService;
2728

@@ -35,7 +36,7 @@ public class ChatCompletionRequestTests {
3536
@Test
3637
public void createRequestWithChatOptions() {
3738

38-
var client = new OpenAiChatModel(new OpenAiApi("TEST"),
39+
var client = new OpenAiChatModel(new OpenAiApi(new StaticApiKey("TEST")),
3940
OpenAiChatOptions.builder().withModel("DEFAULT_MODEL").withTemperature(66.6).build());
4041

4142
var request = client.createRequest(new Prompt("Test message content"), false);
@@ -61,7 +62,7 @@ public void promptOptionsTools() {
6162

6263
final String TOOL_FUNCTION_NAME = "CurrentWeather";
6364

64-
var client = new OpenAiChatModel(new OpenAiApi("TEST"),
65+
var client = new OpenAiChatModel(new OpenAiApi(new StaticApiKey("TEST")),
6566
OpenAiChatOptions.builder().withModel("DEFAULT_MODEL").build());
6667

6768
var request = client.createRequest(new Prompt("Test message content",
@@ -91,7 +92,7 @@ public void defaultOptionsTools() {
9192

9293
final String TOOL_FUNCTION_NAME = "CurrentWeather";
9394

94-
var client = new OpenAiChatModel(new OpenAiApi("TEST"),
95+
var client = new OpenAiChatModel(new OpenAiApi(new StaticApiKey("TEST")),
9596
OpenAiChatOptions.builder()
9697
.withModel("DEFAULT_MODEL")
9798
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package org.springframework.ai.openai;
1818

19+
import org.springframework.ai.model.security.StaticApiKey;
1920
import org.springframework.ai.openai.api.OpenAiApi;
2021
import org.springframework.ai.openai.api.OpenAiApi.ChatModel;
2122
import org.springframework.ai.openai.api.OpenAiAudioApi;
@@ -30,7 +31,7 @@ public class OpenAiTestConfiguration {
3031

3132
@Bean
3233
public OpenAiApi openAiApi() {
33-
return new OpenAiApi(getApiKey());
34+
return new OpenAiApi(new StaticApiKey(getApiKey()));
3435
}
3536

3637
@Bean

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import org.junit.jupiter.api.Test;
2222
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
23+
import org.springframework.ai.model.security.StaticApiKey;
2324
import reactor.core.publisher.Flux;
2425

2526
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion;
@@ -39,7 +40,7 @@
3940
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
4041
public class OpenAiApiIT {
4142

42-
OpenAiApi openAiApi = new OpenAiApi(System.getenv("OPENAI_API_KEY"));
43+
OpenAiApi openAiApi = new OpenAiApi(new StaticApiKey(System.getenv("OPENAI_API_KEY")));
4344

4445
@Test
4546
void chatCompletionEntity() {

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.slf4j.LoggerFactory;
2828

2929
import org.springframework.ai.model.ModelOptionsUtils;
30+
import org.springframework.ai.model.security.StaticApiKey;
3031
import org.springframework.ai.openai.api.OpenAiApi;
3132
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion;
3233
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage;
@@ -51,7 +52,7 @@ public class OpenAiApiToolFunctionCallIT {
5152

5253
MockWeatherService weatherService = new MockWeatherService();
5354

54-
OpenAiApi completionApi = new OpenAiApi(System.getenv("OPENAI_API_KEY"));
55+
OpenAiApi completionApi = new OpenAiApi(new StaticApiKey(System.getenv("OPENAI_API_KEY")));
5556

5657
private static <T> T fromJson(String json, Class<T> targetClass) {
5758
try {

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModeAdditionalHttpHeadersIT.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import org.springframework.ai.chat.model.ChatResponse;
2525
import org.springframework.ai.chat.prompt.Prompt;
26+
import org.springframework.ai.model.security.StaticApiKey;
2627
import org.springframework.ai.openai.OpenAiChatModel;
2728
import org.springframework.ai.openai.OpenAiChatOptions;
2829
import org.springframework.ai.openai.api.OpenAiApi;
@@ -67,7 +68,7 @@ static class Config {
6768

6869
@Bean
6970
public OpenAiApi chatCompletionApi() {
70-
return new OpenAiApi("Invalid API Key");
71+
return new OpenAiApi(new StaticApiKey("Invalid API Key"));
7172
}
7273

7374
@Bean

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelFunctionCallingIT.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
2727
import org.slf4j.Logger;
2828
import org.slf4j.LoggerFactory;
29+
import org.springframework.ai.model.security.StaticApiKey;
2930
import reactor.core.publisher.Flux;
3031

3132
import org.springframework.ai.chat.messages.AssistantMessage;
@@ -197,7 +198,7 @@ static class Config {
197198

198199
@Bean
199200
public OpenAiApi chatCompletionApi() {
200-
return new OpenAiApi(System.getenv("OPENAI_API_KEY"));
201+
return new OpenAiApi(new StaticApiKey(System.getenv("OPENAI_API_KEY")));
201202
}
202203

203204
@Bean

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelObservationIT.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.junit.jupiter.api.BeforeEach;
2525
import org.junit.jupiter.api.Test;
2626
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
27+
import org.springframework.ai.model.security.StaticApiKey;
2728
import reactor.core.publisher.Flux;
2829

2930
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
@@ -169,7 +170,7 @@ public TestObservationRegistry observationRegistry() {
169170

170171
@Bean
171172
public OpenAiApi openAiApi() {
172-
return new OpenAiApi(System.getenv("OPENAI_API_KEY"));
173+
return new OpenAiApi(new StaticApiKey(System.getenv("OPENAI_API_KEY")));
173174
}
174175

175176
@Bean

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelProxyToolCallsIT.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
3333
import org.slf4j.Logger;
3434
import org.slf4j.LoggerFactory;
35+
import org.springframework.ai.model.security.StaticApiKey;
3536
import reactor.core.publisher.Flux;
3637

3738
import org.springframework.ai.chat.messages.AssistantMessage;
@@ -355,7 +356,7 @@ static class Config {
355356

356357
@Bean
357358
public OpenAiApi chatCompletionApi() {
358-
return new OpenAiApi(System.getenv("OPENAI_API_KEY"));
359+
return new OpenAiApi(new StaticApiKey(System.getenv("OPENAI_API_KEY")));
359360
}
360361

361362
@Bean

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelResponseFormatIT.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.springframework.ai.chat.model.ChatResponse;
3232
import org.springframework.ai.chat.prompt.Prompt;
3333
import org.springframework.ai.converter.BeanOutputConverter;
34+
import org.springframework.ai.model.security.StaticApiKey;
3435
import org.springframework.ai.openai.OpenAiChatModel;
3536
import org.springframework.ai.openai.OpenAiChatOptions;
3637
import org.springframework.ai.openai.api.OpenAiApi;
@@ -234,7 +235,7 @@ static class Config {
234235

235236
@Bean
236237
public OpenAiApi chatCompletionApi() {
237-
return new OpenAiApi(System.getenv("OPENAI_API_KEY"));
238+
return new OpenAiApi(new StaticApiKey(System.getenv("OPENAI_API_KEY")));
238239
}
239240

240241
@Bean

0 commit comments

Comments
 (0)