4242import org .springframework .ai .chat .metadata .ChatResponseMetadata ;
4343import org .springframework .ai .chat .metadata .EmptyUsage ;
4444import org .springframework .ai .chat .metadata .RateLimit ;
45+ import org .springframework .ai .chat .metadata .Usage ;
46+ import org .springframework .ai .chat .metadata .UsageUtils ;
4547import org .springframework .ai .chat .model .AbstractToolCallSupport ;
4648import org .springframework .ai .chat .model .ChatModel ;
4749import org .springframework .ai .chat .model .ChatResponse ;
99101 * @author Mariusz Bernacki
100102 * @author luocongqiu
101103 * @author Thomas Vitale
104+ * @author Ilayaperumal Gopinathan
102105 * @see ChatModel
103106 * @see StreamingChatModel
104107 * @see OpenAiApi
@@ -215,6 +218,10 @@ public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions options,
215218
216219 @ Override
217220 public ChatResponse call (Prompt prompt ) {
221+ return this .internalCall (prompt , null );
222+ }
223+
224+ public ChatResponse internalCall (Prompt prompt , ChatResponse previousChatResponse ) {
218225
219226 ChatCompletionRequest request = createRequest (prompt , false );
220227
@@ -259,8 +266,12 @@ public ChatResponse call(Prompt prompt) {
259266
260267 // Non function calling.
261268 RateLimit rateLimit = OpenAiResponseHeaderExtractor .extractAiResponseHeaders (completionEntity );
262-
263- ChatResponse chatResponse = new ChatResponse (generations , from (completionEntity .getBody (), rateLimit ));
269+ // Current usage
270+ OpenAiApi .Usage usage = completionEntity .getBody ().usage ();
271+ Usage currentChatResponseUsage = usage != null ? OpenAiUsage .from (usage ) : new EmptyUsage ();
272+ Usage accumulatedUsage = UsageUtils .getCumulativeUsage (currentChatResponseUsage , previousChatResponse );
273+ ChatResponse chatResponse = new ChatResponse (generations ,
274+ from (completionEntity .getBody (), rateLimit , accumulatedUsage ));
264275
265276 observationContext .setResponse (chatResponse );
266277
@@ -274,14 +285,18 @@ && isToolCall(response, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.n
274285 var toolCallConversation = handleToolCalls (prompt , response );
275286 // Recursively call the call method with the tool call message
276287 // conversation that contains the call responses.
277- return this .call (new Prompt (toolCallConversation , prompt .getOptions ()));
288+ return this .internalCall (new Prompt (toolCallConversation , prompt .getOptions ()), response );
278289 }
279290
280291 return response ;
281292 }
282293
283294 @ Override
284295 public Flux <ChatResponse > stream (Prompt prompt ) {
296+ return internalStream (prompt , null );
297+ }
298+
299+ public Flux <ChatResponse > internalStream (Prompt prompt , ChatResponse previousChatResponse ) {
285300 return Flux .deferContextual (contextView -> {
286301 ChatCompletionRequest request = createRequest (prompt , true );
287302
@@ -337,15 +352,43 @@ public Flux<ChatResponse> stream(Prompt prompt) {
337352 return buildGeneration (choice , metadata , request );
338353 }).toList ();
339354 // @formatter:on
340-
341- return new ChatResponse (generations , from (chatCompletion2 , null ));
355+ OpenAiApi .Usage usage = chatCompletion2 .usage ();
356+ Usage currentChatResponseUsage = usage != null ? OpenAiUsage .from (usage ) : new EmptyUsage ();
357+ Usage accumulatedUsage = UsageUtils .getCumulativeUsage (currentChatResponseUsage ,
358+ previousChatResponse );
359+ return new ChatResponse (generations , from (chatCompletion2 , null , accumulatedUsage ));
342360 }
343361 catch (Exception e ) {
344362 logger .error ("Error processing chat completion" , e );
345363 return new ChatResponse (List .of ());
346364 }
347-
348- }));
365+ // When in stream mode and enabled to include the usage, the OpenAI
366+ // Chat completion response would have the usage set only in its
367+ // final response. Hence, the following overlapping buffer is
368+ // created to store both the current and the subsequent response
369+ // to accumulate the usage from the subsequent response.
370+ }))
371+ .buffer (2 , 1 )
372+ .map (bufferList -> {
373+ ChatResponse firstResponse = bufferList .get (0 );
374+ if (request .streamOptions () != null && request .streamOptions ().includeUsage ()) {
375+ if (bufferList .size () == 2 ) {
376+ ChatResponse secondResponse = bufferList .get (1 );
377+ if (secondResponse != null && secondResponse .getMetadata () != null ) {
378+ // This is the usage from the final Chat response for a
379+ // given Chat request.
380+ Usage usage = secondResponse .getMetadata ().getUsage ();
381+ if (!UsageUtils .isEmpty (usage )) {
382+ // Store the usage from the final response to the
383+ // penultimate response for accumulation.
384+ return new ChatResponse (firstResponse .getResults (),
385+ from (firstResponse .getMetadata (), usage ));
386+ }
387+ }
388+ }
389+ }
390+ return firstResponse ;
391+ });
349392
350393 // @formatter:off
351394 Flux <ChatResponse > flux = chatResponse .flatMap (response -> {
@@ -355,7 +398,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
355398 var toolCallConversation = handleToolCalls (prompt , response );
356399 // Recursively call the stream method with the tool call message
357400 // conversation that contains the call responses.
358- return this .stream (new Prompt (toolCallConversation , prompt .getOptions ()));
401+ return this .internalStream (new Prompt (toolCallConversation , prompt .getOptions ()), response );
359402 }
360403 else {
361404 return Flux .just (response );
@@ -412,11 +455,11 @@ private Generation buildGeneration(Choice choice, Map<String, Object> metadata,
412455 return new Generation (assistantMessage , generationMetadataBuilder .build ());
413456 }
414457
415- private ChatResponseMetadata from (OpenAiApi .ChatCompletion result , RateLimit rateLimit ) {
458+ private ChatResponseMetadata from (OpenAiApi .ChatCompletion result , RateLimit rateLimit , Usage usage ) {
416459 Assert .notNull (result , "OpenAI ChatCompletionResult must not be null" );
417460 var builder = ChatResponseMetadata .builder ()
418461 .withId (result .id () != null ? result .id () : "" )
419- .withUsage (result . usage () != null ? OpenAiUsage . from ( result . usage ()) : new EmptyUsage () )
462+ .withUsage (usage )
420463 .withModel (result .model () != null ? result .model () : "" )
421464 .withKeyValue ("created" , result .created () != null ? result .created () : 0L )
422465 .withKeyValue ("system-fingerprint" , result .systemFingerprint () != null ? result .systemFingerprint () : "" );
@@ -426,6 +469,18 @@ private ChatResponseMetadata from(OpenAiApi.ChatCompletion result, RateLimit rat
426469 return builder .build ();
427470 }
428471
472+ private ChatResponseMetadata from (ChatResponseMetadata chatResponseMetadata , Usage usage ) {
473+ Assert .notNull (chatResponseMetadata , "OpenAI ChatResponseMetadata must not be null" );
474+ var builder = ChatResponseMetadata .builder ()
475+ .withId (chatResponseMetadata .getId () != null ? chatResponseMetadata .getId () : "" )
476+ .withUsage (usage )
477+ .withModel (chatResponseMetadata .getModel () != null ? chatResponseMetadata .getModel () : "" );
478+ if (chatResponseMetadata .getRateLimit () != null ) {
479+ builder .withRateLimit (chatResponseMetadata .getRateLimit ());
480+ }
481+ return builder .build ();
482+ }
483+
429484 /**
430485 * Convert the ChatCompletionChunk into a ChatCompletion. The Usage is set to null.
431486 * @param chunk the ChatCompletionChunk to convert
@@ -533,7 +588,6 @@ else if (message.getMessageType() == MessageType.TOOL) {
533588 OpenAiChatOptions .builder ().withTools (this .getFunctionTools (enabledToolsToUse )).build (), request ,
534589 ChatCompletionRequest .class );
535590 }
536-
537591 // Remove `streamOptions` from the request if it is not a streaming request
538592 if (request .streamOptions () != null && !stream ) {
539593 logger .warn ("Removing streamOptions from the request as it is not a streaming request!" );
0 commit comments