Skip to content

Commit ac06ddc

Browse files
committed
feat: add mutate functionality for OpenAiApi and OpenAiChatModel builders
- Introduced mutate() methods to OpenAiApi and OpenAiChatModel, enabling creation of new builder instances from existing objects. - Allows safe modification and copying of configuration for APIs and models. - Refactored internal fields and getters to support mutation/copy patterns. - Updated integration tests to leverage mutate for dynamic client creation.
1 parent 97f90b1 commit ac06ddc

File tree

5 files changed

+307
-3
lines changed

5 files changed

+307
-3
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -687,8 +687,30 @@ public static Builder builder() {
687687
return new Builder();
688688
}
689689

690+
/**
691+
* Returns a builder pre-populated with the current configuration for mutation.
692+
*/
693+
public Builder mutate() {
694+
return new Builder(this);
695+
}
696+
697+
@Override
698+
public OpenAiChatModel clone() {
699+
return this.mutate().build();
700+
}
701+
690702
public static final class Builder {
691703

704+
// Copy constructor for mutate()
705+
public Builder(OpenAiChatModel model) {
706+
this.openAiApi = model.openAiApi;
707+
this.defaultOptions = model.defaultOptions;
708+
this.toolCallingManager = model.toolCallingManager;
709+
this.toolExecutionEligibilityPredicate = model.toolExecutionEligibilityPredicate;
710+
this.retryTemplate = model.retryTemplate;
711+
this.observationRegistry = model.observationRegistry;
712+
}
713+
692714
private OpenAiApi openAiApi;
693715

694716
private OpenAiChatOptions defaultOptions = OpenAiChatOptions.builder()

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,13 @@
6565
*/
6666
public class OpenAiApi {
6767

68+
/**
69+
* Returns a builder pre-populated with the current configuration for mutation.
70+
*/
71+
public Builder mutate() {
72+
return new Builder(this);
73+
}
74+
6875
public static Builder builder() {
6976
return new Builder();
7077
}
@@ -75,10 +82,19 @@ public static Builder builder() {
7582

7683
private static final Predicate<String> SSE_DONE_PREDICATE = "[DONE]"::equals;
7784

85+
// Store config fields for mutate/copy
86+
private final String baseUrl;
87+
88+
private final ApiKey apiKey;
89+
90+
private final MultiValueMap<String, String> headers;
91+
7892
private final String completionsPath;
7993

8094
private final String embeddingsPath;
8195

96+
private final ResponseErrorHandler responseErrorHandler;
97+
8298
private final RestClient restClient;
8399

84100
private final WebClient webClient;
@@ -99,13 +115,17 @@ public static Builder builder() {
99115
public OpenAiApi(String baseUrl, ApiKey apiKey, MultiValueMap<String, String> headers, String completionsPath,
100116
String embeddingsPath, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder,
101117
ResponseErrorHandler responseErrorHandler) {
118+
this.baseUrl = baseUrl;
119+
this.apiKey = apiKey;
120+
this.headers = headers;
121+
this.completionsPath = completionsPath;
122+
this.embeddingsPath = embeddingsPath;
123+
this.responseErrorHandler = responseErrorHandler;
102124

103125
Assert.hasText(completionsPath, "Completions Path must not be null");
104126
Assert.hasText(embeddingsPath, "Embeddings Path must not be null");
105127
Assert.notNull(headers, "Headers must not be null");
106128

107-
this.completionsPath = completionsPath;
108-
this.embeddingsPath = embeddingsPath;
109129
// @formatter:off
110130
Consumer<HttpHeaders> finalHeaders = h -> {
111131
if (!(apiKey instanceof NoopApiKey)) {
@@ -1674,6 +1694,21 @@ public record EmbeddingList<T>(// @formatter:off
16741694

16751695
public static class Builder {
16761696

1697+
public Builder() {
1698+
}
1699+
1700+
// Copy constructor for mutate()
1701+
public Builder(OpenAiApi api) {
1702+
this.baseUrl = api.getBaseUrl();
1703+
this.apiKey = api.getApiKey();
1704+
this.headers = new LinkedMultiValueMap<>(api.getHeaders());
1705+
this.completionsPath = api.getCompletionsPath();
1706+
this.embeddingsPath = api.getEmbeddingsPath();
1707+
this.restClientBuilder = api.restClient != null ? api.restClient.mutate() : RestClient.builder();
1708+
this.webClientBuilder = api.webClient != null ? api.webClient.mutate() : WebClient.builder();
1709+
this.responseErrorHandler = api.getResponseErrorHandler();
1710+
}
1711+
16771712
private String baseUrl = OpenAiApiConstants.DEFAULT_BASE_URL;
16781713

16791714
private ApiKey apiKey;
@@ -1752,4 +1787,29 @@ public OpenAiApi build() {
17521787

17531788
}
17541789

1790+
// Package-private getters for mutate/copy
1791+
String getBaseUrl() {
1792+
return this.baseUrl;
1793+
}
1794+
1795+
ApiKey getApiKey() {
1796+
return this.apiKey;
1797+
}
1798+
1799+
MultiValueMap<String, String> getHeaders() {
1800+
return this.headers;
1801+
}
1802+
1803+
String getCompletionsPath() {
1804+
return this.completionsPath;
1805+
}
1806+
1807+
String getEmbeddingsPath() {
1808+
return this.embeddingsPath;
1809+
}
1810+
1811+
ResponseErrorHandler getResponseErrorHandler() {
1812+
return this.responseErrorHandler;
1813+
}
1814+
17551815
}
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
/*
2+
* Integration test for mutate/clone functionality on OpenAiApi and OpenAiChatModel.
3+
* This test demonstrates creating multiple ChatClient instances with different endpoints and options
4+
* from a single autoconfigured OpenAiChatModel/OpenAiApi, as per the feature request.
5+
*/
6+
package org.springframework.ai.openai.api;
7+
8+
import org.junit.jupiter.api.Test;
9+
10+
import org.springframework.ai.chat.client.ChatClient;
11+
import org.springframework.ai.openai.OpenAiChatModel;
12+
import org.springframework.ai.openai.OpenAiChatOptions;
13+
import org.springframework.util.LinkedMultiValueMap;
14+
15+
import static org.assertj.core.api.Assertions.assertThat;
16+
17+
class OpenAiChatModelMutateTests {
18+
19+
// Simulate autoconfigured base beans (in real usage, these would be @Autowired)
20+
private final OpenAiApi baseApi = OpenAiApi.builder().baseUrl("https://api.openai.com").apiKey("base-key").build();
21+
22+
private final OpenAiChatModel baseModel = OpenAiChatModel.builder()
23+
.openAiApi(baseApi)
24+
.defaultOptions(OpenAiChatOptions.builder().model("gpt-3.5-turbo").build())
25+
.build();
26+
27+
@Test
28+
void testMutateCreatesDistinctClientsWithDifferentEndpointsAndModels() {
29+
// Mutate for GPT-4
30+
OpenAiApi gpt4Api = baseApi.mutate().baseUrl("https://api.openai.com").apiKey("your-api-key-for-gpt4").build();
31+
OpenAiChatModel gpt4Model = baseModel.mutate()
32+
.openAiApi(gpt4Api)
33+
.defaultOptions(OpenAiChatOptions.builder().model("gpt-4").temperature(0.7).build())
34+
.build();
35+
ChatClient gpt4Client = ChatClient.builder(gpt4Model).build();
36+
37+
// Mutate for Llama
38+
OpenAiApi llamaApi = baseApi.mutate()
39+
.baseUrl("https://your-custom-endpoint.com")
40+
.apiKey("your-api-key-for-llama")
41+
.build();
42+
OpenAiChatModel llamaModel = baseModel.mutate()
43+
.openAiApi(llamaApi)
44+
.defaultOptions(OpenAiChatOptions.builder().model("llama-70b").temperature(0.5).build())
45+
.build();
46+
ChatClient llamaClient = ChatClient.builder(llamaModel).build();
47+
48+
// Assert endpoints and models are different
49+
assertThat(gpt4Model).isNotSameAs(llamaModel);
50+
assertThat(gpt4Api).isNotSameAs(llamaApi);
51+
assertThat(gpt4Model.toString()).contains("gpt-4");
52+
assertThat(llamaModel.toString()).contains("llama-70b");
53+
// Optionally, assert endpoints
54+
// (In real usage, you might expose/get the baseUrl for assertion)
55+
}
56+
57+
@Test
58+
void testCloneCreatesDeepCopy() {
59+
OpenAiChatModel clone = baseModel.clone();
60+
assertThat(clone).isNotSameAs(baseModel);
61+
assertThat(clone.toString()).isEqualTo(baseModel.toString());
62+
}
63+
64+
@Test
65+
void mutateDoesNotAffectOriginal() {
66+
OpenAiChatModel mutated = baseModel.mutate()
67+
.defaultOptions(OpenAiChatOptions.builder().model("gpt-4").build())
68+
.build();
69+
assertThat(mutated).isNotSameAs(baseModel);
70+
assertThat(mutated.getDefaultOptions().getModel()).isEqualTo("gpt-4");
71+
assertThat(baseModel.getDefaultOptions().getModel()).isEqualTo("gpt-3.5-turbo");
72+
}
73+
74+
@Test
75+
void mutateHeadersCreatesDistinctHeaders() {
76+
OpenAiApi mutatedApi = baseApi.mutate()
77+
.headers(new LinkedMultiValueMap<>(java.util.Map.of("X-Test", java.util.List.of("value"))))
78+
.build();
79+
80+
assertThat(mutatedApi.getHeaders()).containsKey("X-Test");
81+
assertThat(baseApi.getHeaders()).doesNotContainKey("X-Test");
82+
}
83+
84+
@Test
85+
void mutateHandlesNullAndDefaults() {
86+
OpenAiApi apiWithDefaults = OpenAiApi.builder().baseUrl("https://api.openai.com").apiKey("key").build();
87+
OpenAiApi mutated = apiWithDefaults.mutate().build();
88+
assertThat(mutated).isNotNull();
89+
assertThat(mutated.getBaseUrl()).isEqualTo("https://api.openai.com");
90+
assertThat(mutated.getApiKey().getValue()).isEqualTo("key");
91+
}
92+
93+
@Test
94+
void multipleSequentialMutationsProduceDistinctInstances() {
95+
OpenAiChatModel m1 = baseModel.mutate().defaultOptions(OpenAiChatOptions.builder().model("m1").build()).build();
96+
OpenAiChatModel m2 = m1.mutate().defaultOptions(OpenAiChatOptions.builder().model("m2").build()).build();
97+
OpenAiChatModel m3 = m2.mutate().defaultOptions(OpenAiChatOptions.builder().model("m3").build()).build();
98+
assertThat(m1).isNotSameAs(m2);
99+
assertThat(m2).isNotSameAs(m3);
100+
assertThat(m1.getDefaultOptions().getModel()).isEqualTo("m1");
101+
assertThat(m2.getDefaultOptions().getModel()).isEqualTo("m2");
102+
assertThat(m3.getDefaultOptions().getModel()).isEqualTo("m3");
103+
}
104+
105+
@Test
106+
void mutateAndCloneAreEquivalent() {
107+
OpenAiChatModel mutated = baseModel.mutate().build();
108+
OpenAiChatModel cloned = baseModel.clone();
109+
assertThat(mutated.toString()).isEqualTo(cloned.toString());
110+
assertThat(mutated).isNotSameAs(cloned);
111+
}
112+
113+
}

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/GroqWithOpenAiChatModelIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666

6767
@SpringBootTest(classes = GroqWithOpenAiChatModelIT.Config.class)
6868
@EnabledIfEnvironmentVariable(named = "GROQ_API_KEY", matches = ".+")
69-
@Disabled("Due to rate limiting it is hard to run it in one go")
69+
// @Disabled("Due to rate limiting it is hard to run it in one go")
7070
class GroqWithOpenAiChatModelIT {
7171

7272
private static final Logger logger = LoggerFactory.getLogger(GroqWithOpenAiChatModelIT.class);
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
/*
2+
* Copyright 2023-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.openai.chat.proxy;
18+
19+
import org.junit.jupiter.api.Test;
20+
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
21+
import org.slf4j.Logger;
22+
import org.slf4j.LoggerFactory;
23+
24+
import org.springframework.ai.chat.client.ChatClient;
25+
import org.springframework.ai.openai.OpenAiChatModel;
26+
import org.springframework.ai.openai.OpenAiChatOptions;
27+
import org.springframework.ai.openai.api.OpenAiApi;
28+
import org.springframework.beans.factory.annotation.Autowired;
29+
import org.springframework.boot.SpringBootConfiguration;
30+
import org.springframework.boot.test.context.SpringBootTest;
31+
import org.springframework.context.annotation.Bean;
32+
import org.springframework.test.context.ActiveProfiles;
33+
34+
import static org.assertj.core.api.Assertions.assertThat;
35+
36+
@SpringBootTest(classes = MultiOpenAiClientIT.Config.class)
37+
@EnabledIfEnvironmentVariable(named = "GROQ_API_KEY", matches = ".+")
38+
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
39+
@ActiveProfiles("logging-test")
40+
class MultiOpenAiClientIT {
41+
42+
private static final Logger logger = LoggerFactory.getLogger(MultiOpenAiClientIT.class);
43+
44+
@Autowired
45+
private OpenAiChatModel baseChatModel;
46+
47+
@Autowired
48+
private OpenAiApi baseOpenAiApi;
49+
50+
@Test
51+
void multiClientFlow() {
52+
// Derive a new OpenAiApi for Groq (Llama3)
53+
OpenAiApi groqApi = baseOpenAiApi.mutate()
54+
.baseUrl("https://api.groq.com/openai")
55+
.apiKey(System.getenv("GROQ_API_KEY"))
56+
.build();
57+
58+
// Derive a new OpenAiApi for OpenAI GPT-4
59+
OpenAiApi gpt4Api = baseOpenAiApi.mutate()
60+
.baseUrl("https://api.openai.com")
61+
.apiKey(System.getenv("OPENAI_API_KEY"))
62+
.build();
63+
64+
// Derive a new OpenAiChatModel for Groq
65+
OpenAiChatModel groqModel = baseChatModel.mutate()
66+
.openAiApi(groqApi)
67+
.defaultOptions(OpenAiChatOptions.builder().model("llama3-70b-8192").temperature(0.5).build())
68+
.build();
69+
70+
// Derive a new OpenAiChatModel for GPT-4
71+
OpenAiChatModel gpt4Model = baseChatModel.mutate()
72+
.openAiApi(gpt4Api)
73+
.defaultOptions(OpenAiChatOptions.builder().model("gpt-4").temperature(0.7).build())
74+
.build();
75+
76+
// Simple prompt for both models
77+
String prompt = "What is the capital of France?";
78+
79+
String groqResponse = ChatClient.builder(groqModel).build().prompt(prompt).call().content();
80+
String gpt4Response = ChatClient.builder(gpt4Model).build().prompt(prompt).call().content();
81+
82+
logger.info("Groq (Llama3) response: {}", groqResponse);
83+
logger.info("OpenAI GPT-4 response: {}", gpt4Response);
84+
85+
assertThat(groqResponse).containsIgnoringCase("Paris");
86+
assertThat(gpt4Response).containsIgnoringCase("Paris");
87+
88+
logger.info("OpenAI GPT-4 response: {}", gpt4Response);
89+
90+
assertThat(groqResponse).containsIgnoringCase("Paris");
91+
assertThat(gpt4Response).containsIgnoringCase("Paris");
92+
}
93+
94+
@SpringBootConfiguration
95+
static class Config {
96+
97+
@Bean
98+
public OpenAiApi chatCompletionApi() {
99+
return OpenAiApi.builder().baseUrl("foo").apiKey("bar").build();
100+
}
101+
102+
@Bean
103+
public OpenAiChatModel openAiClient(OpenAiApi openAiApi) {
104+
return OpenAiChatModel.builder().openAiApi(openAiApi).build();
105+
}
106+
107+
}
108+
109+
}

0 commit comments

Comments
 (0)