diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index a29ebcd8fcd..b987f7587bd 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -31,8 +31,10 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import org.springframework.ai.model.ApiKey; import org.springframework.ai.model.ChatModelDescription; import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.model.SimpleApiKey; import org.springframework.ai.openai.api.common.OpenAiApiConstants; import org.springframework.ai.retry.RetryUtils; import org.springframework.core.ParameterizedTypeReference; @@ -61,6 +63,10 @@ */ public class OpenAiApi { + public static Builder builder() { + return new Builder(); + } + public static final OpenAiApi.ChatModel DEFAULT_CHAT_MODEL = ChatModel.GPT_4_O; public static final String DEFAULT_EMBEDDING_MODEL = EmbeddingModel.TEXT_EMBEDDING_ADA_002.getValue(); @@ -80,7 +86,9 @@ public class OpenAiApi { /** * Create a new chat completion api with base URL set to https://api.openai.com * @param apiKey OpenAI apiKey. + * @deprecated since 1.0.0.M6 - use {@link #builder()} instead */ + @Deprecated(since = "1.0.0.M6") public OpenAiApi(String apiKey) { this(OpenAiApiConstants.DEFAULT_BASE_URL, apiKey); } @@ -89,7 +97,9 @@ public OpenAiApi(String apiKey) { * Create a new chat completion api. * @param baseUrl api base URL. * @param apiKey OpenAI apiKey. + * @deprecated since 1.0.0.M6 - use {@link #builder()} instead */ + @Deprecated(since = "1.0.0.M6") public OpenAiApi(String baseUrl, String apiKey) { this(baseUrl, apiKey, RestClient.builder(), WebClient.builder()); } @@ -100,7 +110,9 @@ public OpenAiApi(String baseUrl, String apiKey) { * @param apiKey OpenAI apiKey. * @param restClientBuilder RestClient builder. * @param webClientBuilder WebClient builder. + * @deprecated since 1.0.0.M6 - use {@link #builder()} instead */ + @Deprecated(since = "1.0.0.M6") public OpenAiApi(String baseUrl, String apiKey, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder) { this(baseUrl, apiKey, restClientBuilder, webClientBuilder, RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER); @@ -113,7 +125,9 @@ public OpenAiApi(String baseUrl, String apiKey, RestClient.Builder restClientBui * @param restClientBuilder RestClient builder. * @param webClientBuilder WebClient builder. * @param responseErrorHandler Response error handler. + * @deprecated since 1.0.0.M6 - use {@link #builder()} instead */ + @Deprecated(since = "1.0.0.M6") public OpenAiApi(String baseUrl, String apiKey, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) { this(baseUrl, apiKey, "/v1/chat/completions", "/v1/embeddings", restClientBuilder, webClientBuilder, @@ -129,7 +143,9 @@ public OpenAiApi(String baseUrl, String apiKey, RestClient.Builder restClientBui * @param restClientBuilder RestClient builder. * @param webClientBuilder WebClient builder. * @param responseErrorHandler Response error handler. + * @deprecated since 1.0.0.M6 - use {@link #builder()} instead */ + @Deprecated(since = "1.0.0.M6") public OpenAiApi(String baseUrl, String apiKey, String completionsPath, String embeddingsPath, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) { @@ -148,10 +164,32 @@ public OpenAiApi(String baseUrl, String apiKey, String completionsPath, String e * @param restClientBuilder RestClient builder. * @param webClientBuilder WebClient builder. * @param responseErrorHandler Response error handler. + * @deprecated since 1.0.0.M6 - use {@link #builder()} instead */ + @Deprecated(since = "1.0.0.M6") public OpenAiApi(String baseUrl, String apiKey, MultiValueMap headers, String completionsPath, String embeddingsPath, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) { + this(baseUrl, new SimpleApiKey(apiKey), headers, completionsPath, embeddingsPath, restClientBuilder, + webClientBuilder, responseErrorHandler); + } + + /** + * Create a new chat completion api. + * @param baseUrl api base URL. + * @param apiKey OpenAI apiKey. + * @param headers the http headers to use. + * @param completionsPath the path to the chat completions endpoint. + * @param embeddingsPath the path to the embeddings endpoint. + * @param restClientBuilder RestClient builder. + * @param webClientBuilder WebClient builder. + * @param responseErrorHandler Response error handler. + * @deprecated since 1.0.0.M6 - use {@link #builder()} instead + */ + @Deprecated(since = "1.0.0.M6") + public OpenAiApi(String baseUrl, ApiKey apiKey, MultiValueMap headers, String completionsPath, + String embeddingsPath, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder, + ResponseErrorHandler responseErrorHandler) { Assert.hasText(completionsPath, "Completions Path must not be null"); Assert.hasText(embeddingsPath, "Embeddings Path must not be null"); @@ -161,7 +199,7 @@ public OpenAiApi(String baseUrl, String apiKey, MultiValueMap he this.embeddingsPath = embeddingsPath; // @formatter:off Consumer finalHeaders = h -> { - h.setBearerAuth(apiKey); + h.setBearerAuth(apiKey.getValue()); h.setContentType(MediaType.APPLICATION_JSON); h.addAll(headers); }; @@ -1507,4 +1545,78 @@ public record EmbeddingList(// @formatter:off @JsonProperty("usage") Usage usage) { // @formatter:on } + public static class Builder { + + private String baseUrl = OpenAiApiConstants.DEFAULT_BASE_URL; + + private ApiKey apiKey; + + private MultiValueMap headers = new LinkedMultiValueMap<>(); + + private String completionsPath = "/v1/chat/completions"; + + private String embeddingsPath = "/v1/embeddings"; + + private RestClient.Builder restClientBuilder = RestClient.builder(); + + private WebClient.Builder webClientBuilder = WebClient.builder(); + + private ResponseErrorHandler responseErrorHandler = RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER; + + public Builder baseUrl(String baseUrl) { + Assert.hasText(baseUrl, "baseUrl cannot be null or empty"); + this.baseUrl = baseUrl; + return this; + } + + public Builder apiKey(ApiKey apiKey) { + Assert.notNull(apiKey, "apiKey cannot be null"); + this.apiKey = apiKey; + return this; + } + + public Builder headers(MultiValueMap headers) { + Assert.notNull(headers, "headers cannot be null"); + this.headers = headers; + return this; + } + + public Builder completionsPath(String completionsPath) { + Assert.hasText(completionsPath, "completionsPath cannot be null or empty"); + this.completionsPath = completionsPath; + return this; + } + + public Builder embeddingsPath(String embeddingsPath) { + Assert.hasText(embeddingsPath, "embeddingsPath cannot be null or empty"); + this.embeddingsPath = embeddingsPath; + return this; + } + + public Builder restClientBuilder(RestClient.Builder restClientBuilder) { + Assert.notNull(restClientBuilder, "restClientBuilder cannot be null"); + this.restClientBuilder = restClientBuilder; + return this; + } + + public Builder webClientBuilder(WebClient.Builder webClientBuilder) { + Assert.notNull(webClientBuilder, "webClientBuilder cannot be null"); + this.webClientBuilder = webClientBuilder; + return this; + } + + public Builder responseErrorHandler(ResponseErrorHandler responseErrorHandler) { + Assert.notNull(responseErrorHandler, "responseErrorHandler cannot be null"); + this.responseErrorHandler = responseErrorHandler; + return this; + } + + public OpenAiApi build() { + Assert.notNull(this.apiKey, "apiKey must be set"); + return new OpenAiApi(this.baseUrl, this.apiKey, this.headers, this.completionsPath, this.embeddingsPath, + this.restClientBuilder, this.webClientBuilder, this.responseErrorHandler); + } + + } + } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiBuilderTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiBuilderTests.java new file mode 100644 index 00000000000..a2d6fe32e92 --- /dev/null +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiBuilderTests.java @@ -0,0 +1,148 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.openai.api; + +import org.junit.jupiter.api.Test; + +import org.springframework.ai.model.ApiKey; +import org.springframework.ai.model.SimpleApiKey; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClient; +import org.springframework.web.reactive.function.client.WebClient; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + +public class OpenAiApiBuilderTests { + + private static final ApiKey TEST_API_KEY = new SimpleApiKey("test-api-key"); + + private static final String TEST_BASE_URL = "https://test.openai.com"; + + private static final String TEST_COMPLETIONS_PATH = "/test/completions"; + + private static final String TEST_EMBEDDINGS_PATH = "/test/embeddings"; + + @Test + void testMinimalBuilder() { + OpenAiApi api = OpenAiApi.builder().apiKey(TEST_API_KEY).build(); + + assertThat(api).isNotNull(); + } + + @Test + void testFullBuilder() { + MultiValueMap headers = new LinkedMultiValueMap<>(); + headers.add("Custom-Header", "test-value"); + RestClient.Builder restClientBuilder = RestClient.builder(); + WebClient.Builder webClientBuilder = WebClient.builder(); + ResponseErrorHandler errorHandler = mock(ResponseErrorHandler.class); + + OpenAiApi api = OpenAiApi.builder() + .apiKey(TEST_API_KEY) + .baseUrl(TEST_BASE_URL) + .headers(headers) + .completionsPath(TEST_COMPLETIONS_PATH) + .embeddingsPath(TEST_EMBEDDINGS_PATH) + .restClientBuilder(restClientBuilder) + .webClientBuilder(webClientBuilder) + .responseErrorHandler(errorHandler) + .build(); + + assertThat(api).isNotNull(); + } + + @Test + void testDefaultValues() { + OpenAiApi api = OpenAiApi.builder().apiKey(TEST_API_KEY).build(); + + assertThat(api).isNotNull(); + // We can't directly test the default values as they're private fields, + // but we know the builder succeeded with defaults + } + + @Test + void testMissingApiKey() { + assertThatThrownBy(() -> { + OpenAiApi.builder().build(); + }).isInstanceOf(IllegalArgumentException.class).hasMessageContaining("apiKey must be set"); + } + + @Test + void testInvalidBaseUrl() { + assertThatThrownBy(() -> { + OpenAiApi.builder().baseUrl("").build(); + }).isInstanceOf(IllegalArgumentException.class).hasMessageContaining("baseUrl cannot be null or empty"); + + assertThatThrownBy(() -> { + OpenAiApi.builder().baseUrl(null).build(); + }).isInstanceOf(IllegalArgumentException.class).hasMessageContaining("baseUrl cannot be null or empty"); + } + + @Test + void testInvalidHeaders() { + assertThatThrownBy(() -> { + OpenAiApi.builder().headers(null).build(); + }).isInstanceOf(IllegalArgumentException.class).hasMessageContaining("headers cannot be null"); + } + + @Test + void testInvalidCompletionsPath() { + assertThatThrownBy(() -> { + OpenAiApi.builder().completionsPath("").build(); + }).isInstanceOf(IllegalArgumentException.class).hasMessageContaining("completionsPath cannot be null or empty"); + + assertThatThrownBy(() -> { + OpenAiApi.builder().completionsPath(null).build(); + }).isInstanceOf(IllegalArgumentException.class).hasMessageContaining("completionsPath cannot be null or empty"); + } + + @Test + void testInvalidEmbeddingsPath() { + assertThatThrownBy(() -> { + OpenAiApi.builder().embeddingsPath("").build(); + }).isInstanceOf(IllegalArgumentException.class).hasMessageContaining("embeddingsPath cannot be null or empty"); + + assertThatThrownBy(() -> { + OpenAiApi.builder().embeddingsPath(null).build(); + }).isInstanceOf(IllegalArgumentException.class).hasMessageContaining("embeddingsPath cannot be null or empty"); + } + + @Test + void testInvalidRestClientBuilder() { + assertThatThrownBy(() -> { + OpenAiApi.builder().restClientBuilder(null).build(); + }).isInstanceOf(IllegalArgumentException.class).hasMessageContaining("restClientBuilder cannot be null"); + } + + @Test + void testInvalidWebClientBuilder() { + assertThatThrownBy(() -> { + OpenAiApi.builder().webClientBuilder(null).build(); + }).isInstanceOf(IllegalArgumentException.class).hasMessageContaining("webClientBuilder cannot be null"); + } + + @Test + void testInvalidResponseErrorHandler() { + assertThatThrownBy(() -> { + OpenAiApi.builder().responseErrorHandler(null).build(); + }).isInstanceOf(IllegalArgumentException.class).hasMessageContaining("responseErrorHandler cannot be null"); + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/ApiKey.java b/spring-ai-core/src/main/java/org/springframework/ai/model/ApiKey.java new file mode 100644 index 00000000000..3d052202ab6 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/ApiKey.java @@ -0,0 +1,39 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.model; + +/** + * Some model providers API leverage short-lived api keys which must be renewed at regular + * intervals using another credential. For example, a GCP service account can be exchanged + * for an api key to call Vertex AI. + * + * Model clients use the ApiKey interface to get an api key before they make any request + * to the model provider. Implementations of this interface can cache the api key and + * perform a key refresh when it is required. + * + * @author Adib Saikali + */ +public interface ApiKey { + + /** + * Returns an api key to use for a making request. Users of this method should NOT + * cache the returned api key, instead call this method whenever you need an api key. + * Implementors of this method MUST ensure that the returned key is not expired. + * @return the current value of the api key + */ + String getValue(); + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/SimpleApiKey.java b/spring-ai-core/src/main/java/org/springframework/ai/model/SimpleApiKey.java new file mode 100644 index 00000000000..4887ff12b83 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/SimpleApiKey.java @@ -0,0 +1,68 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.model; + +import org.springframework.util.Assert; + +/** + * A simple implementation of {@link ApiKey} that holds an immutable API key value. This + * implementation is suitable for cases where the API key is static and does not need to + * be refreshed or rotated. + * + * @author Adib Saikali + * @since 1.0.0 + */ +public final class SimpleApiKey implements ApiKey { + + private final String value; + + /** + * Create a new SimpleApiKey. + * @param value the API key value, must not be null or empty + * @throws IllegalArgumentException if value is null or empty + */ + public SimpleApiKey(String value) { + Assert.hasText(value, "API key value must not be null or empty"); + this.value = value; + } + + @Override + public String getValue() { + return this.value; + } + + @Override + public String toString() { + return "SimpleApiKey{value='***'}"; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof SimpleApiKey that)) { + return false; + } + return this.value.equals(that.value); + } + + @Override + public int hashCode() { + return this.value.hashCode(); + } + +} diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/openai-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/openai-chat.adoc index 2b5d1db2468..0942a22e225 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/openai-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/openai-chat.adoc @@ -550,3 +550,46 @@ Follow the https://github.com/spring-projects/spring-ai/blob/main/models/spring- * The link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java[OpenAiApiToolFunctionCallIT.java] tests show how to use the low-level API to call tool functions. Based on the link:https://platform.openai.com/docs/guides/function-calling/parallel-function-calling[OpenAI Function Calling] tutorial. + +== API Key Management + +Spring AI provides flexible API key management through the `ApiKey` interface and its implementations. The default implementation, `SimpleApiKey`, is suitable for most use cases, but you can also create custom implementations for more complex scenarios. + +=== Default Configuration + +By default, Spring Boot auto-configuration will create an API key bean using the `spring.ai.openai.api-key` property: + +[source,properties] +---- +spring.ai.openai.api-key=your-api-key-here +---- + +=== Custom API Key Configuration + +You can create a custom instance of `OpenAiApi` with your own `ApiKey` implementation using the builder pattern: + +[source,java] +---- +ApiKey customApiKey = new ApiKey() { + @Override + public String getValue() { + // Custom logic to retrieve API key + return "your-api-key-here"; + } +}; + +OpenAiApi openAiApi = OpenAiApi.builder() + .apiKey(customApiKey) + .build(); + +// Create a chat client with the custom OpenAiApi instance +OpenAiChatClient chatClient = new OpenAiChatClient(openAiApi); + +---- + +This is useful when you need to: + +* Retrieve the API key from a secure key store +* Rotate API keys dynamically +* Implement custom API key selection logic + diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java index 7ce9ca291e2..35e1fcf9ecc 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java @@ -27,6 +27,7 @@ import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention; import org.springframework.ai.image.observation.ImageModelObservationConvention; +import org.springframework.ai.model.SimpleApiKey; import org.springframework.ai.model.function.DefaultFunctionCallbackResolver; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackResolver; @@ -160,9 +161,16 @@ private OpenAiApi openAiApi(OpenAiChatProperties chatProperties, OpenAiConnectio ResolvedConnectionProperties resolved = resolveConnectionProperties(commonProperties, chatProperties, modelType); - return new OpenAiApi(resolved.baseUrl(), resolved.apiKey(), resolved.headers(), - chatProperties.getCompletionsPath(), OpenAiEmbeddingProperties.DEFAULT_EMBEDDINGS_PATH, - restClientBuilder, webClientBuilder, responseErrorHandler); + return OpenAiApi.builder() + .baseUrl(resolved.baseUrl()) + .apiKey(new SimpleApiKey(resolved.apiKey())) + .headers(resolved.headers()) + .completionsPath(chatProperties.getCompletionsPath()) + .embeddingsPath(OpenAiEmbeddingProperties.DEFAULT_EMBEDDINGS_PATH) + .restClientBuilder(restClientBuilder) + .webClientBuilder(webClientBuilder) + .responseErrorHandler(responseErrorHandler) + .build(); } private OpenAiApi openAiApi(OpenAiEmbeddingProperties embeddingProperties, @@ -172,9 +180,16 @@ private OpenAiApi openAiApi(OpenAiEmbeddingProperties embeddingProperties, ResolvedConnectionProperties resolved = resolveConnectionProperties(commonProperties, embeddingProperties, modelType); - return new OpenAiApi(resolved.baseUrl(), resolved.apiKey(), resolved.headers(), - OpenAiChatProperties.DEFAULT_COMPLETIONS_PATH, embeddingProperties.getEmbeddingsPath(), - restClientBuilder, webClientBuilder, responseErrorHandler); + return OpenAiApi.builder() + .baseUrl(resolved.baseUrl()) + .apiKey(new SimpleApiKey(resolved.apiKey())) + .headers(resolved.headers()) + .completionsPath(OpenAiChatProperties.DEFAULT_COMPLETIONS_PATH) + .embeddingsPath(embeddingProperties.getEmbeddingsPath()) + .restClientBuilder(restClientBuilder) + .webClientBuilder(webClientBuilder) + .responseErrorHandler(responseErrorHandler) + .build(); } @Bean