Skip to content

Commit b97a535

Browse files
committed
Implementation of the OpenAI Java SDK
Create the OpenAiOfficialChatModel class: - This is a first implementation that is not of good quality yet - Tests do not pass yet Signed-off-by: Julien Dubois <[email protected]>
1 parent 6ce657a commit b97a535

File tree

8 files changed

+1029
-62
lines changed

8 files changed

+1029
-62
lines changed

models/spring-ai-openai-official/src/main/java/org/springframework/ai/openaiofficial/OpenAiOfficialChatModel.java

Lines changed: 784 additions & 0 deletions
Large diffs are not rendered by default.

models/spring-ai-openai-official/src/main/java/org/springframework/ai/openaiofficial/OpenAiOfficialChatOptions.java

Lines changed: 17 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,17 @@
3737
import java.util.Objects;
3838
import java.util.Set;
3939

40+
import static com.openai.models.ChatModel.GPT_5_MINI;
41+
4042
/**
4143
* Configuration information for the Chat Model implementation using the OpenAI Java SDK.
4244
*
4345
* @author Julien Dubois
4446
*/
4547
public class OpenAiOfficialChatOptions extends AbstractOpenAiOfficialOptions implements ToolCallingChatOptions {
4648

49+
public static final String DEFAULT_CHAT_MODEL = GPT_5_MINI.asString();
50+
4751
private static final Logger logger = LoggerFactory.getLogger(OpenAiOfficialChatOptions.class);
4852

4953
private Double frequencyPenalty;
@@ -60,8 +64,6 @@ public class OpenAiOfficialChatOptions extends AbstractOpenAiOfficialOptions imp
6064

6165
private Integer n;
6266

63-
private List<String> outputModalities;
64-
6567
private ChatCompletionAudioParam outputAudio;
6668

6769
private Double presencePenalty;
@@ -164,14 +166,6 @@ public void setN(Integer n) {
164166
this.n = n;
165167
}
166168

167-
public List<String> getOutputModalities() {
168-
return this.outputModalities;
169-
}
170-
171-
public void setOutputModalities(List<String> outputModalities) {
172-
this.outputModalities = outputModalities;
173-
}
174-
175169
public ChatCompletionAudioParam getOutputAudio() {
176170
return this.outputAudio;
177171
}
@@ -396,8 +390,7 @@ public boolean equals(Object o) {
396390
return Objects.equals(frequencyPenalty, options.frequencyPenalty)
397391
&& Objects.equals(logitBias, options.logitBias) && Objects.equals(logprobs, options.logprobs)
398392
&& Objects.equals(topLogprobs, options.topLogprobs) && Objects.equals(maxTokens, options.maxTokens)
399-
&& Objects.equals(n, options.n) && Objects.equals(outputModalities, options.outputModalities)
400-
&& Objects.equals(outputAudio, options.outputAudio)
393+
&& Objects.equals(n, options.n) && Objects.equals(outputAudio, options.outputAudio)
401394
&& Objects.equals(presencePenalty, options.presencePenalty)
402395
&& Objects.equals(responseFormat, options.responseFormat)
403396
&& Objects.equals(streamOptions, options.streamOptions) && Objects.equals(seed, options.seed)
@@ -415,24 +408,24 @@ public boolean equals(Object o) {
415408

416409
@Override
417410
public int hashCode() {
418-
return Objects.hash(frequencyPenalty, logitBias, logprobs, topLogprobs, maxTokens, n, outputModalities,
419-
outputAudio, presencePenalty, responseFormat, streamOptions, seed, stop, temperature, topP, tools,
420-
toolChoice, user, parallelToolCalls, store, metadata, reasoningEffort, verbosity, serviceTier,
421-
toolCallbacks, toolNames, internalToolExecutionEnabled, httpHeaders, toolContext);
411+
return Objects.hash(frequencyPenalty, logitBias, logprobs, topLogprobs, maxTokens, n, outputAudio,
412+
presencePenalty, responseFormat, streamOptions, seed, stop, temperature, topP, tools, toolChoice, user,
413+
parallelToolCalls, store, metadata, reasoningEffort, verbosity, serviceTier, toolCallbacks, toolNames,
414+
internalToolExecutionEnabled, httpHeaders, toolContext);
422415
}
423416

424417
@Override
425418
public String toString() {
426419
return "OpenAiOfficialChatOptions{" + "frequencyPenalty=" + frequencyPenalty + ", logitBias=" + logitBias
427420
+ ", logprobs=" + logprobs + ", topLogprobs=" + topLogprobs + ", maxTokens=" + maxTokens + ", n=" + n
428-
+ ", outputModalities=" + outputModalities + ", outputAudio=" + outputAudio + ", presencePenalty="
429-
+ presencePenalty + ", responseFormat=" + responseFormat + ", streamOptions=" + streamOptions
430-
+ ", seed=" + seed + ", stop=" + stop + ", temperature=" + temperature + ", topP=" + topP + ", tools="
431-
+ tools + ", toolChoice=" + toolChoice + ", user='" + user + '\'' + ", parallelToolCalls="
432-
+ parallelToolCalls + ", store=" + store + ", metadata=" + metadata + ", reasoningEffort='"
433-
+ reasoningEffort + '\'' + ", verbosity='" + verbosity + '\'' + ", serviceTier='" + serviceTier + '\''
434-
+ ", toolCallbacks=" + toolCallbacks + ", toolNames=" + toolNames + ", internalToolExecutionEnabled="
435-
+ internalToolExecutionEnabled + ", httpHeaders=" + httpHeaders + ", toolContext=" + toolContext + '}';
421+
+ ", outputAudio=" + outputAudio + ", presencePenalty=" + presencePenalty + ", responseFormat="
422+
+ responseFormat + ", streamOptions=" + streamOptions + ", seed=" + seed + ", stop=" + stop
423+
+ ", temperature=" + temperature + ", topP=" + topP + ", tools=" + tools + ", toolChoice=" + toolChoice
424+
+ ", user='" + user + '\'' + ", parallelToolCalls=" + parallelToolCalls + ", store=" + store
425+
+ ", metadata=" + metadata + ", reasoningEffort='" + reasoningEffort + '\'' + ", verbosity='"
426+
+ verbosity + '\'' + ", serviceTier='" + serviceTier + '\'' + ", toolCallbacks=" + toolCallbacks
427+
+ ", toolNames=" + toolNames + ", internalToolExecutionEnabled=" + internalToolExecutionEnabled
428+
+ ", httpHeaders=" + httpHeaders + ", toolContext=" + toolContext + '}';
436429
}
437430

438431
public static final class Builder {
@@ -449,8 +442,6 @@ public Builder from(OpenAiOfficialChatOptions fromOptions) {
449442
this.options.setMaxTokens(fromOptions.getMaxTokens());
450443
this.options.setMaxCompletionTokens(fromOptions.getMaxCompletionTokens());
451444
this.options.setN(fromOptions.getN());
452-
this.options.setOutputModalities(fromOptions.getOutputModalities() != null
453-
? new ArrayList<>(fromOptions.getOutputModalities()) : null);
454445
this.options.setOutputAudio(fromOptions.getOutputAudio());
455446
this.options.setPresencePenalty(fromOptions.getPresencePenalty());
456447
this.options.setResponseFormat(fromOptions.getResponseFormat());
@@ -505,9 +496,6 @@ public Builder merge(OpenAiOfficialChatOptions from) {
505496
if (from.getN() != null) {
506497
this.options.setN(from.getN());
507498
}
508-
if (from.getOutputModalities() != null) {
509-
this.options.setOutputModalities(new ArrayList<>(from.getOutputModalities()));
510-
}
511499
if (from.getOutputAudio() != null) {
512500
this.options.setOutputAudio(from.getOutputAudio());
513501
}
@@ -636,11 +624,6 @@ public Builder N(Integer n) {
636624
return this;
637625
}
638626

639-
public Builder outputModalities(List<String> modalities) {
640-
this.options.setOutputModalities(modalities);
641-
return this;
642-
}
643-
644627
public Builder outputAudio(ChatCompletionAudioParam audio) {
645628
this.options.setOutputAudio(audio);
646629
return this;

models/spring-ai-openai-official/src/main/java/org/springframework/ai/openaiofficial/setup/OpenAiOfficialSetup.java

Lines changed: 90 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
import com.openai.azure.AzureOpenAIServiceVersion;
2020
import com.openai.azure.credential.AzureApiKeyCredential;
2121
import com.openai.client.OpenAIClient;
22+
import com.openai.client.OpenAIClientAsync;
2223
import com.openai.client.okhttp.OpenAIOkHttpClient;
24+
import com.openai.client.okhttp.OpenAIOkHttpClientAsync;
2325
import com.openai.credential.Credential;
2426
import org.slf4j.Logger;
2527
import org.slf4j.LoggerFactory;
@@ -64,32 +66,64 @@ public static OpenAIClient setupSyncClient(String baseUrl, String apiKey, Creden
6466
boolean isAzure, boolean isGitHubModels, String modelName, Duration timeout, Integer maxRetries,
6567
Proxy proxy, Map<String, String> customHeaders) {
6668

67-
if (apiKey == null && credential == null) {
68-
var openAiKey = System.getenv("OPENAI_API_KEY");
69-
if (openAiKey != null) {
70-
apiKey = openAiKey;
71-
logger.debug("OpenAI API Key detected from environment variable OPENAI_API_KEY.");
72-
}
73-
var azureOpenAiKey = System.getenv("AZURE_OPENAI_KEY");
74-
if (azureOpenAiKey != null) {
75-
apiKey = azureOpenAiKey;
76-
logger.debug("Azure OpenAI Key detected from environment variable AZURE_OPENAI_KEY.");
77-
}
69+
baseUrl = detectBaseUrlFromEnv(baseUrl);
70+
var modelHost = detectModelHost(isAzure, isGitHubModels, baseUrl, azureDeploymentName,
71+
azureOpenAiServiceVersion);
72+
if (timeout == null) {
73+
timeout = DEFAULT_DURATION;
7874
}
79-
if (baseUrl == null) {
80-
var openAiBaseUrl = System.getenv("OPENAI_BASE_URL");
81-
if (openAiBaseUrl != null) {
82-
baseUrl = openAiBaseUrl;
83-
logger.debug("OpenAI Base URL detected from environment variable OPENAI_BASE_URL.");
84-
}
85-
var azureOpenAiBaseUrl = System.getenv("AZURE_OPENAI_BASE_URL");
86-
if (azureOpenAiBaseUrl != null) {
87-
baseUrl = azureOpenAiBaseUrl;
88-
logger.debug("Azure OpenAI Base URL detected from environment variable AZURE_OPENAI_BASE_URL.");
89-
}
75+
if (maxRetries == null) {
76+
maxRetries = DEFAULT_MAX_RETRIES;
77+
}
78+
79+
OpenAIOkHttpClient.Builder builder = OpenAIOkHttpClient.builder();
80+
builder
81+
.baseUrl(calculateBaseUrl(baseUrl, modelHost, modelName, azureDeploymentName, azureOpenAiServiceVersion));
82+
83+
Credential calculatedCredential = calculateCredential(modelHost, apiKey, credential);
84+
String calculatedApiKey = calculateApiKey(modelHost, apiKey);
85+
if (calculatedCredential == null && calculatedApiKey == null) {
86+
throw new IllegalArgumentException("Either apiKey or credential must be set to authenticate");
87+
}
88+
else if (calculatedCredential != null) {
89+
builder.credential(calculatedCredential);
90+
}
91+
else {
92+
builder.apiKey(calculatedApiKey);
93+
}
94+
builder.organization(organizationId);
95+
96+
if (azureOpenAiServiceVersion != null) {
97+
builder.azureServiceVersion(azureOpenAiServiceVersion);
9098
}
9199

92-
ModelHost modelHost = detectModelHost(isAzure, isGitHubModels, baseUrl, azureDeploymentName,
100+
if (proxy != null) {
101+
builder.proxy(proxy);
102+
}
103+
104+
builder.putHeader("User-Agent", DEFAULT_USER_AGENT);
105+
if (customHeaders != null) {
106+
builder.putAllHeaders(customHeaders.entrySet()
107+
.stream()
108+
.collect(Collectors.toMap(Map.Entry::getKey, entry -> Collections.singletonList(entry.getValue()))));
109+
}
110+
111+
builder.timeout(timeout);
112+
builder.maxRetries(maxRetries);
113+
return builder.build();
114+
}
115+
116+
/**
117+
* The asynchronous client setup is the same as the synchronous one in the OpenAI Java
118+
* SDK, but uses a different client implementation.
119+
*/
120+
public static OpenAIClientAsync setupAsyncClient(String baseUrl, String apiKey, Credential credential,
121+
String azureDeploymentName, AzureOpenAIServiceVersion azureOpenAiServiceVersion, String organizationId,
122+
boolean isAzure, boolean isGitHubModels, String modelName, Duration timeout, Integer maxRetries,
123+
Proxy proxy, Map<String, String> customHeaders) {
124+
125+
baseUrl = detectBaseUrlFromEnv(baseUrl);
126+
var modelHost = detectModelHost(isAzure, isGitHubModels, baseUrl, azureDeploymentName,
93127
azureOpenAiServiceVersion);
94128
if (timeout == null) {
95129
timeout = DEFAULT_DURATION;
@@ -98,7 +132,7 @@ public static OpenAIClient setupSyncClient(String baseUrl, String apiKey, Creden
98132
maxRetries = DEFAULT_MAX_RETRIES;
99133
}
100134

101-
OpenAIOkHttpClient.Builder builder = OpenAIOkHttpClient.builder();
135+
OpenAIOkHttpClientAsync.Builder builder = OpenAIOkHttpClientAsync.builder();
102136
builder
103137
.baseUrl(calculateBaseUrl(baseUrl, modelHost, modelName, azureDeploymentName, azureOpenAiServiceVersion));
104138

@@ -135,8 +169,25 @@ else if (calculatedCredential != null) {
135169
return builder.build();
136170
}
137171

172+
static String detectBaseUrlFromEnv(String baseUrl) {
173+
if (baseUrl == null) {
174+
var openAiBaseUrl = System.getenv("OPENAI_BASE_URL");
175+
if (openAiBaseUrl != null) {
176+
baseUrl = openAiBaseUrl;
177+
logger.debug("OpenAI Base URL detected from environment variable OPENAI_BASE_URL.");
178+
}
179+
var azureOpenAiBaseUrl = System.getenv("AZURE_OPENAI_BASE_URL");
180+
if (azureOpenAiBaseUrl != null) {
181+
baseUrl = azureOpenAiBaseUrl;
182+
logger.debug("Azure OpenAI Base URL detected from environment variable AZURE_OPENAI_BASE_URL.");
183+
}
184+
}
185+
return baseUrl;
186+
}
187+
138188
static ModelHost detectModelHost(boolean isAzure, boolean isGitHubModels, String baseUrl,
139189
String azureDeploymentName, AzureOpenAIServiceVersion azureOpenAIServiceVersion) {
190+
140191
if (isAzure) {
141192
return ModelHost.AZURE_OPENAI; // Forced by the user
142193
}
@@ -159,8 +210,9 @@ else if (baseUrl.startsWith(GITHUB_MODELS_URL)) {
159210
return ModelHost.OPENAI;
160211
}
161212

162-
static String calculateBaseUrl(final String baseUrl, ModelHost modelHost, String modelName,
163-
String azureDeploymentName, AzureOpenAIServiceVersion azureOpenAiServiceVersion) {
213+
static String calculateBaseUrl(String baseUrl, ModelHost modelHost, String modelName, String azureDeploymentName,
214+
AzureOpenAIServiceVersion azureOpenAiServiceVersion) {
215+
164216
if (modelHost == ModelHost.OPENAI) {
165217
if (baseUrl == null || baseUrl.isBlank()) {
166218
return OPENAI_URL;
@@ -211,6 +263,18 @@ else if (modelHost == ModelHost.AZURE_OPENAI) {
211263
}
212264

213265
static String calculateApiKey(ModelHost modelHost, String apiKey) {
266+
if (apiKey == null) {
267+
var openAiKey = System.getenv("OPENAI_API_KEY");
268+
if (openAiKey != null) {
269+
apiKey = openAiKey;
270+
logger.debug("OpenAI API Key detected from environment variable OPENAI_API_KEY.");
271+
}
272+
var azureOpenAiKey = System.getenv("AZURE_OPENAI_KEY");
273+
if (azureOpenAiKey != null) {
274+
apiKey = azureOpenAiKey;
275+
logger.debug("Azure OpenAI Key detected from environment variable AZURE_OPENAI_KEY.");
276+
}
277+
}
214278
if (modelHost != ModelHost.AZURE_OPENAI && apiKey != null) {
215279
return apiKey;
216280
}

models/spring-ai-openai-official/src/test/java/org/springframework/ai/openaiofficial/OpenAiOfficialTestConfiguration.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,9 @@ public OpenAiOfficialImageModel openAiImageModel() {
3737
return new OpenAiOfficialImageModel();
3838
}
3939

40+
@Bean
41+
public OpenAiOfficialChatModel openAiChatModel() {
42+
return new OpenAiOfficialChatModel();
43+
}
44+
4045
}

0 commit comments

Comments
 (0)