Skip to content

Commit 9774b4e

Browse files
markpollacksobychacko
authored andcommitted
feat: add mutate functionality for OpenAiApi and OpenAiChatModel builders
- Introduced mutate() methods to OpenAiApi and OpenAiChatModel, enabling creation of new builder instances from existing objects. - Allows safe modification and copying of configuration for APIs and models. - Refactored internal fields and getters to support mutation/copy patterns. - Updated integration tests to leverage mutate for dynamic client creation.
1 parent bf16617 commit 9774b4e

File tree

5 files changed

+323
-3
lines changed

5 files changed

+323
-3
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -694,8 +694,30 @@ public static Builder builder() {
694694
return new Builder();
695695
}
696696

697+
/**
698+
* Returns a builder pre-populated with the current configuration for mutation.
699+
*/
700+
public Builder mutate() {
701+
return new Builder(this);
702+
}
703+
704+
@Override
705+
public OpenAiChatModel clone() {
706+
return this.mutate().build();
707+
}
708+
697709
public static final class Builder {
698710

711+
// Copy constructor for mutate()
712+
public Builder(OpenAiChatModel model) {
713+
this.openAiApi = model.openAiApi;
714+
this.defaultOptions = model.defaultOptions;
715+
this.toolCallingManager = model.toolCallingManager;
716+
this.toolExecutionEligibilityPredicate = model.toolExecutionEligibilityPredicate;
717+
this.retryTemplate = model.retryTemplate;
718+
this.observationRegistry = model.observationRegistry;
719+
}
720+
699721
private OpenAiApi openAiApi;
700722

701723
private OpenAiChatOptions defaultOptions = OpenAiChatOptions.builder()

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,13 @@
6565
*/
6666
public class OpenAiApi {
6767

68+
/**
69+
* Returns a builder pre-populated with the current configuration for mutation.
70+
*/
71+
public Builder mutate() {
72+
return new Builder(this);
73+
}
74+
6875
public static Builder builder() {
6976
return new Builder();
7077
}
@@ -75,10 +82,19 @@ public static Builder builder() {
7582

7683
private static final Predicate<String> SSE_DONE_PREDICATE = "[DONE]"::equals;
7784

85+
// Store config fields for mutate/copy
86+
private final String baseUrl;
87+
88+
private final ApiKey apiKey;
89+
90+
private final MultiValueMap<String, String> headers;
91+
7892
private final String completionsPath;
7993

8094
private final String embeddingsPath;
8195

96+
private final ResponseErrorHandler responseErrorHandler;
97+
8298
private final RestClient restClient;
8399

84100
private final WebClient webClient;
@@ -99,13 +115,17 @@ public static Builder builder() {
99115
public OpenAiApi(String baseUrl, ApiKey apiKey, MultiValueMap<String, String> headers, String completionsPath,
100116
String embeddingsPath, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder,
101117
ResponseErrorHandler responseErrorHandler) {
118+
this.baseUrl = baseUrl;
119+
this.apiKey = apiKey;
120+
this.headers = headers;
121+
this.completionsPath = completionsPath;
122+
this.embeddingsPath = embeddingsPath;
123+
this.responseErrorHandler = responseErrorHandler;
102124

103125
Assert.hasText(completionsPath, "Completions Path must not be null");
104126
Assert.hasText(embeddingsPath, "Embeddings Path must not be null");
105127
Assert.notNull(headers, "Headers must not be null");
106128

107-
this.completionsPath = completionsPath;
108-
this.embeddingsPath = embeddingsPath;
109129
// @formatter:off
110130
Consumer<HttpHeaders> finalHeaders = h -> {
111131
if (!(apiKey instanceof NoopApiKey)) {
@@ -1773,6 +1793,21 @@ public record EmbeddingList<T>(// @formatter:off
17731793

17741794
public static class Builder {
17751795

1796+
public Builder() {
1797+
}
1798+
1799+
// Copy constructor for mutate()
1800+
public Builder(OpenAiApi api) {
1801+
this.baseUrl = api.getBaseUrl();
1802+
this.apiKey = api.getApiKey();
1803+
this.headers = new LinkedMultiValueMap<>(api.getHeaders());
1804+
this.completionsPath = api.getCompletionsPath();
1805+
this.embeddingsPath = api.getEmbeddingsPath();
1806+
this.restClientBuilder = api.restClient != null ? api.restClient.mutate() : RestClient.builder();
1807+
this.webClientBuilder = api.webClient != null ? api.webClient.mutate() : WebClient.builder();
1808+
this.responseErrorHandler = api.getResponseErrorHandler();
1809+
}
1810+
17761811
private String baseUrl = OpenAiApiConstants.DEFAULT_BASE_URL;
17771812

17781813
private ApiKey apiKey;
@@ -1851,4 +1886,29 @@ public OpenAiApi build() {
18511886

18521887
}
18531888

1889+
// Package-private getters for mutate/copy
1890+
String getBaseUrl() {
1891+
return this.baseUrl;
1892+
}
1893+
1894+
ApiKey getApiKey() {
1895+
return this.apiKey;
1896+
}
1897+
1898+
MultiValueMap<String, String> getHeaders() {
1899+
return this.headers;
1900+
}
1901+
1902+
String getCompletionsPath() {
1903+
return this.completionsPath;
1904+
}
1905+
1906+
String getEmbeddingsPath() {
1907+
return this.embeddingsPath;
1908+
}
1909+
1910+
ResponseErrorHandler getResponseErrorHandler() {
1911+
return this.responseErrorHandler;
1912+
}
1913+
18541914
}
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
/*
2+
* Copyright 2025-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.openai.api;
18+
19+
import org.junit.jupiter.api.Test;
20+
21+
import org.springframework.ai.chat.client.ChatClient;
22+
import org.springframework.ai.openai.OpenAiChatModel;
23+
import org.springframework.ai.openai.OpenAiChatOptions;
24+
import org.springframework.util.LinkedMultiValueMap;
25+
26+
import static org.assertj.core.api.Assertions.assertThat;
27+
28+
/*
29+
* Integration test for mutate/clone functionality on OpenAiApi and OpenAiChatModel.
30+
* This test demonstrates creating multiple ChatClient instances with different endpoints and options
31+
* from a single autoconfigured OpenAiChatModel/OpenAiApi, as per the feature request.
32+
*/
33+
class OpenAiChatModelMutateTests {
34+
35+
// Simulate autoconfigured base beans (in real usage, these would be @Autowired)
36+
private final OpenAiApi baseApi = OpenAiApi.builder().baseUrl("https://api.openai.com").apiKey("base-key").build();
37+
38+
private final OpenAiChatModel baseModel = OpenAiChatModel.builder()
39+
.openAiApi(baseApi)
40+
.defaultOptions(OpenAiChatOptions.builder().model("gpt-3.5-turbo").build())
41+
.build();
42+
43+
@Test
44+
void testMutateCreatesDistinctClientsWithDifferentEndpointsAndModels() {
45+
// Mutate for GPT-4
46+
OpenAiApi gpt4Api = baseApi.mutate().baseUrl("https://api.openai.com").apiKey("your-api-key-for-gpt4").build();
47+
OpenAiChatModel gpt4Model = baseModel.mutate()
48+
.openAiApi(gpt4Api)
49+
.defaultOptions(OpenAiChatOptions.builder().model("gpt-4").temperature(0.7).build())
50+
.build();
51+
ChatClient gpt4Client = ChatClient.builder(gpt4Model).build();
52+
53+
// Mutate for Llama
54+
OpenAiApi llamaApi = baseApi.mutate()
55+
.baseUrl("https://your-custom-endpoint.com")
56+
.apiKey("your-api-key-for-llama")
57+
.build();
58+
OpenAiChatModel llamaModel = baseModel.mutate()
59+
.openAiApi(llamaApi)
60+
.defaultOptions(OpenAiChatOptions.builder().model("llama-70b").temperature(0.5).build())
61+
.build();
62+
ChatClient llamaClient = ChatClient.builder(llamaModel).build();
63+
64+
// Assert endpoints and models are different
65+
assertThat(gpt4Model).isNotSameAs(llamaModel);
66+
assertThat(gpt4Api).isNotSameAs(llamaApi);
67+
assertThat(gpt4Model.toString()).contains("gpt-4");
68+
assertThat(llamaModel.toString()).contains("llama-70b");
69+
// Optionally, assert endpoints
70+
// (In real usage, you might expose/get the baseUrl for assertion)
71+
}
72+
73+
@Test
74+
void testCloneCreatesDeepCopy() {
75+
OpenAiChatModel clone = baseModel.clone();
76+
assertThat(clone).isNotSameAs(baseModel);
77+
assertThat(clone.toString()).isEqualTo(baseModel.toString());
78+
}
79+
80+
@Test
81+
void mutateDoesNotAffectOriginal() {
82+
OpenAiChatModel mutated = baseModel.mutate()
83+
.defaultOptions(OpenAiChatOptions.builder().model("gpt-4").build())
84+
.build();
85+
assertThat(mutated).isNotSameAs(baseModel);
86+
assertThat(mutated.getDefaultOptions().getModel()).isEqualTo("gpt-4");
87+
assertThat(baseModel.getDefaultOptions().getModel()).isEqualTo("gpt-3.5-turbo");
88+
}
89+
90+
@Test
91+
void mutateHeadersCreatesDistinctHeaders() {
92+
OpenAiApi mutatedApi = baseApi.mutate()
93+
.headers(new LinkedMultiValueMap<>(java.util.Map.of("X-Test", java.util.List.of("value"))))
94+
.build();
95+
96+
assertThat(mutatedApi.getHeaders()).containsKey("X-Test");
97+
assertThat(baseApi.getHeaders()).doesNotContainKey("X-Test");
98+
}
99+
100+
@Test
101+
void mutateHandlesNullAndDefaults() {
102+
OpenAiApi apiWithDefaults = OpenAiApi.builder().baseUrl("https://api.openai.com").apiKey("key").build();
103+
OpenAiApi mutated = apiWithDefaults.mutate().build();
104+
assertThat(mutated).isNotNull();
105+
assertThat(mutated.getBaseUrl()).isEqualTo("https://api.openai.com");
106+
assertThat(mutated.getApiKey().getValue()).isEqualTo("key");
107+
}
108+
109+
@Test
110+
void multipleSequentialMutationsProduceDistinctInstances() {
111+
OpenAiChatModel m1 = baseModel.mutate().defaultOptions(OpenAiChatOptions.builder().model("m1").build()).build();
112+
OpenAiChatModel m2 = m1.mutate().defaultOptions(OpenAiChatOptions.builder().model("m2").build()).build();
113+
OpenAiChatModel m3 = m2.mutate().defaultOptions(OpenAiChatOptions.builder().model("m3").build()).build();
114+
assertThat(m1).isNotSameAs(m2);
115+
assertThat(m2).isNotSameAs(m3);
116+
assertThat(m1.getDefaultOptions().getModel()).isEqualTo("m1");
117+
assertThat(m2.getDefaultOptions().getModel()).isEqualTo("m2");
118+
assertThat(m3.getDefaultOptions().getModel()).isEqualTo("m3");
119+
}
120+
121+
@Test
122+
void mutateAndCloneAreEquivalent() {
123+
OpenAiChatModel mutated = baseModel.mutate().build();
124+
OpenAiChatModel cloned = baseModel.clone();
125+
assertThat(mutated.toString()).isEqualTo(cloned.toString());
126+
assertThat(mutated).isNotSameAs(cloned);
127+
}
128+
129+
}

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/GroqWithOpenAiChatModelIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666

6767
@SpringBootTest(classes = GroqWithOpenAiChatModelIT.Config.class)
6868
@EnabledIfEnvironmentVariable(named = "GROQ_API_KEY", matches = ".+")
69-
@Disabled("Due to rate limiting it is hard to run it in one go")
69+
// @Disabled("Due to rate limiting it is hard to run it in one go")
7070
class GroqWithOpenAiChatModelIT {
7171

7272
private static final Logger logger = LoggerFactory.getLogger(GroqWithOpenAiChatModelIT.class);
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
/*
2+
* Copyright 2023-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.openai.chat.proxy;
18+
19+
import org.junit.jupiter.api.Test;
20+
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
21+
import org.slf4j.Logger;
22+
import org.slf4j.LoggerFactory;
23+
24+
import org.springframework.ai.chat.client.ChatClient;
25+
import org.springframework.ai.openai.OpenAiChatModel;
26+
import org.springframework.ai.openai.OpenAiChatOptions;
27+
import org.springframework.ai.openai.api.OpenAiApi;
28+
import org.springframework.beans.factory.annotation.Autowired;
29+
import org.springframework.boot.SpringBootConfiguration;
30+
import org.springframework.boot.test.context.SpringBootTest;
31+
import org.springframework.context.annotation.Bean;
32+
import org.springframework.test.context.ActiveProfiles;
33+
34+
import static org.assertj.core.api.Assertions.assertThat;
35+
36+
@SpringBootTest(classes = MultiOpenAiClientIT.Config.class)
37+
@EnabledIfEnvironmentVariable(named = "GROQ_API_KEY", matches = ".+")
38+
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
39+
@ActiveProfiles("logging-test")
40+
class MultiOpenAiClientIT {
41+
42+
private static final Logger logger = LoggerFactory.getLogger(MultiOpenAiClientIT.class);
43+
44+
@Autowired
45+
private OpenAiChatModel baseChatModel;
46+
47+
@Autowired
48+
private OpenAiApi baseOpenAiApi;
49+
50+
@Test
51+
void multiClientFlow() {
52+
// Derive a new OpenAiApi for Groq (Llama3)
53+
OpenAiApi groqApi = baseOpenAiApi.mutate()
54+
.baseUrl("https://api.groq.com/openai")
55+
.apiKey(System.getenv("GROQ_API_KEY"))
56+
.build();
57+
58+
// Derive a new OpenAiApi for OpenAI GPT-4
59+
OpenAiApi gpt4Api = baseOpenAiApi.mutate()
60+
.baseUrl("https://api.openai.com")
61+
.apiKey(System.getenv("OPENAI_API_KEY"))
62+
.build();
63+
64+
// Derive a new OpenAiChatModel for Groq
65+
OpenAiChatModel groqModel = baseChatModel.mutate()
66+
.openAiApi(groqApi)
67+
.defaultOptions(OpenAiChatOptions.builder().model("llama3-70b-8192").temperature(0.5).build())
68+
.build();
69+
70+
// Derive a new OpenAiChatModel for GPT-4
71+
OpenAiChatModel gpt4Model = baseChatModel.mutate()
72+
.openAiApi(gpt4Api)
73+
.defaultOptions(OpenAiChatOptions.builder().model("gpt-4").temperature(0.7).build())
74+
.build();
75+
76+
// Simple prompt for both models
77+
String prompt = "What is the capital of France?";
78+
79+
String groqResponse = ChatClient.builder(groqModel).build().prompt(prompt).call().content();
80+
String gpt4Response = ChatClient.builder(gpt4Model).build().prompt(prompt).call().content();
81+
82+
logger.info("Groq (Llama3) response: {}", groqResponse);
83+
logger.info("OpenAI GPT-4 response: {}", gpt4Response);
84+
85+
assertThat(groqResponse).containsIgnoringCase("Paris");
86+
assertThat(gpt4Response).containsIgnoringCase("Paris");
87+
88+
logger.info("OpenAI GPT-4 response: {}", gpt4Response);
89+
90+
assertThat(groqResponse).containsIgnoringCase("Paris");
91+
assertThat(gpt4Response).containsIgnoringCase("Paris");
92+
}
93+
94+
@SpringBootConfiguration
95+
static class Config {
96+
97+
@Bean
98+
public OpenAiApi chatCompletionApi() {
99+
return OpenAiApi.builder().baseUrl("foo").apiKey("bar").build();
100+
}
101+
102+
@Bean
103+
public OpenAiChatModel openAiClient(OpenAiApi openAiApi) {
104+
return OpenAiChatModel.builder().openAiApi(openAiApi).build();
105+
}
106+
107+
}
108+
109+
}

0 commit comments

Comments
 (0)