Skip to content

Commit 7e2cefc

Browse files
committed
Share options instance between DefaultFunctionCallingOptionsBuilder and parent class
1 parent adefca1 commit 7e2cefc

File tree

3 files changed

+139
-10
lines changed

3 files changed

+139
-10
lines changed

spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptionsBuilder.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
*/
2424
public class DefaultChatOptionsBuilder<T extends DefaultChatOptionsBuilder<T>> implements ChatOptions.Builder<T> {
2525

26-
private final DefaultChatOptions options = new DefaultChatOptions();
26+
protected DefaultChatOptions options = new DefaultChatOptions();
2727

2828
protected T self() {
2929
return (T) this;

spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultFunctionCallingOptionsBuilder.java

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,43 +36,49 @@ public class DefaultFunctionCallingOptionsBuilder
3636
extends DefaultChatOptionsBuilder<DefaultFunctionCallingOptionsBuilder>
3737
implements FunctionCallingOptions.Builder<DefaultFunctionCallingOptionsBuilder> {
3838

39-
private final DefaultFunctionCallingOptions functionCallingOptions = new DefaultFunctionCallingOptions();
39+
private final DefaultFunctionCallingOptions functionCallingOptions;
40+
41+
public DefaultFunctionCallingOptionsBuilder() {
42+
this.functionCallingOptions = new DefaultFunctionCallingOptions();
43+
// Set the options in the parent class to be the same instance
44+
super.options = this.functionCallingOptions;
45+
}
4046

4147
public DefaultFunctionCallingOptionsBuilder functionCallbacks(List<FunctionCallback> functionCallbacks) {
4248
this.functionCallingOptions.setFunctionCallbacks(functionCallbacks);
43-
return this;
49+
return self();
4450
}
4551

4652
public DefaultFunctionCallingOptionsBuilder functionCallbacks(FunctionCallback... functionCallbacks) {
4753
Assert.notNull(functionCallbacks, "FunctionCallbacks must not be null");
4854
this.functionCallingOptions.setFunctionCallbacks(List.of(functionCallbacks));
49-
return this;
55+
return self();
5056
}
5157

5258
public DefaultFunctionCallingOptionsBuilder functions(Set<String> functions) {
5359
this.functionCallingOptions.setFunctions(functions);
54-
return this;
60+
return self();
5561
}
5662

5763
public DefaultFunctionCallingOptionsBuilder function(String function) {
5864
Assert.notNull(function, "Function must not be null");
5965
var set = new HashSet<>(this.functionCallingOptions.getFunctions());
6066
set.add(function);
6167
this.functionCallingOptions.setFunctions(set);
62-
return this;
68+
return self();
6369
}
6470

6571
public DefaultFunctionCallingOptionsBuilder proxyToolCalls(Boolean proxyToolCalls) {
6672
this.functionCallingOptions.setProxyToolCalls(proxyToolCalls);
67-
return this;
73+
return self();
6874
}
6975

7076
public DefaultFunctionCallingOptionsBuilder toolContext(Map<String, Object> context) {
7177
Assert.notNull(context, "Tool context must not be null");
7278
Map<String, Object> newContext = new HashMap<>(this.functionCallingOptions.getToolContext());
7379
newContext.putAll(context);
7480
this.functionCallingOptions.setToolContext(newContext);
75-
return this;
81+
return self();
7682
}
7783

7884
public DefaultFunctionCallingOptionsBuilder toolContext(String key, Object value) {
@@ -81,7 +87,7 @@ public DefaultFunctionCallingOptionsBuilder toolContext(String key, Object value
8187
Map<String, Object> newContext = new HashMap<>(this.functionCallingOptions.getToolContext());
8288
newContext.put(key, value);
8389
this.functionCallingOptions.setToolContext(newContext);
84-
return this;
90+
return self();
8591
}
8692

8793
public FunctionCallingOptions build() {

spring-ai-core/src/test/java/org/springframework/ai/model/function/DefaultFunctionCallingOptionsBuilderTests.java

Lines changed: 124 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
import java.util.Map;
2626
import java.util.Set;
2727

28+
import org.springframework.ai.chat.prompt.ChatOptions;
29+
2830
import static org.assertj.core.api.Assertions.assertThat;
2931
import static org.assertj.core.api.Assertions.assertThatThrownBy;
3032

@@ -41,6 +43,109 @@ void setUp() {
4143
builder = new DefaultFunctionCallingOptionsBuilder();
4244
}
4345

46+
// Tests for inherited ChatOptions properties
47+
48+
@Test
49+
void shouldBuildWithModel() {
50+
// When
51+
ChatOptions options = builder.model("gpt-4").build();
52+
53+
// Then
54+
assertThat(options.getModel()).isEqualTo("gpt-4");
55+
}
56+
57+
@Test
58+
void shouldBuildWithFrequencyPenalty() {
59+
// When
60+
ChatOptions options = builder.frequencyPenalty(0.5).build();
61+
62+
// Then
63+
assertThat(options.getFrequencyPenalty()).isEqualTo(0.5);
64+
}
65+
66+
@Test
67+
void shouldBuildWithMaxTokens() {
68+
// When
69+
ChatOptions options = builder.maxTokens(100).build();
70+
71+
// Then
72+
assertThat(options.getMaxTokens()).isEqualTo(100);
73+
}
74+
75+
@Test
76+
void shouldBuildWithPresencePenalty() {
77+
// When
78+
ChatOptions options = builder.presencePenalty(0.7).build();
79+
80+
// Then
81+
assertThat(options.getPresencePenalty()).isEqualTo(0.7);
82+
}
83+
84+
@Test
85+
void shouldBuildWithStopSequences() {
86+
// Given
87+
List<String> stopSequences = List.of("stop1", "stop2");
88+
89+
// When
90+
ChatOptions options = builder.stopSequences(stopSequences).build();
91+
92+
// Then
93+
assertThat(options.getStopSequences()).hasSize(2).containsExactlyElementsOf(stopSequences);
94+
}
95+
96+
@Test
97+
void shouldBuildWithTemperature() {
98+
// When
99+
ChatOptions options = builder.temperature(0.8).build();
100+
101+
// Then
102+
assertThat(options.getTemperature()).isEqualTo(0.8);
103+
}
104+
105+
@Test
106+
void shouldBuildWithTopK() {
107+
// When
108+
ChatOptions options = builder.topK(5).build();
109+
110+
// Then
111+
assertThat(options.getTopK()).isEqualTo(5);
112+
}
113+
114+
@Test
115+
void shouldBuildWithTopP() {
116+
// When
117+
ChatOptions options = builder.topP(0.9).build();
118+
119+
// Then
120+
assertThat(options.getTopP()).isEqualTo(0.9);
121+
}
122+
123+
@Test
124+
void shouldBuildWithAllInheritedOptions() {
125+
// When
126+
ChatOptions options = builder.model("gpt-4")
127+
.frequencyPenalty(0.5)
128+
.maxTokens(100)
129+
.presencePenalty(0.7)
130+
.stopSequences(List.of("stop1", "stop2"))
131+
.temperature(0.8)
132+
.topK(5)
133+
.topP(0.9)
134+
.build();
135+
136+
// Then
137+
assertThat(options.getModel()).isEqualTo("gpt-4");
138+
assertThat(options.getFrequencyPenalty()).isEqualTo(0.5);
139+
assertThat(options.getMaxTokens()).isEqualTo(100);
140+
assertThat(options.getPresencePenalty()).isEqualTo(0.7);
141+
assertThat(options.getStopSequences()).containsExactly("stop1", "stop2");
142+
assertThat(options.getTemperature()).isEqualTo(0.8);
143+
assertThat(options.getTopK()).isEqualTo(5);
144+
assertThat(options.getTopP()).isEqualTo(0.9);
145+
}
146+
147+
// Original FunctionCallingOptions tests
148+
44149
@Test
45150
void shouldBuildWithFunctionCallbacksList() {
46151
// Given
@@ -195,7 +300,15 @@ void shouldBuildWithAllOptions() {
195300
Map<String, Object> context = Map.of("key1", "value1");
196301

197302
// When
198-
FunctionCallingOptions options = builder.functionCallbacks(callback)
303+
FunctionCallingOptions options = builder.model("gpt-4")
304+
.frequencyPenalty(0.5)
305+
.maxTokens(100)
306+
.presencePenalty(0.7)
307+
.stopSequences(List.of("stop1", "stop2"))
308+
.temperature(0.8)
309+
.topK(5)
310+
.topP(0.9)
311+
.functionCallbacks(callback)
199312
.functions(functions)
200313
.proxyToolCalls(true)
201314
.toolContext(context)
@@ -206,6 +319,16 @@ void shouldBuildWithAllOptions() {
206319
assertThat(options.getFunctions()).hasSize(1).containsExactlyElementsOf(functions);
207320
assertThat(options.getProxyToolCalls()).isTrue();
208321
assertThat(options.getToolContext()).hasSize(1).containsAllEntriesOf(context);
322+
323+
ChatOptions chatOptions = options;
324+
assertThat(chatOptions.getModel()).isEqualTo("gpt-4");
325+
assertThat(chatOptions.getFrequencyPenalty()).isEqualTo(0.5);
326+
assertThat(chatOptions.getMaxTokens()).isEqualTo(100);
327+
assertThat(chatOptions.getPresencePenalty()).isEqualTo(0.7);
328+
assertThat(chatOptions.getStopSequences()).containsExactly("stop1", "stop2");
329+
assertThat(chatOptions.getTemperature()).isEqualTo(0.8);
330+
assertThat(chatOptions.getTopK()).isEqualTo(5);
331+
assertThat(chatOptions.getTopP()).isEqualTo(0.9);
209332
}
210333

211334
}

0 commit comments

Comments
 (0)