1616package org .springframework .ai .openai ;
1717
1818import java .time .Duration ;
19+ import java .util .ArrayList ;
20+ import java .util .HashMap ;
21+ import java .util .HashSet ;
1922import java .util .List ;
2023import java .util .Map ;
24+ import java .util .Set ;
2125import java .util .concurrent .ConcurrentHashMap ;
2226
2327import org .slf4j .Logger ;
3337import org .springframework .ai .chat .metadata .RateLimit ;
3438import org .springframework .ai .chat .prompt .Prompt ;
3539import org .springframework .ai .model .ModelOptionsUtils ;
40+ import org .springframework .ai .model .ToolFunctionCallback ;
3641import org .springframework .ai .openai .api .OpenAiApi ;
3742import org .springframework .ai .openai .api .OpenAiApi .ChatCompletion ;
3843import org .springframework .ai .openai .api .OpenAiApi .ChatCompletionMessage ;
44+ import org .springframework .ai .openai .api .OpenAiApi .ChatCompletionMessage .Role ;
45+ import org .springframework .ai .openai .api .OpenAiApi .ChatCompletionMessage .ToolCall ;
3946import org .springframework .ai .openai .api .OpenAiApi .ChatCompletionRequest ;
4047import org .springframework .ai .openai .api .OpenAiApi .OpenAiApiException ;
4148import org .springframework .ai .openai .metadata .OpenAiChatResponseMetadata ;
4653import org .springframework .retry .RetryListener ;
4754import org .springframework .retry .support .RetryTemplate ;
4855import org .springframework .util .Assert ;
56+ import org .springframework .util .CollectionUtils ;
4957
5058/**
5159 * {@link ChatClient} implementation for {@literal OpenAI} backed by {@link OpenAiApi}.
@@ -66,11 +74,14 @@ public class OpenAiChatClient implements ChatClient, StreamingChatClient {
6674
6775 private OpenAiChatOptions defaultOptions ;
6876
77+ private Map <String , ToolFunctionCallback > toolCallbackRegister = new ConcurrentHashMap <>();
78+
6979 public final RetryTemplate retryTemplate = RetryTemplate .builder ()
7080 .maxAttempts (10 )
7181 .retryOn (OpenAiApiException .class )
7282 .exponentialBackoff (Duration .ofMillis (2000 ), 5 , Duration .ofMillis (3 * 60000 ))
7383 .withListener (new RetryListener () {
84+ @ Override
7485 public <T extends Object , E extends Throwable > void onError (RetryContext context ,
7586 RetryCallback <T , E > callback , Throwable throwable ) {
7687 logger .warn ("Retry error. Retry count:" + context .getRetryCount (), throwable );
@@ -108,18 +119,18 @@ public ChatResponse call(Prompt prompt) {
108119
109120 ChatCompletionRequest request = createRequest (prompt , false );
110121
111- ResponseEntity <ChatCompletion > completionEntity = this .openAiApi . chatCompletionEntity (request );
122+ ResponseEntity <ChatCompletion > completionEntity = this .chatCompletionWithTools (request );
112123
113124 var chatCompletion = completionEntity .getBody ();
114125 if (chatCompletion == null ) {
115- logger .warn ("No chat completion returned for request : {}" , prompt );
126+ logger .warn ("No chat completion returned for prompt : {}" , prompt );
116127 return new ChatResponse (List .of ());
117128 }
118129
119130 RateLimit rateLimits = OpenAiResponseHeaderExtractor .extractAiResponseHeaders (completionEntity );
120131
121132 List <Generation > generations = chatCompletion .choices ().stream ().map (choice -> {
122- return new Generation (choice .message ().content (), Map . of ( "role" , choice .message (). role (). name ()))
133+ return new Generation (choice .message ().content (), toMap ( choice .message ()))
123134 .withGenerationMetadata (ChatGenerationMetadata .from (choice .finishReason ().name (), null ));
124135 }).toList ();
125136
@@ -162,6 +173,8 @@ public Flux<ChatResponse> stream(Prompt prompt) {
162173 */
163174 ChatCompletionRequest createRequest (Prompt prompt , boolean stream ) {
164175
176+ Set <String > enabledFunctionsForRequest = new HashSet <>();
177+
165178 List <ChatCompletionMessage > chatCompletionMessages = prompt .getInstructions ()
166179 .stream ()
167180 .map (m -> new ChatCompletionMessage (m .getContent (),
@@ -170,14 +183,15 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
170183
171184 ChatCompletionRequest request = new ChatCompletionRequest (chatCompletionMessages , stream );
172185
173- if (this .defaultOptions != null ) {
174- request = ModelOptionsUtils .merge (request , this .defaultOptions , ChatCompletionRequest .class );
175- }
176-
177186 if (prompt .getOptions () != null ) {
178187 if (prompt .getOptions () instanceof ChatOptions runtimeOptions ) {
179188 OpenAiChatOptions updatedRuntimeOptions = ModelOptionsUtils .copyToTarget (runtimeOptions ,
180189 ChatOptions .class , OpenAiChatOptions .class );
190+
191+ Set <String > promptEnabledFunctions = handleToolFunctionConfigurations (updatedRuntimeOptions , true ,
192+ true );
193+ enabledFunctionsForRequest .addAll (promptEnabledFunctions );
194+
181195 request = ModelOptionsUtils .merge (updatedRuntimeOptions , request , ChatCompletionRequest .class );
182196 }
183197 else {
@@ -186,7 +200,180 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
186200 }
187201 }
188202
203+ if (this .defaultOptions != null ) {
204+
205+ Set <String > defaultEnabledFunctions = handleToolFunctionConfigurations (this .defaultOptions , false , false );
206+
207+ enabledFunctionsForRequest .addAll (defaultEnabledFunctions );
208+
209+ request = ModelOptionsUtils .merge (request , this .defaultOptions , ChatCompletionRequest .class );
210+ }
211+
212+ // Add the enabled functions definitions to the request's tools parameter.
213+ if (!CollectionUtils .isEmpty (enabledFunctionsForRequest )) {
214+
215+ if (stream ) {
216+ throw new IllegalArgumentException ("Currently tool functions are not supported in streaming mode" );
217+ }
218+
219+ request = ModelOptionsUtils .merge (
220+ OpenAiChatOptions .builder ().withTools (this .getFunctionTools (enabledFunctionsForRequest )).build (),
221+ request , ChatCompletionRequest .class );
222+ }
223+
189224 return request ;
190225 }
191226
227+ private Set <String > handleToolFunctionConfigurations (OpenAiChatOptions options , boolean autoEnableCallbackFunctions ,
228+ boolean overrideCallbackFunctionsRegister ) {
229+
230+ Set <String > enabledFunctions = new HashSet <>();
231+
232+ if (options != null ) {
233+ if (!CollectionUtils .isEmpty (options .getToolCallbacks ())) {
234+ options .getToolCallbacks ().stream ().forEach (toolCallback -> {
235+
236+ // Register the tool callback.
237+ if (overrideCallbackFunctionsRegister ) {
238+ this .toolCallbackRegister .put (toolCallback .getName (), toolCallback );
239+ }
240+ else {
241+ this .toolCallbackRegister .putIfAbsent (toolCallback .getName (), toolCallback );
242+ }
243+
244+ // Automatically enable the function, usually from prompt callback.
245+ if (autoEnableCallbackFunctions ) {
246+ enabledFunctions .add (toolCallback .getName ());
247+ }
248+ });
249+ }
250+
251+ // Add the explicitly enabled functions.
252+ if (!CollectionUtils .isEmpty (options .getEnabledFunctions ())) {
253+ enabledFunctions .addAll (options .getEnabledFunctions ());
254+ }
255+ }
256+
257+ return enabledFunctions ;
258+ }
259+
260+ /**
261+ * @return returns the registered tool callbacks.
262+ */
263+ Map <String , ToolFunctionCallback > getToolCallbackRegister () {
264+ return toolCallbackRegister ;
265+ }
266+
267+ public List <OpenAiApi .FunctionTool > getFunctionTools (Set <String > functionNames ) {
268+
269+ List <OpenAiApi .FunctionTool > functionTools = new ArrayList <>();
270+ for (String functionName : functionNames ) {
271+ if (!this .toolCallbackRegister .containsKey (functionName )) {
272+ throw new IllegalStateException ("No function callback found for function name: " + functionName );
273+ }
274+ ToolFunctionCallback functionCallback = this .toolCallbackRegister .get (functionName );
275+
276+ var function = new OpenAiApi .FunctionTool .Function (functionCallback .getDescription (),
277+ functionCallback .getName (), functionCallback .getInputTypeSchema ());
278+ functionTools .add (new OpenAiApi .FunctionTool (function ));
279+ }
280+
281+ return functionTools ;
282+ }
283+
284+ /**
285+ * Function Call handling. If the model calls a function, the function is called and
286+ * the response is added to the conversation history. The conversation history is then
287+ * sent back to the model.
288+ * @param request the chat completion request
289+ * @return the chat completion response.
290+ */
291+ @ SuppressWarnings ("null" )
292+ private ResponseEntity <ChatCompletion > chatCompletionWithTools (OpenAiApi .ChatCompletionRequest request ) {
293+
294+ ResponseEntity <ChatCompletion > chatCompletion = this .openAiApi .chatCompletionEntity (request );
295+
296+ // Return the result if the model is not calling a function.
297+ if (Boolean .FALSE .equals (this .isToolCall (chatCompletion ))) {
298+ return chatCompletion ;
299+ }
300+
301+ // The OpenAI chat completion tool call API requires the complete conversation
302+ // history. Including the initial user message.
303+ List <ChatCompletionMessage > conversationMessages = new ArrayList <>(request .messages ());
304+
305+ // We assume that the tool calling information is inside the response's first
306+ // choice.
307+ ChatCompletionMessage responseMessage = chatCompletion .getBody ().choices ().iterator ().next ().message ();
308+
309+ if (chatCompletion .getBody ().choices ().size () > 1 ) {
310+ logger .warn ("More than one choice returned. Only the first choice is processed." );
311+ }
312+
313+ // Add the assistant response to the message conversation history.
314+ conversationMessages .add (responseMessage );
315+
316+ // Every tool-call item requires a separate function call and a response (TOOL)
317+ // message.
318+ for (ToolCall toolCall : responseMessage .toolCalls ()) {
319+
320+ var functionName = toolCall .function ().name ();
321+ String functionArguments = toolCall .function ().arguments ();
322+
323+ if (!this .toolCallbackRegister .containsKey (functionName )) {
324+ throw new IllegalStateException ("No function callback found for function name: " + functionName );
325+ }
326+
327+ String functionResponse = this .toolCallbackRegister .get (functionName ).call (functionArguments );
328+
329+ // Add the function response to the conversation.
330+ conversationMessages .add (new ChatCompletionMessage (functionResponse , Role .TOOL , null , toolCall .id (), null ));
331+ }
332+
333+ // Recursively call chatCompletionWithTools until the model doesn't call a
334+ // functions anymore.
335+ ChatCompletionRequest newRequest = new ChatCompletionRequest (conversationMessages , request .stream ());
336+ newRequest = ModelOptionsUtils .merge (newRequest , request , ChatCompletionRequest .class );
337+
338+ return this .chatCompletionWithTools (newRequest );
339+ }
340+
341+ private Map <String , Object > toMap (ChatCompletionMessage message ) {
342+ Map <String , Object > map = new HashMap <>();
343+
344+ // The tool_calls and tool_call_id are not used by the OpenAiChatClient functions
345+ // call support! Useful only for users that want to use the tool_calls and
346+ // tool_call_id in their applications.
347+ if (message .toolCalls () != null ) {
348+ map .put ("tool_calls" , message .toolCalls ());
349+ }
350+ if (message .toolCallId () != null ) {
351+ map .put ("tool_call_id" , message .toolCallId ());
352+ }
353+
354+ if (message .role () != null ) {
355+ map .put ("role" , message .role ().name ());
356+ }
357+ return map ;
358+ }
359+
360+ /**
361+ * Check if it is a model calls function response.
362+ * @param chatCompletion the chat completion response.
363+ * @return true if the model expects a function call.
364+ */
365+ private Boolean isToolCall (ResponseEntity <ChatCompletion > chatCompletion ) {
366+ var body = chatCompletion .getBody ();
367+ if (body == null ) {
368+ return false ;
369+ }
370+
371+ var choices = body .choices ();
372+ if (CollectionUtils .isEmpty (choices )) {
373+ return false ;
374+ }
375+
376+ return choices .get (0 ).message ().toolCalls () != null ;
377+ }
378+
192379}
0 commit comments