Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
import com.fasterxml.jackson.annotation.JsonProperty;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.embedding.Embedding;
import org.springframework.ai.model.security.ApiKey;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

Expand Down Expand Up @@ -75,13 +78,15 @@ public class OpenAiApi {

private final WebClient webClient;

private final ApiKey apiKey;

private OpenAiStreamFunctionCallingHelper chunkMerger = new OpenAiStreamFunctionCallingHelper();

/**
* Create a new chat completion api with base URL set to https://api.openai.com
* @param apiKey OpenAI apiKey.
*/
public OpenAiApi(String apiKey) {
public OpenAiApi(ApiKey apiKey) {
this(OpenAiApiConstants.DEFAULT_BASE_URL, apiKey);
}

Expand All @@ -90,7 +95,7 @@ public OpenAiApi(String apiKey) {
* @param baseUrl api base URL.
* @param apiKey OpenAI apiKey.
*/
public OpenAiApi(String baseUrl, String apiKey) {
public OpenAiApi(String baseUrl, ApiKey apiKey) {
this(baseUrl, apiKey, RestClient.builder(), WebClient.builder());
}

Expand All @@ -101,7 +106,7 @@ public OpenAiApi(String baseUrl, String apiKey) {
* @param restClientBuilder RestClient builder.
* @param webClientBuilder WebClient builder.
*/
public OpenAiApi(String baseUrl, String apiKey, RestClient.Builder restClientBuilder,
public OpenAiApi(String baseUrl, ApiKey apiKey, RestClient.Builder restClientBuilder,
WebClient.Builder webClientBuilder) {
this(baseUrl, apiKey, restClientBuilder, webClientBuilder, RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER);
}
Expand All @@ -114,7 +119,7 @@ public OpenAiApi(String baseUrl, String apiKey, RestClient.Builder restClientBui
* @param webClientBuilder WebClient builder.
* @param responseErrorHandler Response error handler.
*/
public OpenAiApi(String baseUrl, String apiKey, RestClient.Builder restClientBuilder,
public OpenAiApi(String baseUrl, ApiKey apiKey, RestClient.Builder restClientBuilder,
WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) {
this(baseUrl, apiKey, "/v1/chat/completions", "/v1/embeddings", restClientBuilder, webClientBuilder,
responseErrorHandler);
Expand All @@ -130,7 +135,7 @@ public OpenAiApi(String baseUrl, String apiKey, RestClient.Builder restClientBui
* @param webClientBuilder WebClient builder.
* @param responseErrorHandler Response error handler.
*/
public OpenAiApi(String baseUrl, String apiKey, String completionsPath, String embeddingsPath,
public OpenAiApi(String baseUrl, ApiKey apiKey, String completionsPath, String embeddingsPath,
RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder,
ResponseErrorHandler responseErrorHandler) {

Expand All @@ -149,19 +154,19 @@ public OpenAiApi(String baseUrl, String apiKey, String completionsPath, String e
* @param webClientBuilder WebClient builder.
* @param responseErrorHandler Response error handler.
*/
public OpenAiApi(String baseUrl, String apiKey, MultiValueMap<String, String> headers, String completionsPath,
public OpenAiApi(String baseUrl, ApiKey apiKey, MultiValueMap<String, String> headers, String completionsPath,
String embeddingsPath, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder,
ResponseErrorHandler responseErrorHandler) {

Assert.hasText(completionsPath, "Completions Path must not be null");
Assert.hasText(embeddingsPath, "Embeddings Path must not be null");
Assert.notNull(headers, "Headers must not be null");

this.apiKey = apiKey;
this.completionsPath = completionsPath;
this.embeddingsPath = embeddingsPath;
// @formatter:off
Consumer<HttpHeaders> finalHeaders = h -> {
h.setBearerAuth(apiKey);
h.setContentType(MediaType.APPLICATION_JSON);
h.addAll(headers);
};
Expand Down Expand Up @@ -208,12 +213,12 @@ public ResponseEntity<ChatCompletion> chatCompletionEntity(ChatCompletionRequest
Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false.");
Assert.notNull(additionalHttpHeader, "The additional HTTP headers can not be null.");

return this.restClient.post()
.uri(this.completionsPath)
.headers(headers -> headers.addAll(additionalHttpHeader))
.body(chatRequest)
.retrieve()
.toEntity(ChatCompletion.class);
return this.restClient.post().uri(this.completionsPath).headers(headers -> {
headers.addAll(additionalHttpHeader);
if (!additionalHttpHeader.containsKey(HttpHeaders.AUTHORIZATION)) {
headers.setBearerAuth(apiKey.getValue());
}
}).body(chatRequest).retrieve().toEntity(ChatCompletion.class);
}

/**
Expand Down Expand Up @@ -242,9 +247,12 @@ public Flux<ChatCompletionChunk> chatCompletionStream(ChatCompletionRequest chat

AtomicBoolean isInsideTool = new AtomicBoolean(false);

return this.webClient.post()
.uri(this.completionsPath)
.headers(headers -> headers.addAll(additionalHttpHeader))
return this.webClient.post().uri(this.completionsPath).headers(headers -> {
headers.addAll(additionalHttpHeader);
if (!additionalHttpHeader.containsKey(HttpHeaders.AUTHORIZATION)) {
headers.setBearerAuth(apiKey.getValue());
}
})
.body(Mono.just(chatRequest), ChatCompletionRequest.class)
.retrieve()
.bodyToFlux(String.class)
Expand Down Expand Up @@ -318,6 +326,7 @@ public <T> ResponseEntity<EmbeddingList<Embedding>> embeddings(EmbeddingRequest<

return this.restClient.post()
.uri(this.embeddingsPath)
.headers(headers -> headers.setBearerAuth(apiKey.getValue()))
.body(embeddingRequest)
.retrieve()
.toEntity(new ParameterizedTypeReference<>() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.function.FunctionCallbackWrapper;
import org.springframework.ai.model.security.StaticApiKey;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.api.tool.MockWeatherService;

Expand All @@ -35,7 +36,7 @@ public class ChatCompletionRequestTests {
@Test
public void createRequestWithChatOptions() {

var client = new OpenAiChatModel(new OpenAiApi("TEST"),
var client = new OpenAiChatModel(new OpenAiApi(new StaticApiKey("TEST")),
OpenAiChatOptions.builder().withModel("DEFAULT_MODEL").withTemperature(66.6).build());

var request = client.createRequest(new Prompt("Test message content"), false);
Expand All @@ -61,7 +62,7 @@ public void promptOptionsTools() {

final String TOOL_FUNCTION_NAME = "CurrentWeather";

var client = new OpenAiChatModel(new OpenAiApi("TEST"),
var client = new OpenAiChatModel(new OpenAiApi(new StaticApiKey("TEST")),
OpenAiChatOptions.builder().withModel("DEFAULT_MODEL").build());

var request = client.createRequest(new Prompt("Test message content",
Expand Down Expand Up @@ -91,7 +92,7 @@ public void defaultOptionsTools() {

final String TOOL_FUNCTION_NAME = "CurrentWeather";

var client = new OpenAiChatModel(new OpenAiApi("TEST"),
var client = new OpenAiChatModel(new OpenAiApi(new StaticApiKey("TEST")),
OpenAiChatOptions.builder()
.withModel("DEFAULT_MODEL")
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package org.springframework.ai.openai;

import org.springframework.ai.model.security.StaticApiKey;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.api.OpenAiApi.ChatModel;
import org.springframework.ai.openai.api.OpenAiAudioApi;
Expand All @@ -30,7 +31,7 @@ public class OpenAiTestConfiguration {

@Bean
public OpenAiApi openAiApi() {
return new OpenAiApi(getApiKey());
return new OpenAiApi(new StaticApiKey(getApiKey()));
}

@Bean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.springframework.ai.model.security.StaticApiKey;
import reactor.core.publisher.Flux;

import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion;
Expand All @@ -39,7 +40,7 @@
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
public class OpenAiApiIT {

OpenAiApi openAiApi = new OpenAiApi(System.getenv("OPENAI_API_KEY"));
OpenAiApi openAiApi = new OpenAiApi(new StaticApiKey(System.getenv("OPENAI_API_KEY")));

@Test
void chatCompletionEntity() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.slf4j.LoggerFactory;

import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.security.StaticApiKey;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage;
Expand All @@ -51,7 +52,7 @@ public class OpenAiApiToolFunctionCallIT {

MockWeatherService weatherService = new MockWeatherService();

OpenAiApi completionApi = new OpenAiApi(System.getenv("OPENAI_API_KEY"));
OpenAiApi completionApi = new OpenAiApi(new StaticApiKey(System.getenv("OPENAI_API_KEY")));

private static <T> T fromJson(String json, Class<T> targetClass) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.security.StaticApiKey;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.api.OpenAiApi;
Expand Down Expand Up @@ -67,7 +68,7 @@ static class Config {

@Bean
public OpenAiApi chatCompletionApi() {
return new OpenAiApi("Invalid API Key");
return new OpenAiApi(new StaticApiKey("Invalid API Key"));
}

@Bean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.model.security.StaticApiKey;
import reactor.core.publisher.Flux;

import org.springframework.ai.chat.messages.AssistantMessage;
Expand Down Expand Up @@ -197,7 +198,7 @@ static class Config {

@Bean
public OpenAiApi chatCompletionApi() {
return new OpenAiApi(System.getenv("OPENAI_API_KEY"));
return new OpenAiApi(new StaticApiKey(System.getenv("OPENAI_API_KEY")));
}

@Bean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.springframework.ai.model.security.StaticApiKey;
import reactor.core.publisher.Flux;

import org.springframework.ai.chat.metadata.ChatResponseMetadata;
Expand Down Expand Up @@ -169,7 +170,7 @@ public TestObservationRegistry observationRegistry() {

@Bean
public OpenAiApi openAiApi() {
return new OpenAiApi(System.getenv("OPENAI_API_KEY"));
return new OpenAiApi(new StaticApiKey(System.getenv("OPENAI_API_KEY")));
}

@Bean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.model.security.StaticApiKey;
import reactor.core.publisher.Flux;

import org.springframework.ai.chat.messages.AssistantMessage;
Expand Down Expand Up @@ -355,7 +356,7 @@ static class Config {

@Bean
public OpenAiApi chatCompletionApi() {
return new OpenAiApi(System.getenv("OPENAI_API_KEY"));
return new OpenAiApi(new StaticApiKey(System.getenv("OPENAI_API_KEY")));
}

@Bean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.converter.BeanOutputConverter;
import org.springframework.ai.model.security.StaticApiKey;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.api.OpenAiApi;
Expand Down Expand Up @@ -234,7 +235,7 @@ static class Config {

@Bean
public OpenAiApi chatCompletionApi() {
return new OpenAiApi(System.getenv("OPENAI_API_KEY"));
return new OpenAiApi(new StaticApiKey(System.getenv("OPENAI_API_KEY")));
}

@Bean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.security.ApiKey;
import org.springframework.ai.model.security.StaticApiKey;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.metadata.support.OpenAiApiResponseHeaders;
Expand Down Expand Up @@ -56,7 +58,7 @@
@RestClientTest(OpenAiChatModelWithChatResponseMetadataTests.Config.class)
public class OpenAiChatModelWithChatResponseMetadataTests {

private static String TEST_API_KEY = "sk-1234567890";
private static ApiKey TEST_API_KEY = new StaticApiKey("sk-1234567890");

@Autowired
private OpenAiChatModel openAiChatClient;
Expand Down Expand Up @@ -135,7 +137,7 @@ private void prepareMock() {

this.server.expect(requestTo("/v1/chat/completions"))
.andExpect(method(HttpMethod.POST))
.andExpect(header(HttpHeaders.AUTHORIZATION, "Bearer " + TEST_API_KEY))
.andExpect(header(HttpHeaders.AUTHORIZATION, "Bearer " + TEST_API_KEY.getValue()))
.andRespond(withSuccess(getJson(), MediaType.APPLICATION_JSON).headers(httpHeaders));

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.springframework.ai.model.security.StaticApiKey;
import reactor.core.publisher.Flux;

import org.springframework.ai.chat.messages.AssistantMessage;
Expand Down Expand Up @@ -55,16 +56,18 @@ static OpenAiChatOptions forModelName(String modelName) {
static Stream<ChatModel> openAiCompatibleApis() {
Stream.Builder<ChatModel> builder = Stream.builder();

builder.add(new OpenAiChatModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")), forModelName("gpt-3.5-turbo")));
builder.add(new OpenAiChatModel(new OpenAiApi(new StaticApiKey(System.getenv("OPENAI_API_KEY"))),
forModelName("gpt-3.5-turbo")));

if (System.getenv("GROQ_API_KEY") != null) {
builder.add(new OpenAiChatModel(new OpenAiApi("https://api.groq.com/openai", System.getenv("GROQ_API_KEY")),
builder.add(new OpenAiChatModel(
new OpenAiApi("https://api.groq.com/openai", new StaticApiKey(System.getenv("GROQ_API_KEY"))),
forModelName("llama3-8b-8192")));
}

if (System.getenv("OPEN_ROUTER_API_KEY") != null) {
builder.add(new OpenAiChatModel(
new OpenAiApi("https://openrouter.ai/api", System.getenv("OPEN_ROUTER_API_KEY")),
new OpenAiApi("https://openrouter.ai/api", new StaticApiKey(System.getenv("OPEN_ROUTER_API_KEY"))),
forModelName("meta-llama/llama-3-8b-instruct")));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.junit.jupiter.params.provider.ValueSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.model.security.StaticApiKey;
import reactor.core.publisher.Flux;

import org.springframework.ai.chat.client.ChatClient;
Expand Down Expand Up @@ -216,7 +217,7 @@ public ChatClient chatClient(OpenAiChatModel chatModel) {

@Bean
public OpenAiApi chatCompletionApi() {
return new OpenAiApi(System.getenv("OPENAI_API_KEY"));
return new OpenAiApi(new StaticApiKey(System.getenv("OPENAI_API_KEY")));
}

@Bean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.junit.jupiter.params.provider.ValueSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.model.security.StaticApiKey;
import reactor.core.publisher.Flux;

import org.springframework.ai.chat.client.ChatClient;
Expand Down Expand Up @@ -380,7 +381,7 @@ static class Config {

@Bean
public OpenAiApi chatCompletionApi() {
return new OpenAiApi(GROQ_BASE_URL, System.getenv("GROQ_API_KEY"));
return new OpenAiApi(GROQ_BASE_URL, new StaticApiKey(System.getenv("GROQ_API_KEY")));
}

@Bean
Expand Down
Loading