Skip to content

Commit e532290

Browse files
committed
Introduce an ApiKey Abstraction to enable Dynamic ApiKeys
Some model providers use short-lived api keys which must be renewed at regular intervals using another credential. For example, a GCP service account can be exchanged for an api key to call Vertex AI. A StaticApiKey implementation is provided for API providers that use long lived api keys. OpenAiApi class was changed so that it accepts an ApiKey instead of a String which caused a lot small changes in downstream integration tests and auto configurations that were creating a OpenAiApi using a String key. Model provider specific implementations of the ApiKey interface will be in a future commit.
1 parent 6d2b22f commit e532290

File tree

42 files changed

+192
-61
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+192
-61
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525
import com.fasterxml.jackson.annotation.JsonInclude;
2626
import com.fasterxml.jackson.annotation.JsonInclude.Include;
2727
import com.fasterxml.jackson.annotation.JsonProperty;
28+
import org.springframework.ai.chat.metadata.Usage;
29+
import org.springframework.ai.embedding.Embedding;
30+
import org.springframework.ai.model.security.ApiKey;
2831
import reactor.core.publisher.Flux;
2932
import reactor.core.publisher.Mono;
3033

@@ -74,13 +77,15 @@ public class OpenAiApi {
7477

7578
private final WebClient webClient;
7679

80+
private final ApiKey apiKey;
81+
7782
private OpenAiStreamFunctionCallingHelper chunkMerger = new OpenAiStreamFunctionCallingHelper();
7883

7984
/**
8085
* Create a new chat completion api with base URL set to https://api.openai.com
8186
* @param apiKey OpenAI apiKey.
8287
*/
83-
public OpenAiApi(String apiKey) {
88+
public OpenAiApi(ApiKey apiKey) {
8489
this(OpenAiApiConstants.DEFAULT_BASE_URL, apiKey);
8590
}
8691

@@ -89,7 +94,7 @@ public OpenAiApi(String apiKey) {
8994
* @param baseUrl api base URL.
9095
* @param apiKey OpenAI apiKey.
9196
*/
92-
public OpenAiApi(String baseUrl, String apiKey) {
97+
public OpenAiApi(String baseUrl, ApiKey apiKey) {
9398
this(baseUrl, apiKey, RestClient.builder(), WebClient.builder());
9499
}
95100

@@ -100,7 +105,7 @@ public OpenAiApi(String baseUrl, String apiKey) {
100105
* @param restClientBuilder RestClient builder.
101106
* @param webClientBuilder WebClient builder.
102107
*/
103-
public OpenAiApi(String baseUrl, String apiKey, RestClient.Builder restClientBuilder,
108+
public OpenAiApi(String baseUrl, ApiKey apiKey, RestClient.Builder restClientBuilder,
104109
WebClient.Builder webClientBuilder) {
105110
this(baseUrl, apiKey, restClientBuilder, webClientBuilder, RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER);
106111
}
@@ -113,7 +118,7 @@ public OpenAiApi(String baseUrl, String apiKey, RestClient.Builder restClientBui
113118
* @param webClientBuilder WebClient builder.
114119
* @param responseErrorHandler Response error handler.
115120
*/
116-
public OpenAiApi(String baseUrl, String apiKey, RestClient.Builder restClientBuilder,
121+
public OpenAiApi(String baseUrl, ApiKey apiKey, RestClient.Builder restClientBuilder,
117122
WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) {
118123
this(baseUrl, apiKey, "/v1/chat/completions", "/v1/embeddings", restClientBuilder, webClientBuilder,
119124
responseErrorHandler);
@@ -129,7 +134,7 @@ public OpenAiApi(String baseUrl, String apiKey, RestClient.Builder restClientBui
129134
* @param webClientBuilder WebClient builder.
130135
* @param responseErrorHandler Response error handler.
131136
*/
132-
public OpenAiApi(String baseUrl, String apiKey, String completionsPath, String embeddingsPath,
137+
public OpenAiApi(String baseUrl, ApiKey apiKey, String completionsPath, String embeddingsPath,
133138
RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder,
134139
ResponseErrorHandler responseErrorHandler) {
135140

@@ -148,19 +153,19 @@ public OpenAiApi(String baseUrl, String apiKey, String completionsPath, String e
148153
* @param webClientBuilder WebClient builder.
149154
* @param responseErrorHandler Response error handler.
150155
*/
151-
public OpenAiApi(String baseUrl, String apiKey, MultiValueMap<String, String> headers, String completionsPath,
156+
public OpenAiApi(String baseUrl, ApiKey apiKey, MultiValueMap<String, String> headers, String completionsPath,
152157
String embeddingsPath, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder,
153158
ResponseErrorHandler responseErrorHandler) {
154159

155160
Assert.hasText(completionsPath, "Completions Path must not be null");
156161
Assert.hasText(embeddingsPath, "Embeddings Path must not be null");
157162
Assert.notNull(headers, "Headers must not be null");
158163

164+
this.apiKey = apiKey;
159165
this.completionsPath = completionsPath;
160166
this.embeddingsPath = embeddingsPath;
161167
// @formatter:off
162168
Consumer<HttpHeaders> finalHeaders = h -> {
163-
h.setBearerAuth(apiKey);
164169
h.setContentType(MediaType.APPLICATION_JSON);
165170
h.addAll(headers);
166171
};
@@ -207,12 +212,12 @@ public ResponseEntity<ChatCompletion> chatCompletionEntity(ChatCompletionRequest
207212
Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false.");
208213
Assert.notNull(additionalHttpHeader, "The additional HTTP headers can not be null.");
209214

210-
return this.restClient.post()
211-
.uri(this.completionsPath)
212-
.headers(headers -> headers.addAll(additionalHttpHeader))
213-
.body(chatRequest)
214-
.retrieve()
215-
.toEntity(ChatCompletion.class);
215+
return this.restClient.post().uri(this.completionsPath).headers(headers -> {
216+
headers.addAll(additionalHttpHeader);
217+
if (!additionalHttpHeader.containsKey(HttpHeaders.AUTHORIZATION)) {
218+
headers.setBearerAuth(apiKey.getValue());
219+
}
220+
}).body(chatRequest).retrieve().toEntity(ChatCompletion.class);
216221
}
217222

218223
/**
@@ -241,9 +246,12 @@ public Flux<ChatCompletionChunk> chatCompletionStream(ChatCompletionRequest chat
241246

242247
AtomicBoolean isInsideTool = new AtomicBoolean(false);
243248

244-
return this.webClient.post()
245-
.uri(this.completionsPath)
246-
.headers(headers -> headers.addAll(additionalHttpHeader))
249+
return this.webClient.post().uri(this.completionsPath).headers(headers -> {
250+
headers.addAll(additionalHttpHeader);
251+
if (!additionalHttpHeader.containsKey(HttpHeaders.AUTHORIZATION)) {
252+
headers.setBearerAuth(apiKey.getValue());
253+
}
254+
})
247255
.body(Mono.just(chatRequest), ChatCompletionRequest.class)
248256
.retrieve()
249257
.bodyToFlux(String.class)
@@ -317,6 +325,7 @@ public <T> ResponseEntity<EmbeddingList<Embedding>> embeddings(EmbeddingRequest<
317325

318326
return this.restClient.post()
319327
.uri(this.embeddingsPath)
328+
.headers(headers -> headers.setBearerAuth(apiKey.getValue()))
320329
.body(embeddingRequest)
321330
.retrieve()
322331
.toEntity(new ParameterizedTypeReference<>() {

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import org.springframework.ai.chat.prompt.Prompt;
2424
import org.springframework.ai.model.function.FunctionCallbackWrapper;
25+
import org.springframework.ai.model.security.StaticApiKey;
2526
import org.springframework.ai.openai.api.OpenAiApi;
2627
import org.springframework.ai.openai.api.tool.MockWeatherService;
2728

@@ -35,7 +36,7 @@ public class ChatCompletionRequestTests {
3536
@Test
3637
public void createRequestWithChatOptions() {
3738

38-
var client = new OpenAiChatModel(new OpenAiApi("TEST"),
39+
var client = new OpenAiChatModel(new OpenAiApi(new StaticApiKey("TEST")),
3940
OpenAiChatOptions.builder().withModel("DEFAULT_MODEL").withTemperature(66.6).build());
4041

4142
var request = client.createRequest(new Prompt("Test message content"), false);
@@ -61,7 +62,7 @@ public void promptOptionsTools() {
6162

6263
final String TOOL_FUNCTION_NAME = "CurrentWeather";
6364

64-
var client = new OpenAiChatModel(new OpenAiApi("TEST"),
65+
var client = new OpenAiChatModel(new OpenAiApi(new StaticApiKey("TEST")),
6566
OpenAiChatOptions.builder().withModel("DEFAULT_MODEL").build());
6667

6768
var request = client.createRequest(new Prompt("Test message content",
@@ -91,7 +92,7 @@ public void defaultOptionsTools() {
9192

9293
final String TOOL_FUNCTION_NAME = "CurrentWeather";
9394

94-
var client = new OpenAiChatModel(new OpenAiApi("TEST"),
95+
var client = new OpenAiChatModel(new OpenAiApi(new StaticApiKey("TEST")),
9596
OpenAiChatOptions.builder()
9697
.withModel("DEFAULT_MODEL")
9798
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package org.springframework.ai.openai;
1818

19+
import org.springframework.ai.model.security.StaticApiKey;
1920
import org.springframework.ai.openai.api.OpenAiApi;
2021
import org.springframework.ai.openai.api.OpenAiApi.ChatModel;
2122
import org.springframework.ai.openai.api.OpenAiAudioApi;
@@ -30,7 +31,7 @@ public class OpenAiTestConfiguration {
3031

3132
@Bean
3233
public OpenAiApi openAiApi() {
33-
return new OpenAiApi(getApiKey());
34+
return new OpenAiApi(new StaticApiKey(getApiKey()));
3435
}
3536

3637
@Bean

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import org.junit.jupiter.api.Test;
2222
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
23+
import org.springframework.ai.model.security.StaticApiKey;
2324
import reactor.core.publisher.Flux;
2425

2526
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion;
@@ -39,7 +40,7 @@
3940
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
4041
public class OpenAiApiIT {
4142

42-
OpenAiApi openAiApi = new OpenAiApi(System.getenv("OPENAI_API_KEY"));
43+
OpenAiApi openAiApi = new OpenAiApi(new StaticApiKey(System.getenv("OPENAI_API_KEY")));
4344

4445
@Test
4546
void chatCompletionEntity() {

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.slf4j.LoggerFactory;
2828

2929
import org.springframework.ai.model.ModelOptionsUtils;
30+
import org.springframework.ai.model.security.StaticApiKey;
3031
import org.springframework.ai.openai.api.OpenAiApi;
3132
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion;
3233
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage;
@@ -52,7 +53,7 @@ public class OpenAiApiToolFunctionCallIT {
5253

5354
MockWeatherService weatherService = new MockWeatherService();
5455

55-
OpenAiApi completionApi = new OpenAiApi(System.getenv("OPENAI_API_KEY"));
56+
OpenAiApi completionApi = new OpenAiApi(new StaticApiKey(System.getenv("OPENAI_API_KEY")));
5657

5758
private static <T> T fromJson(String json, Class<T> targetClass) {
5859
try {

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModeAdditionalHttpHeadersIT.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import org.springframework.ai.chat.model.ChatResponse;
2525
import org.springframework.ai.chat.prompt.Prompt;
26+
import org.springframework.ai.model.security.StaticApiKey;
2627
import org.springframework.ai.openai.OpenAiChatModel;
2728
import org.springframework.ai.openai.OpenAiChatOptions;
2829
import org.springframework.ai.openai.api.OpenAiApi;
@@ -68,7 +69,7 @@ static class Config {
6869

6970
@Bean
7071
public OpenAiApi chatCompletionApi() {
71-
return new OpenAiApi("Invalid API Key");
72+
return new OpenAiApi(new StaticApiKey("Invalid API Key"));
7273
}
7374

7475
@Bean

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelFunctionCallingIT.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
2727
import org.slf4j.Logger;
2828
import org.slf4j.LoggerFactory;
29+
import org.springframework.ai.model.security.StaticApiKey;
2930
import reactor.core.publisher.Flux;
3031

3132
import org.springframework.ai.chat.messages.AssistantMessage;
@@ -197,7 +198,7 @@ static class Config {
197198

198199
@Bean
199200
public OpenAiApi chatCompletionApi() {
200-
return new OpenAiApi(System.getenv("OPENAI_API_KEY"));
201+
return new OpenAiApi(new StaticApiKey(System.getenv("OPENAI_API_KEY")));
201202
}
202203

203204
@Bean

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelObservationIT.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.junit.jupiter.api.BeforeEach;
2525
import org.junit.jupiter.api.Test;
2626
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
27+
import org.springframework.ai.model.security.StaticApiKey;
2728
import reactor.core.publisher.Flux;
2829

2930
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
@@ -167,7 +168,7 @@ public TestObservationRegistry observationRegistry() {
167168

168169
@Bean
169170
public OpenAiApi openAiApi() {
170-
return new OpenAiApi(System.getenv("OPENAI_API_KEY"));
171+
return new OpenAiApi(new StaticApiKey(System.getenv("OPENAI_API_KEY")));
171172
}
172173

173174
@Bean

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelProxyToolCallsIT.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
3333
import org.slf4j.Logger;
3434
import org.slf4j.LoggerFactory;
35+
import org.springframework.ai.model.security.StaticApiKey;
3536
import reactor.core.publisher.Flux;
3637

3738
import org.springframework.ai.chat.messages.AssistantMessage;
@@ -355,7 +356,7 @@ static class Config {
355356

356357
@Bean
357358
public OpenAiApi chatCompletionApi() {
358-
return new OpenAiApi(System.getenv("OPENAI_API_KEY"));
359+
return new OpenAiApi(new StaticApiKey(System.getenv("OPENAI_API_KEY")));
359360
}
360361

361362
@Bean

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelResponseFormatIT.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.springframework.ai.chat.model.ChatResponse;
3131
import org.springframework.ai.chat.prompt.Prompt;
3232
import org.springframework.ai.converter.BeanOutputConverter;
33+
import org.springframework.ai.model.security.StaticApiKey;
3334
import org.springframework.ai.openai.OpenAiChatModel;
3435
import org.springframework.ai.openai.OpenAiChatOptions;
3536
import org.springframework.ai.openai.api.OpenAiApi;
@@ -186,7 +187,7 @@ static class Config {
186187

187188
@Bean
188189
public OpenAiApi chatCompletionApi() {
189-
return new OpenAiApi(System.getenv("OPENAI_API_KEY"));
190+
return new OpenAiApi(new StaticApiKey(System.getenv("OPENAI_API_KEY")));
190191
}
191192

192193
@Bean

0 commit comments

Comments
 (0)