1818
1919import org .springframework .ai .chat .prompt .ChatOptions ;
2020import org .springframework .ai .model .function .FunctionCallback ;
21- import org .springframework .ai .tool .ToolCallback ;
2221import org .springframework .lang .Nullable ;
2322import org .springframework .util .Assert ;
24- import org .springframework .util .CollectionUtils ;
25- import org .springframework .util .StringUtils ;
2623
2724import java .util .ArrayList ;
25+ import java .util .Arrays ;
2826import java .util .HashMap ;
2927import java .util .HashSet ;
3028import java .util .List ;
3937 */
4038public class DefaultToolCallingChatOptions implements ToolCallingChatOptions {
4139
42- private List <ToolCallback > toolCallbacks = new ArrayList <>();
40+ private List <FunctionCallback > toolCallbacks = new ArrayList <>();
4341
4442 private Set <String > tools = new HashSet <>();
4543
4644 private Map <String , Object > toolContext = new HashMap <>();
4745
4846 @ Nullable
49- private Boolean toolCallReturnDirect ;
47+ private Boolean toolExecutionEnabled ;
5048
5149 @ Nullable
5250 private String model ;
@@ -73,23 +71,17 @@ public class DefaultToolCallingChatOptions implements ToolCallingChatOptions {
7371 private Double topP ;
7472
7573 @ Override
76- public List <ToolCallback > getToolCallbacks () {
74+ public List <FunctionCallback > getToolCallbacks () {
7775 return List .copyOf (this .toolCallbacks );
7876 }
7977
8078 @ Override
81- public void setToolCallbacks (List <ToolCallback > toolCallbacks ) {
79+ public void setToolCallbacks (List <FunctionCallback > toolCallbacks ) {
8280 Assert .notNull (toolCallbacks , "toolCallbacks cannot be null" );
8381 Assert .noNullElements (toolCallbacks , "toolCallbacks cannot contain null elements" );
8482 this .toolCallbacks = new ArrayList <>(toolCallbacks );
8583 }
8684
87- @ Override
88- public void setToolCallbacks (ToolCallback ... toolCallbacks ) {
89- Assert .notNull (toolCallbacks , "toolCallbacks cannot be null" );
90- setToolCallbacks (List .of (toolCallbacks ));
91- }
92-
9385 @ Override
9486 public Set <String > getTools () {
9587 return Set .copyOf (this .tools );
@@ -103,12 +95,6 @@ public void setTools(Set<String> tools) {
10395 this .tools = new HashSet <>(tools );
10496 }
10597
106- @ Override
107- public void setTools (String ... tools ) {
108- Assert .notNull (tools , "tools cannot be null" );
109- setTools (Set .of (tools ));
110- }
111-
11298 @ Override
11399 public Map <String , Object > getToolContext () {
114100 return Map .copyOf (this .toolContext );
@@ -123,23 +109,23 @@ public void setToolContext(Map<String, Object> toolContext) {
123109
124110 @ Override
125111 @ Nullable
126- public Boolean getToolCallReturnDirect () {
127- return this .toolCallReturnDirect ;
112+ public Boolean isToolExecutionEnabled () {
113+ return this .toolExecutionEnabled ;
128114 }
129115
130116 @ Override
131- public void setToolCallReturnDirect (@ Nullable Boolean toolCallReturnDirect ) {
132- this .toolCallReturnDirect = toolCallReturnDirect ;
117+ public void setToolExecutionEnabled (@ Nullable Boolean toolExecutionEnabled ) {
118+ this .toolExecutionEnabled = toolExecutionEnabled ;
133119 }
134120
135121 @ Override
136122 public List <FunctionCallback > getFunctionCallbacks () {
137- return getToolCallbacks (). stream (). map ( FunctionCallback . class :: cast ). toList () ;
123+ return getToolCallbacks ();
138124 }
139125
140126 @ Override
141127 public void setFunctionCallbacks (List <FunctionCallback > functionCallbacks ) {
142- throw new UnsupportedOperationException ( "Not supported. Call setToolCallbacks instead." );
128+ setToolCallbacks ( functionCallbacks );
143129 }
144130
145131 @ Override
@@ -155,12 +141,12 @@ public void setFunctions(Set<String> functions) {
155141 @ Override
156142 @ Nullable
157143 public Boolean getProxyToolCalls () {
158- return getToolCallReturnDirect () ;
144+ return isToolExecutionEnabled () != null ? ! isToolExecutionEnabled () : null ;
159145 }
160146
161147 @ Override
162148 public void setProxyToolCalls (@ Nullable Boolean proxyToolCalls ) {
163- setToolCallReturnDirect (proxyToolCalls != null && proxyToolCalls );
149+ setToolExecutionEnabled (proxyToolCalls == null || ! proxyToolCalls );
164150 }
165151
166152 @ Override
@@ -250,7 +236,7 @@ public <T extends ChatOptions> T copy() {
250236 options .setToolCallbacks (getToolCallbacks ());
251237 options .setTools (getTools ());
252238 options .setToolContext (getToolContext ());
253- options .setToolCallReturnDirect ( getToolCallReturnDirect ());
239+ options .setToolExecutionEnabled ( isToolExecutionEnabled ());
254240 options .setModel (getModel ());
255241 options .setFrequencyPenalty (getFrequencyPenalty ());
256242 options .setMaxTokens (getMaxTokens ());
@@ -262,55 +248,6 @@ public <T extends ChatOptions> T copy() {
262248 return (T ) options ;
263249 }
264250
265- /**
266- * Merge the given {@link ChatOptions} into this instance.
267- */
268- public ToolCallingChatOptions merge (ChatOptions options ) {
269- ToolCallingChatOptions .Builder builder = ToolCallingChatOptions .builder ();
270- builder .model (StringUtils .hasText (options .getModel ()) ? options .getModel () : this .getModel ());
271- builder .frequencyPenalty (
272- options .getFrequencyPenalty () != null ? options .getFrequencyPenalty () : this .getFrequencyPenalty ());
273- builder .maxTokens (options .getMaxTokens () != null ? options .getMaxTokens () : this .getMaxTokens ());
274- builder .presencePenalty (
275- options .getPresencePenalty () != null ? options .getPresencePenalty () : this .getPresencePenalty ());
276- builder .stopSequences (options .getStopSequences () != null ? new ArrayList <>(options .getStopSequences ())
277- : this .getStopSequences ());
278- builder .temperature (options .getTemperature () != null ? options .getTemperature () : this .getTemperature ());
279- builder .topK (options .getTopK () != null ? options .getTopK () : this .getTopK ());
280- builder .topP (options .getTopP () != null ? options .getTopP () : this .getTopP ());
281-
282- if (options instanceof ToolCallingChatOptions toolOptions ) {
283- List <ToolCallback > toolCallbacks = new ArrayList <>(this .toolCallbacks );
284- if (!CollectionUtils .isEmpty (toolOptions .getToolCallbacks ())) {
285- toolCallbacks .addAll (toolOptions .getToolCallbacks ());
286- }
287- builder .toolCallbacks (toolCallbacks );
288-
289- Set <String > tools = new HashSet <>(this .tools );
290- if (!CollectionUtils .isEmpty (toolOptions .getTools ())) {
291- tools .addAll (toolOptions .getTools ());
292- }
293- builder .tools (tools );
294-
295- Map <String , Object > toolContext = new HashMap <>(this .toolContext );
296- if (!CollectionUtils .isEmpty (toolOptions .getToolContext ())) {
297- toolContext .putAll (toolOptions .getToolContext ());
298- }
299- builder .toolContext (toolContext );
300-
301- builder .toolCallReturnDirect (toolOptions .getToolCallReturnDirect () != null
302- ? toolOptions .getToolCallReturnDirect () : this .getToolCallReturnDirect ());
303- }
304- else {
305- builder .toolCallbacks (this .toolCallbacks );
306- builder .tools (this .tools );
307- builder .toolContext (this .toolContext );
308- builder .toolCallReturnDirect (this .toolCallReturnDirect );
309- }
310-
311- return builder .build ();
312- }
313-
314251 public static Builder builder () {
315252 return new Builder ();
316253 }
@@ -323,14 +260,15 @@ public static class Builder implements ToolCallingChatOptions.Builder {
323260 private final DefaultToolCallingChatOptions options = new DefaultToolCallingChatOptions ();
324261
325262 @ Override
326- public ToolCallingChatOptions .Builder toolCallbacks (List <ToolCallback > toolCallbacks ) {
263+ public ToolCallingChatOptions .Builder toolCallbacks (List <FunctionCallback > toolCallbacks ) {
327264 this .options .setToolCallbacks (toolCallbacks );
328265 return this ;
329266 }
330267
331268 @ Override
332- public ToolCallingChatOptions .Builder toolCallbacks (ToolCallback ... toolCallbacks ) {
333- this .options .setToolCallbacks (toolCallbacks );
269+ public ToolCallingChatOptions .Builder toolCallbacks (FunctionCallback ... toolCallbacks ) {
270+ Assert .notNull (toolCallbacks , "toolCallbacks cannot be null" );
271+ this .options .setToolCallbacks (Arrays .asList (toolCallbacks ));
334272 return this ;
335273 }
336274
@@ -342,7 +280,8 @@ public ToolCallingChatOptions.Builder tools(Set<String> toolNames) {
342280
343281 @ Override
344282 public ToolCallingChatOptions .Builder tools (String ... toolNames ) {
345- this .options .setTools (toolNames );
283+ Assert .notNull (toolNames , "toolNames cannot be null" );
284+ this .options .setTools (Set .of (toolNames ));
346285 return this ;
347286 }
348287
@@ -363,16 +302,15 @@ public ToolCallingChatOptions.Builder toolContext(String key, Object value) {
363302 }
364303
365304 @ Override
366- public ToolCallingChatOptions .Builder toolCallReturnDirect (@ Nullable Boolean toolCallReturnDirect ) {
367- this .options .setToolCallReturnDirect ( toolCallReturnDirect );
305+ public ToolCallingChatOptions .Builder toolExecutionEnabled (@ Nullable Boolean toolExecutionEnabled ) {
306+ this .options .setToolExecutionEnabled ( toolExecutionEnabled );
368307 return this ;
369308 }
370309
371310 @ Override
372311 @ Deprecated // Use toolCallbacks() instead
373312 public ToolCallingChatOptions .Builder functionCallbacks (List <FunctionCallback > functionCallbacks ) {
374- Assert .notNull (functionCallbacks , "functionCallbacks cannot be null" );
375- return toolCallbacks (functionCallbacks .stream ().map (ToolCallback .class ::cast ).toList ());
313+ return toolCallbacks (functionCallbacks );
376314 }
377315
378316 @ Override
@@ -395,9 +333,9 @@ public ToolCallingChatOptions.Builder function(String function) {
395333 }
396334
397335 @ Override
398- @ Deprecated // Use toolCallReturnDirect () instead
336+ @ Deprecated // Use toolExecutionEnabled () instead
399337 public ToolCallingChatOptions .Builder proxyToolCalls (@ Nullable Boolean proxyToolCalls ) {
400- return toolCallReturnDirect (proxyToolCalls != null && proxyToolCalls );
338+ return toolExecutionEnabled (proxyToolCalls == null || ! proxyToolCalls );
401339 }
402340
403341 @ Override
0 commit comments