1
1
/*
2
- * Copyright 2023-2024 the original author or authors.
2
+ * Copyright 2023-2025 the original author or authors.
3
3
*
4
4
* Licensed under the Apache License, Version 2.0 (the "License");
5
5
* you may not use this file except in compliance with the License.
17
17
package org .springframework .ai .ollama .api ;
18
18
19
19
import java .util .ArrayList ;
20
+ import java .util .Arrays ;
20
21
import java .util .HashSet ;
21
22
import java .util .List ;
22
23
import java .util .Map ;
32
33
import org .springframework .ai .embedding .EmbeddingOptions ;
33
34
import org .springframework .ai .model .ModelOptionsUtils ;
34
35
import org .springframework .ai .model .function .FunctionCallback ;
35
- import org .springframework .ai .model .function .FunctionCallingOptions ;
36
+ import org .springframework .ai .model .tool .ToolCallingChatOptions ;
37
+ import org .springframework .lang .Nullable ;
36
38
import org .springframework .util .Assert ;
37
39
38
40
/**
48
50
* @see <a href="https://github.com/ollama/ollama/blob/main/api/types.go">Ollama Types</a>
49
51
*/
50
52
@ JsonInclude (Include .NON_NULL )
51
- public class OllamaOptions implements FunctionCallingOptions , EmbeddingOptions {
53
+ public class OllamaOptions implements ToolCallingChatOptions , EmbeddingOptions {
52
54
53
55
private static final List <String > NON_SUPPORTED_FIELDS = List .of ("model" , "format" , "keep_alive" , "truncate" );
54
56
@@ -305,28 +307,28 @@ public class OllamaOptions implements FunctionCallingOptions, EmbeddingOptions {
305
307
@ JsonProperty ("truncate" )
306
308
private Boolean truncate ;
307
309
310
+ @ JsonIgnore
311
+ private Boolean internalToolExecutionEnabled ;
312
+
308
313
/**
309
314
* Tool Function Callbacks to register with the ChatModel.
310
315
* For Prompt Options the functionCallbacks are automatically enabled for the duration of the prompt execution.
311
316
* For Default Options the functionCallbacks are registered but disabled by default. Use the enableFunctions to set the functions
312
317
* from the registry to be used by the ChatModel chat completion requests.
313
318
*/
314
319
@ JsonIgnore
315
- private List <FunctionCallback > functionCallbacks = new ArrayList <>();
320
+ private List <FunctionCallback > toolCallbacks = new ArrayList <>();
316
321
317
322
/**
318
323
* List of functions, identified by their names, to configure for function calling in
319
324
* the chat completion requests.
320
325
* Functions with those names must exist in the functionCallbacks registry.
321
- * The {@link #functionCallbacks } from the PromptOptions are automatically enabled for the duration of the prompt execution.
326
+ * The {@link #toolCallbacks } from the PromptOptions are automatically enabled for the duration of the prompt execution.
322
327
* Note that function enabled with the default options are enabled for all chat completion requests. This could impact the token count and the billing.
323
328
* If the functions is set in a prompt options, then the enabled functions are only active for the duration of this prompt execution.
324
329
*/
325
330
@ JsonIgnore
326
- private Set <String > functions = new HashSet <>();
327
-
328
- @ JsonIgnore
329
- private Boolean proxyToolCalls ;
331
+ private Set <String > toolNames = new HashSet <>();
330
332
331
333
@ JsonIgnore
332
334
private Map <String , Object > toolContext ;
@@ -381,9 +383,9 @@ public static OllamaOptions fromOptions(OllamaOptions fromOptions) {
381
383
.mirostatEta (fromOptions .getMirostatEta ())
382
384
.penalizeNewline (fromOptions .getPenalizeNewline ())
383
385
.stop (fromOptions .getStop ())
384
- .functions (fromOptions .getFunctions ())
385
- .proxyToolCalls (fromOptions .getProxyToolCalls ())
386
- .functionCallbacks (fromOptions .getFunctionCallbacks ())
386
+ .tools (fromOptions .getTools ())
387
+ .internalToolExecutionEnabled (fromOptions .isInternalToolExecutionEnabled ())
388
+ .toolCallbacks (fromOptions .getToolCallbacks ())
387
389
.toolContext (fromOptions .getToolContext ()).build ();
388
390
}
389
391
@@ -683,23 +685,73 @@ public void setTruncate(Boolean truncate) {
683
685
}
684
686
685
687
@ Override
688
+ @ JsonIgnore
689
+ public List <FunctionCallback > getToolCallbacks () {
690
+ return this .toolCallbacks ;
691
+ }
692
+
693
+ @ Override
694
+ @ JsonIgnore
695
+ public void setToolCallbacks (List <FunctionCallback > toolCallbacks ) {
696
+ Assert .notNull (toolCallbacks , "toolCallbacks cannot be null" );
697
+ Assert .noNullElements (toolCallbacks , "toolCallbacks cannot contain null elements" );
698
+ this .toolCallbacks = toolCallbacks ;
699
+ }
700
+
701
+ @ Override
702
+ @ JsonIgnore
703
+ public Set <String > getTools () {
704
+ return this .toolNames ;
705
+ }
706
+
707
+ @ Override
708
+ @ JsonIgnore
709
+ public void setTools (Set <String > toolNames ) {
710
+ Assert .notNull (toolNames , "toolNames cannot be null" );
711
+ Assert .noNullElements (toolNames , "toolNames cannot contain null elements" );
712
+ toolNames .forEach (tool -> Assert .hasText (tool , "toolNames cannot contain empty elements" ));
713
+ this .toolNames = toolNames ;
714
+ }
715
+
716
+ @ Override
717
+ @ Nullable
718
+ @ JsonIgnore
719
+ public Boolean isInternalToolExecutionEnabled () {
720
+ return internalToolExecutionEnabled ;
721
+ }
722
+
723
+ @ Override
724
+ @ JsonIgnore
725
+ public void setInternalToolExecutionEnabled (@ Nullable Boolean internalToolExecutionEnabled ) {
726
+ this .internalToolExecutionEnabled = internalToolExecutionEnabled ;
727
+ }
728
+
729
+ @ Override
730
+ @ Deprecated
731
+ @ JsonIgnore
686
732
public List <FunctionCallback > getFunctionCallbacks () {
687
- return this .functionCallbacks ;
733
+ return this .getToolCallbacks () ;
688
734
}
689
735
690
736
@ Override
737
+ @ Deprecated
738
+ @ JsonIgnore
691
739
public void setFunctionCallbacks (List <FunctionCallback > functionCallbacks ) {
692
- this .functionCallbacks = functionCallbacks ;
740
+ this .setToolCallbacks ( functionCallbacks ) ;
693
741
}
694
742
695
743
@ Override
744
+ @ Deprecated
745
+ @ JsonIgnore
696
746
public Set <String > getFunctions () {
697
- return this .functions ;
747
+ return this .getTools () ;
698
748
}
699
749
700
750
@ Override
751
+ @ Deprecated
752
+ @ JsonIgnore
701
753
public void setFunctions (Set <String > functions ) {
702
- this .functions = functions ;
754
+ this .setTools ( functions ) ;
703
755
}
704
756
705
757
@ Override
@@ -709,20 +761,26 @@ public Integer getDimensions() {
709
761
}
710
762
711
763
@ Override
764
+ @ Deprecated
765
+ @ JsonIgnore
712
766
public Boolean getProxyToolCalls () {
713
- return this .proxyToolCalls ;
767
+ return this .internalToolExecutionEnabled != null ? ! this . internalToolExecutionEnabled : null ;
714
768
}
715
769
770
+ @ Deprecated
771
+ @ JsonIgnore
716
772
public void setProxyToolCalls (Boolean proxyToolCalls ) {
717
- this .proxyToolCalls = proxyToolCalls ;
773
+ this .internalToolExecutionEnabled = proxyToolCalls != null ? ! proxyToolCalls : null ;
718
774
}
719
775
720
776
@ Override
777
+ @ JsonIgnore
721
778
public Map <String , Object > getToolContext () {
722
779
return this .toolContext ;
723
780
}
724
781
725
782
@ Override
783
+ @ JsonIgnore
726
784
public void setToolContext (Map <String , Object > toolContext ) {
727
785
this .toolContext = toolContext ;
728
786
}
@@ -769,9 +827,9 @@ public boolean equals(Object o) {
769
827
&& Objects .equals (this .mirostat , that .mirostat ) && Objects .equals (this .mirostatTau , that .mirostatTau )
770
828
&& Objects .equals (this .mirostatEta , that .mirostatEta )
771
829
&& Objects .equals (this .penalizeNewline , that .penalizeNewline ) && Objects .equals (this .stop , that .stop )
772
- && Objects .equals (this .functionCallbacks , that .functionCallbacks )
773
- && Objects .equals (this .proxyToolCalls , that .proxyToolCalls )
774
- && Objects .equals (this .functions , that .functions ) && Objects .equals (this .toolContext , that .toolContext );
830
+ && Objects .equals (this .toolCallbacks , that .toolCallbacks )
831
+ && Objects .equals (this .internalToolExecutionEnabled , that .internalToolExecutionEnabled )
832
+ && Objects .equals (this .toolNames , that .toolNames ) && Objects .equals (this .toolContext , that .toolContext );
775
833
}
776
834
777
835
@ Override
@@ -781,7 +839,7 @@ public int hashCode() {
781
839
this .useMMap , this .useMLock , this .numThread , this .numKeep , this .seed , this .numPredict , this .topK ,
782
840
this .topP , this .tfsZ , this .typicalP , this .repeatLastN , this .temperature , this .repeatPenalty ,
783
841
this .presencePenalty , this .frequencyPenalty , this .mirostat , this .mirostatTau , this .mirostatEta ,
784
- this .penalizeNewline , this .stop , this .functionCallbacks , this .functions , this .proxyToolCalls ,
842
+ this .penalizeNewline , this .stop , this .toolCallbacks , this .toolNames , this .internalToolExecutionEnabled ,
785
843
this .toolContext );
786
844
}
787
845
@@ -959,25 +1017,53 @@ public Builder stop(List<String> stop) {
959
1017
return this ;
960
1018
}
961
1019
962
- public Builder functionCallbacks (List <FunctionCallback > functionCallbacks ) {
963
- this .options .functionCallbacks = functionCallbacks ;
1020
+ public Builder toolCallbacks (List <FunctionCallback > toolCallbacks ) {
1021
+ this .options .setToolCallbacks ( toolCallbacks ) ;
964
1022
return this ;
965
1023
}
966
1024
967
- public Builder functions ( Set < String > functions ) {
968
- Assert .notNull (functions , "Function names must not be null" );
969
- this .options .functions = functions ;
1025
+ public Builder toolCallbacks ( FunctionCallback ... toolCallbacks ) {
1026
+ Assert .notNull (toolCallbacks , "toolCallbacks cannot be null" );
1027
+ this .options .toolCallbacks . addAll ( Arrays . asList ( toolCallbacks )) ;
970
1028
return this ;
971
1029
}
972
1030
973
- public Builder function (String functionName ) {
974
- Assert .hasText (functionName , "Function name must not be empty" );
975
- this .options .functions .add (functionName );
1031
+ public Builder tools (Set <String > toolNames ) {
1032
+ this .options .setTools (toolNames );
1033
+ return this ;
1034
+ }
1035
+
1036
+ public Builder tools (String ... toolNames ) {
1037
+ Assert .notNull (toolNames , "toolNames cannot be null" );
1038
+ this .options .toolNames .addAll (Set .of (toolNames ));
1039
+ return this ;
1040
+ }
1041
+
1042
+ public Builder internalToolExecutionEnabled (@ Nullable Boolean internalToolExecutionEnabled ) {
1043
+ this .options .setInternalToolExecutionEnabled (internalToolExecutionEnabled );
976
1044
return this ;
977
1045
}
978
1046
1047
+ @ Deprecated
1048
+ public Builder functionCallbacks (List <FunctionCallback > functionCallbacks ) {
1049
+ return toolCallbacks (functionCallbacks );
1050
+ }
1051
+
1052
+ @ Deprecated
1053
+ public Builder functions (Set <String > functions ) {
1054
+ return tools (functions );
1055
+ }
1056
+
1057
+ @ Deprecated
1058
+ public Builder function (String functionName ) {
1059
+ return tools (functionName );
1060
+ }
1061
+
1062
+ @ Deprecated
979
1063
public Builder proxyToolCalls (Boolean proxyToolCalls ) {
980
- this .options .proxyToolCalls = proxyToolCalls ;
1064
+ if (proxyToolCalls != null ) {
1065
+ this .options .setInternalToolExecutionEnabled (!proxyToolCalls );
1066
+ }
981
1067
return this ;
982
1068
}
983
1069
0 commit comments