1515 */
1616package org .springframework .ai .openai ;
1717
18- import java .util .ArrayList ;
19- import java .util .Base64 ;
20- import java .util .HashMap ;
21- import java .util .HashSet ;
22- import java .util .List ;
23- import java .util .Map ;
24- import java .util .Optional ;
25- import java .util .Set ;
26- import java .util .concurrent .ConcurrentHashMap ;
27-
2818import org .slf4j .Logger ;
2919import org .slf4j .LoggerFactory ;
20+ import org .springframework .ai .chat .messages .AssistantMessage ;
21+ import org .springframework .ai .chat .messages .Message ;
22+ import org .springframework .ai .chat .messages .MessageType ;
23+ import org .springframework .ai .chat .messages .ToolResponseMessage ;
3024import org .springframework .ai .chat .metadata .ChatGenerationMetadata ;
3125import org .springframework .ai .chat .metadata .RateLimit ;
3226import org .springframework .ai .chat .model .ChatModel ;
3630import org .springframework .ai .chat .prompt .ChatOptions ;
3731import org .springframework .ai .chat .prompt .Prompt ;
3832import org .springframework .ai .model .ModelOptionsUtils ;
39- import org .springframework .ai .model .function .AbstractFunctionCallSupport ;
33+ import org .springframework .ai .model .function .AbstractToolCallSupport ;
4034import org .springframework .ai .model .function .FunctionCallbackContext ;
4135import org .springframework .ai .openai .api .OpenAiApi ;
4236import org .springframework .ai .openai .api .OpenAiApi .ChatCompletion ;
4337import org .springframework .ai .openai .api .OpenAiApi .ChatCompletion .Choice ;
4438import org .springframework .ai .openai .api .OpenAiApi .ChatCompletionFinishReason ;
4539import org .springframework .ai .openai .api .OpenAiApi .ChatCompletionMessage ;
40+ import org .springframework .ai .openai .api .OpenAiApi .ChatCompletionMessage .ChatCompletionFunction ;
4641import org .springframework .ai .openai .api .OpenAiApi .ChatCompletionMessage .MediaContent ;
47- import org .springframework .ai .openai .api .OpenAiApi .ChatCompletionMessage .Role ;
4842import org .springframework .ai .openai .api .OpenAiApi .ChatCompletionMessage .ToolCall ;
4943import org .springframework .ai .openai .api .OpenAiApi .ChatCompletionRequest ;
5044import org .springframework .ai .openai .metadata .OpenAiChatResponseMetadata ;
5549import org .springframework .util .Assert ;
5650import org .springframework .util .CollectionUtils ;
5751import org .springframework .util .MimeType ;
58-
5952import reactor .core .publisher .Flux ;
53+ import reactor .core .publisher .Mono ;
54+
55+ import java .util .ArrayList ;
56+ import java .util .Base64 ;
57+ import java .util .HashSet ;
58+ import java .util .List ;
59+ import java .util .Map ;
60+ import java .util .Set ;
61+ import java .util .concurrent .ConcurrentHashMap ;
6062
6163/**
6264 * {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal OpenAI}
7779 * @see StreamingChatModel
7880 * @see OpenAiApi
7981 */
80- public class OpenAiChatModel extends
81- AbstractFunctionCallSupport <ChatCompletionMessage , OpenAiApi .ChatCompletionRequest , ResponseEntity <ChatCompletion >>
82- implements ChatModel {
82+ public class OpenAiChatModel extends AbstractToolCallSupport <ChatCompletion > implements ChatModel {
8383
8484 private static final Logger logger = LoggerFactory .getLogger (OpenAiChatModel .class );
8585
@@ -145,14 +145,25 @@ public ChatResponse call(Prompt prompt) {
145145
146146 return this .retryTemplate .execute (ctx -> {
147147
148- ResponseEntity <ChatCompletion > completionEntity = this .callWithFunctionSupport (request );
148+ ResponseEntity <ChatCompletion > completionEntity = this .openAiApi . chatCompletionEntity (request );
149149
150150 var chatCompletion = completionEntity .getBody ();
151+
151152 if (chatCompletion == null ) {
152153 logger .warn ("No chat completion returned for prompt: {}" , prompt );
153154 return new ChatResponse (List .of ());
154155 }
155156
157+ if (isToolFunctionCall (chatCompletion )) {
158+ List <Message > toolCallMessageConversation = this .handleToolCallRequests (prompt .getInstructions (),
159+ chatCompletion );
160+ // Recursively call the call method with the tool call message
161+ // conversation that contains the call responses.
162+
163+ return this .call (new Prompt (toolCallMessageConversation , prompt .getOptions ()));
164+ }
165+
166+ // Non function calling.
156167 RateLimit rateLimits = OpenAiResponseHeaderExtractor .extractAiResponseHeaders (completionEntity );
157168
158169 List <Choice > choices = chatCompletion .choices ();
@@ -162,7 +173,10 @@ public ChatResponse call(Prompt prompt) {
162173 }
163174
164175 List <Generation > generations = choices .stream ().map (choice -> {
165- var generation = new Generation (choice .message ().content (), toMap (chatCompletion .id (), choice ));
176+ Map <String , Object > metadata = Map .of ("id" , chatCompletion .id (), "role" ,
177+ choice .message ().role () != null ? choice .message ().role ().name () : "" , "finishReason" ,
178+ choice .finishReason () != null ? choice .finishReason ().name () : "" );
179+ var generation = new Generation (choice .message ().content (), metadata );
166180 if (choice .finishReason () != null ) {
167181 generation = generation
168182 .withGenerationMetadata (ChatGenerationMetadata .from (choice .finishReason ().name (), null ));
@@ -176,20 +190,6 @@ public ChatResponse call(Prompt prompt) {
176190 });
177191 }
178192
179- private Map <String , Object > toMap (String id , ChatCompletion .Choice choice ) {
180- Map <String , Object > map = new HashMap <>();
181-
182- var message = choice .message ();
183- if (message .role () != null ) {
184- map .put ("role" , message .role ().name ());
185- }
186- if (choice .finishReason () != null ) {
187- map .put ("finishReason" , choice .finishReason ().name ());
188- }
189- map .put ("id" , id );
190- return map ;
191- }
192-
193193 @ Override
194194 public Flux <ChatResponse > stream (Prompt prompt ) {
195195
@@ -205,16 +205,23 @@ public Flux<ChatResponse> stream(Prompt prompt) {
205205
206206 // Convert the ChatCompletionChunk into a ChatCompletion to be able to reuse
207207 // the function call handling logic.
208- return completionChunks .map (chunk -> chunkToChatCompletion (chunk ))
209- .switchMap (
210- cc -> handleFunctionCallOrReturnStream (request , Flux .just (ResponseEntity .of (Optional .of (cc )))))
211- .map (ResponseEntity ::getBody )
212- .map (chatCompletion -> {
208+ return completionChunks .map (this ::chunkToChatCompletion ).switchMap (chatCompletion -> {
209+
210+ if (this .isToolFunctionCall (chatCompletion )) {
211+ var toolCallMessageConversation = this .handleToolCallRequests (prompt .getInstructions (),
212+ chatCompletion );
213+ // Recursively call the stream method with the tool call message
214+ // conversation that contains the call responses.
215+ return this .stream (new Prompt (toolCallMessageConversation , prompt .getOptions ()));
216+ }
217+
218+ // Non function calling.
219+ return Mono .just (chatCompletion ).map (chatCompletion2 -> {
213220 try {
214221 @ SuppressWarnings ("null" )
215- String id = chatCompletion .id ();
222+ String id = chatCompletion2 .id ();
216223
217- List <Generation > generations = chatCompletion .choices ().stream ().map (choice -> {
224+ List <Generation > generations = chatCompletion2 .choices ().stream ().map (choice -> {
218225 if (choice .message ().role () != null ) {
219226 roleMap .putIfAbsent (id , choice .message ().role ().name ());
220227 }
@@ -228,8 +235,8 @@ public Flux<ChatResponse> stream(Prompt prompt) {
228235 return generation ;
229236 }).toList ();
230237
231- if (chatCompletion .usage () != null ) {
232- return new ChatResponse (generations , OpenAiChatResponseMetadata .from (chatCompletion ));
238+ if (chatCompletion2 .usage () != null ) {
239+ return new ChatResponse (generations , OpenAiChatResponseMetadata .from (chatCompletion2 ));
233240 }
234241 else {
235242 return new ChatResponse (generations );
@@ -241,9 +248,33 @@ public Flux<ChatResponse> stream(Prompt prompt) {
241248 }
242249
243250 });
251+ });
244252 });
245253 }
246254
255+ private List <Message > handleToolCallRequests (List <Message > previousMessages , ChatCompletion chatCompletion ) {
256+
257+ ChatCompletionMessage nativeAssistantMessage = this .extractAssistantMessage (chatCompletion );
258+
259+ List <AssistantMessage .ToolCall > assistantToolCalls = nativeAssistantMessage .toolCalls ()
260+ .stream ()
261+ .map (toolCall -> new AssistantMessage .ToolCall (toolCall .id (), "function" , toolCall .function ().name (),
262+ toolCall .function ().arguments ()))
263+ .toList ();
264+
265+ AssistantMessage assistantMessage = new AssistantMessage (nativeAssistantMessage .content (), Map .of (),
266+ assistantToolCalls );
267+
268+ List <ToolResponseMessage > toolResponseMessages = this .executeFuncitons (assistantMessage );
269+
270+ // History
271+ List <Message > messages = new ArrayList <>(previousMessages );
272+ messages .add (assistantMessage );
273+ messages .addAll (toolResponseMessages );
274+
275+ return messages ;
276+ }
277+
247278 /**
248279 * Convert the ChatCompletionChunk into a ChatCompletion. The Usage is set to null.
249280 * @param chunk the ChatCompletionChunk to convert
@@ -252,38 +283,66 @@ public Flux<ChatResponse> stream(Prompt prompt) {
252283 private OpenAiApi .ChatCompletion chunkToChatCompletion (OpenAiApi .ChatCompletionChunk chunk ) {
253284 List <Choice > choices = chunk .choices ()
254285 .stream ()
255- .map (cc -> new Choice (cc .finishReason (), cc .index (), cc .delta (), cc .logprobs ()))
286+ .map (chunkChoice -> new Choice (chunkChoice .finishReason (), chunkChoice .index (), chunkChoice .delta (),
287+ chunkChoice .logprobs ()))
256288 .toList ();
257289
258290 return new OpenAiApi .ChatCompletion (chunk .id (), choices , chunk .created (), chunk .model (),
259291 chunk .systemFingerprint (), "chat.completion" , chunk .usage ());
260292 }
261293
294+ private ChatCompletionMessage extractAssistantMessage (ChatCompletion chatCompletion ) {
295+ return chatCompletion .choices ().iterator ().next ().message ();
296+ }
297+
262298 /**
263299 * Accessible for testing.
264300 */
265301 ChatCompletionRequest createRequest (Prompt prompt , boolean stream ) {
266302
267303 Set <String > functionsForThisRequest = new HashSet <>();
268304
269- List <ChatCompletionMessage > chatCompletionMessages = prompt .getInstructions ().stream ().map (m -> {
270- Object content ;
271- if (CollectionUtils .isEmpty (m .getMedia ())) {
272- content = m .getContent ();
273- }
274- else {
275- List <MediaContent > contentList = new ArrayList <>(List .of (new MediaContent (m .getContent ())));
305+ List <ChatCompletionMessage > chatCompletionMessages = prompt .getInstructions ().stream ().map (message -> {
306+ if (message .getMessageType () == MessageType .USER || message .getMessageType () == MessageType .SYSTEM ) {
307+ Object content ;
308+ if (CollectionUtils .isEmpty (message .getMedia ())) {
309+ content = message .getContent ();
310+ }
311+ else {
312+ List <MediaContent > contentList = new ArrayList <>(List .of (new MediaContent (message .getContent ())));
276313
277- contentList .addAll (m .getMedia ()
278- .stream ()
279- .map (media -> new MediaContent (
280- new MediaContent .ImageUrl (this .fromMediaData (media .getMimeType (), media .getData ()))))
281- .toList ());
314+ contentList .addAll (message .getMedia ()
315+ .stream ()
316+ .map (media -> new MediaContent (
317+ new MediaContent .ImageUrl (this .fromMediaData (media .getMimeType (), media .getData ()))))
318+ .toList ());
282319
283- content = contentList ;
284- }
320+ content = contentList ;
321+ }
285322
286- return new ChatCompletionMessage (content , ChatCompletionMessage .Role .valueOf (m .getMessageType ().name ()));
323+ return new ChatCompletionMessage (content ,
324+ ChatCompletionMessage .Role .valueOf (message .getMessageType ().name ()));
325+ }
326+ else if (message .getMessageType () == MessageType .ASSISTANT ) {
327+ var assistantMessage = (AssistantMessage ) message ;
328+ List <ToolCall > toolCalls = null ;
329+ if (!CollectionUtils .isEmpty (assistantMessage .getToolCalls ())) {
330+ toolCalls = assistantMessage .getToolCalls ().stream ().map (toolCall -> {
331+ var function = new ChatCompletionFunction (toolCall .name (), toolCall .arguments ());
332+ return new ToolCall (toolCall .id (), toolCall .type (), function );
333+ }).toList ();
334+ }
335+ return new ChatCompletionMessage (assistantMessage .getContent (), ChatCompletionMessage .Role .ASSISTANT ,
336+ null , null , toolCalls );
337+ }
338+ else if (message .getMessageType () == MessageType .TOOL ) {
339+ ToolResponseMessage toolMessage = (ToolResponseMessage ) message ;
340+ return new ChatCompletionMessage (toolMessage .getContent (), ChatCompletionMessage .Role .TOOL ,
341+ toolMessage .getName (), toolMessage .getId (), null );
342+ }
343+ else {
344+ throw new IllegalArgumentException ("Unsupported message type: " + message .getMessageType ());
345+ }
287346 }).toList ();
288347
289348 ChatCompletionRequest request = new ChatCompletionRequest (chatCompletionMessages , stream );
@@ -351,66 +410,12 @@ private List<OpenAiApi.FunctionTool> getFunctionTools(Set<String> functionNames)
351410 }
352411
353412 @ Override
354- protected ChatCompletionRequest doCreateToolResponseRequest (ChatCompletionRequest previousRequest ,
355- ChatCompletionMessage responseMessage , List <ChatCompletionMessage > conversationHistory ) {
356-
357- // Every tool-call item requires a separate function call and a response (TOOL)
358- // message.
359- for (ToolCall toolCall : responseMessage .toolCalls ()) {
360-
361- var functionName = toolCall .function ().name ();
362- String functionArguments = toolCall .function ().arguments ();
363-
364- if (!this .functionCallbackRegister .containsKey (functionName )) {
365- throw new IllegalStateException ("No function callback found for function name: " + functionName );
366- }
367-
368- String functionResponse = this .functionCallbackRegister .get (functionName ).call (functionArguments );
369-
370- // Add the function response to the conversation.
371- conversationHistory
372- .add (new ChatCompletionMessage (functionResponse , Role .TOOL , functionName , toolCall .id (), null ));
373- }
374-
375- // Recursively call chatCompletionWithTools until the model doesn't call a
376- // functions anymore.
377- ChatCompletionRequest newRequest = new ChatCompletionRequest (conversationHistory , previousRequest .stream ());
378- newRequest = ModelOptionsUtils .merge (newRequest , previousRequest , ChatCompletionRequest .class );
379-
380- return newRequest ;
381- }
382-
383- @ Override
384- protected List <ChatCompletionMessage > doGetUserMessages (ChatCompletionRequest request ) {
385- return request .messages ();
386- }
387-
388- @ Override
389- protected ChatCompletionMessage doGetToolResponseMessage (ResponseEntity <ChatCompletion > chatCompletion ) {
390- return chatCompletion .getBody ().choices ().iterator ().next ().message ();
391- }
392-
393- @ Override
394- protected ResponseEntity <ChatCompletion > doChatCompletion (ChatCompletionRequest request ) {
395- return this .openAiApi .chatCompletionEntity (request );
396- }
397-
398- @ Override
399- protected Flux <ResponseEntity <ChatCompletion >> doChatCompletionStream (ChatCompletionRequest request ) {
400- return this .openAiApi .chatCompletionStream (request )
401- .map (this ::chunkToChatCompletion )
402- .map (Optional ::ofNullable )
403- .map (ResponseEntity ::of );
404- }
405-
406- @ Override
407- protected boolean isToolFunctionCall (ResponseEntity <ChatCompletion > chatCompletion ) {
408- var body = chatCompletion .getBody ();
409- if (body == null ) {
413+ protected boolean isToolFunctionCall (ChatCompletion chatCompletion ) {
414+ if (chatCompletion == null ) {
410415 return false ;
411416 }
412417
413- var choices = body .choices ();
418+ var choices = chatCompletion .choices ();
414419 if (CollectionUtils .isEmpty (choices )) {
415420 return false ;
416421 }
0 commit comments