5151import com .azure .ai .openai .models .ChatRequestToolMessage ;
5252import com .azure .ai .openai .models .ChatRequestUserMessage ;
5353import com .azure .ai .openai .models .CompletionsFinishReason ;
54+ import com .azure .ai .openai .models .CompletionsUsage ;
5455import com .azure .ai .openai .models .ContentFilterResultsForPrompt ;
5556import com .azure .ai .openai .models .FunctionCall ;
5657import com .azure .core .util .BinaryData ;
7071import org .springframework .ai .chat .metadata .PromptMetadata ;
7172import org .springframework .ai .chat .metadata .PromptMetadata .PromptFilterMetadata ;
7273import org .springframework .ai .chat .metadata .Usage ;
74+ import org .springframework .ai .chat .metadata .UsageUtils ;
7375import org .springframework .ai .chat .model .AbstractToolCallSupport ;
7476import org .springframework .ai .chat .model .ChatModel ;
7577import org .springframework .ai .chat .model .ChatResponse ;
105107 * @author timostark
106108 * @author Soby Chacko
107109 * @author Jihoon Kim
110+ * @author Ilayaperumal Gopinathan
108111 * @see ChatModel
109112 * @see com.azure.ai.openai.OpenAIClient
110113 * @since 1.0.0
@@ -176,10 +179,10 @@ public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAi
176179 this .observationRegistry = observationRegistry ;
177180 }
178181
179- public static ChatResponseMetadata from (ChatCompletions chatCompletions , PromptMetadata promptFilterMetadata ) {
182+ public static ChatResponseMetadata from (ChatCompletions chatCompletions , PromptMetadata promptFilterMetadata ,
183+ Usage usage ) {
180184 Assert .notNull (chatCompletions , "Azure OpenAI ChatCompletions must not be null" );
181185 String id = chatCompletions .getId ();
182- Usage usage = (chatCompletions .getUsage () != null ) ? AzureOpenAiUsage .from (chatCompletions ) : new EmptyUsage ();
183186 return ChatResponseMetadata .builder ()
184187 .withId (id )
185188 .withUsage (usage )
@@ -189,12 +192,40 @@ public static ChatResponseMetadata from(ChatCompletions chatCompletions, PromptM
189192 .build ();
190193 }
191194
195+ public static ChatResponseMetadata from (ChatCompletions chatCompletions , PromptMetadata promptFilterMetadata ) {
196+ Usage usage = (chatCompletions .getUsage () != null ) ? AzureOpenAiUsage .from (chatCompletions ) : new EmptyUsage ();
197+ return from (chatCompletions , promptFilterMetadata , usage );
198+ }
199+
200+ public static ChatResponseMetadata from (ChatCompletions chatCompletions , PromptMetadata promptFilterMetadata ,
201+ CompletionsUsage usage ) {
202+ return from (chatCompletions , promptFilterMetadata , AzureOpenAiUsage .from (usage ));
203+ }
204+
205+ public static ChatResponseMetadata from (ChatResponse chatResponse , Usage usage ) {
206+ Assert .notNull (chatResponse , "ChatResponse must not be null" );
207+ ChatResponseMetadata chatResponseMetadata = chatResponse .getMetadata ();
208+ ChatResponseMetadata .Builder builder = ChatResponseMetadata .builder ();
209+ builder .withId (chatResponseMetadata .getId ())
210+ .withUsage (usage )
211+ .withModel (chatResponseMetadata .getModel ())
212+ .withPromptMetadata (chatResponseMetadata .getPromptMetadata ());
213+ if (chatResponseMetadata .containsKey ("system-fingerprint" )) {
214+ builder .withKeyValue ("system-fingerprint" , chatResponseMetadata .get ("system-fingerprint" ));
215+ }
216+ return builder .build ();
217+ }
218+
192219 public AzureOpenAiChatOptions getDefaultOptions () {
193220 return AzureOpenAiChatOptions .fromOptions (this .defaultOptions );
194221 }
195222
196223 @ Override
197224 public ChatResponse call (Prompt prompt ) {
225+ return this .internalCall (prompt , null );
226+ }
227+
228+ public ChatResponse internalCall (Prompt prompt , ChatResponse previousChatResponse ) {
198229
199230 ChatModelObservationContext observationContext = ChatModelObservationContext .builder ()
200231 .prompt (prompt )
@@ -210,7 +241,7 @@ public ChatResponse call(Prompt prompt) {
210241 ChatCompletionsOptionsAccessHelper .setStream (options , false );
211242
212243 ChatCompletions chatCompletions = this .openAIClient .getChatCompletions (options .getModel (), options );
213- ChatResponse chatResponse = toChatResponse (chatCompletions );
244+ ChatResponse chatResponse = toChatResponse (chatCompletions , previousChatResponse );
214245 observationContext .setResponse (chatResponse );
215246 return chatResponse ;
216247 });
@@ -220,14 +251,18 @@ && isToolCall(response, Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS
220251 var toolCallConversation = handleToolCalls (prompt , response );
221252 // Recursively call the call method with the tool call message
222253 // conversation that contains the call responses.
223- return this .call (new Prompt (toolCallConversation , prompt .getOptions ()));
254+ return this .internalCall (new Prompt (toolCallConversation , prompt .getOptions ()), response );
224255 }
225256
226257 return response ;
227258 }
228259
229260 @ Override
230261 public Flux <ChatResponse > stream (Prompt prompt ) {
262+ return this .internalStream (prompt , null );
263+ }
264+
265+ public Flux <ChatResponse > internalStream (Prompt prompt , ChatResponse previousChatResponse ) {
231266
232267 return Flux .deferContextual (contextView -> {
233268 ChatCompletionsOptions options = toAzureChatCompletionsOptions (prompt );
@@ -279,16 +314,36 @@ public Flux<ChatResponse> stream(Prompt prompt) {
279314 })
280315 .flatMap (mono -> mono );
281316
282- return accessibleChatCompletionsFlux .switchMap (chatCompletions -> {
283-
284- ChatResponse chatResponse = toChatResponse (chatCompletions );
317+ final Flux <ChatResponse > chatResponseFlux = accessibleChatCompletionsFlux .map (chatCompletion -> {
318+ if (previousChatResponse == null ) {
319+ return toChatResponse (chatCompletion );
320+ }
321+ // Accumulate the usage from the previous chat response
322+ CompletionsUsage usage = chatCompletion .getUsage ();
323+ Usage currentChatResponseUsage = usage != null ? AzureOpenAiUsage .from (usage ) : new EmptyUsage ();
324+ Usage accumulatedUsage = UsageUtils .getCumulativeUsage (currentChatResponseUsage , previousChatResponse );
325+ return toChatResponse (chatCompletion , accumulatedUsage );
326+ }).buffer (2 , 1 ).map (bufferList -> {
327+ ChatResponse chatResponse1 = bufferList .get (0 );
328+ if (options .getStreamOptions () != null && options .getStreamOptions ().isIncludeUsage ()) {
329+ if (bufferList .size () == 2 ) {
330+ ChatResponse chatResponse2 = bufferList .get (1 );
331+ if (chatResponse2 != null && chatResponse2 .getMetadata () != null
332+ && !UsageUtils .isEmpty (chatResponse2 .getMetadata ().getUsage ())) {
333+ return toChatResponse (chatResponse1 , chatResponse2 .getMetadata ().getUsage ());
334+ }
335+ }
336+ }
337+ return chatResponse1 ;
338+ });
285339
340+ return chatResponseFlux .flatMap (chatResponse -> {
286341 if (!isProxyToolCalls (prompt , this .defaultOptions ) && isToolCall (chatResponse ,
287342 Set .of (String .valueOf (CompletionsFinishReason .TOOL_CALLS ).toLowerCase ()))) {
288343 var toolCallConversation = handleToolCalls (prompt , chatResponse );
289344 // Recursively call the call method with the tool call message
290345 // conversation that contains the call responses.
291- return this .stream (new Prompt (toolCallConversation , prompt .getOptions ()));
346+ return this .internalStream (new Prompt (toolCallConversation , prompt .getOptions ()), chatResponse );
292347 }
293348
294349 Flux <ChatResponse > flux = Flux .just (chatResponse )
@@ -305,6 +360,44 @@ public Flux<ChatResponse> stream(Prompt prompt) {
305360
306361 private ChatResponse toChatResponse (ChatCompletions chatCompletions ) {
307362
363+ List <Generation > generations = nullSafeList (chatCompletions .getChoices ()).stream ().map (choice -> {
364+ // @formatter:off
365+ Map <String , Object > metadata = Map .of (
366+ "id" , chatCompletions .getId () != null ? chatCompletions .getId () : "" ,
367+ "choiceIndex" , choice .getIndex (),
368+ "finishReason" , choice .getFinishReason () != null ? String .valueOf (choice .getFinishReason ()) : "" );
369+ // @formatter:on
370+ return buildGeneration (choice , metadata );
371+ }).toList ();
372+
373+ PromptMetadata promptFilterMetadata = generatePromptMetadata (chatCompletions );
374+
375+ return new ChatResponse (generations , from (chatCompletions , promptFilterMetadata ));
376+ }
377+
378+ private ChatResponse toChatResponse (ChatCompletions chatCompletions , Usage usage ) {
379+
380+ List <Generation > generations = nullSafeList (chatCompletions .getChoices ()).stream ().map (choice -> {
381+ // @formatter:off
382+ Map <String , Object > metadata = Map .of (
383+ "id" , chatCompletions .getId () != null ? chatCompletions .getId () : "" ,
384+ "choiceIndex" , choice .getIndex (),
385+ "finishReason" , choice .getFinishReason () != null ? String .valueOf (choice .getFinishReason ()) : "" );
386+ // @formatter:on
387+ return buildGeneration (choice , metadata );
388+ }).toList ();
389+
390+ PromptMetadata promptFilterMetadata = generatePromptMetadata (chatCompletions );
391+
392+ return new ChatResponse (generations , from (chatCompletions , promptFilterMetadata , usage ));
393+ }
394+
395+ private ChatResponse toChatResponse (ChatResponse chatResponse , Usage usage ) {
396+ return new ChatResponse (chatResponse .getResults (), from (chatResponse , usage ));
397+ }
398+
399+ private ChatResponse toChatResponse (ChatCompletions chatCompletions , ChatResponse previousChatResponse ) {
400+
308401 List <Generation > generations = nullSafeList (chatCompletions .getChoices ()).stream ().map (choice -> {
309402 // @formatter:off
310403 Map <String , Object > metadata = Map .of (
@@ -316,8 +409,12 @@ private ChatResponse toChatResponse(ChatCompletions chatCompletions) {
316409 }).toList ();
317410
318411 PromptMetadata promptFilterMetadata = generatePromptMetadata (chatCompletions );
319-
320- return new ChatResponse (generations , from (chatCompletions , promptFilterMetadata ));
412+ Usage currentUsage = null ;
413+ if (chatCompletions .getUsage () != null ) {
414+ currentUsage = AzureOpenAiUsage .from (chatCompletions );
415+ }
416+ Usage cumulativeUsage = UsageUtils .getCumulativeUsage (currentUsage , previousChatResponse );
417+ return new ChatResponse (generations , from (chatCompletions , promptFilterMetadata , cumulativeUsage ));
321418 }
322419
323420 private Generation buildGeneration (ChatChoice choice , Map <String , Object > metadata ) {
0 commit comments