4747import org .springframework .ai .chat .messages .UserMessage ;
4848import org .springframework .ai .chat .metadata .ChatGenerationMetadata ;
4949import org .springframework .ai .chat .metadata .ChatResponseMetadata ;
50+ import org .springframework .ai .chat .metadata .EmptyUsage ;
51+ import org .springframework .ai .chat .metadata .Usage ;
52+ import org .springframework .ai .chat .metadata .UsageUtils ;
5053import org .springframework .ai .chat .model .AbstractToolCallSupport ;
5154import org .springframework .ai .chat .model .ChatModel ;
5255import org .springframework .ai .chat .model .ChatResponse ;
@@ -211,6 +214,10 @@ public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaul
211214
212215 @ Override
213216 public ChatResponse call (Prompt prompt ) {
217+ return this .internalCall (prompt , null );
218+ }
219+
220+ public ChatResponse internalCall (Prompt prompt , ChatResponse previousChatResponse ) {
214221 ChatCompletionRequest request = createRequest (prompt , false );
215222
216223 ChatModelObservationContext observationContext = ChatModelObservationContext .builder ()
@@ -227,8 +234,14 @@ public ChatResponse call(Prompt prompt) {
227234 ResponseEntity <ChatCompletionResponse > completionEntity = this .retryTemplate
228235 .execute (ctx -> this .anthropicApi .chatCompletionEntity (request ));
229236
230- ChatResponse chatResponse = toChatResponse (completionEntity .getBody ());
237+ AnthropicApi .ChatCompletionResponse completionResponse = completionEntity .getBody ();
238+ AnthropicApi .Usage usage = completionResponse .usage ();
231239
240+ Usage currentChatResponseUsage = usage != null ? AnthropicUsage .from (completionResponse .usage ())
241+ : new EmptyUsage ();
242+ Usage accumulatedUsage = UsageUtils .getCumulativeUsage (currentChatResponseUsage , previousChatResponse );
243+
244+ ChatResponse chatResponse = toChatResponse (completionEntity .getBody (), accumulatedUsage );
232245 observationContext .setResponse (chatResponse );
233246
234247 return chatResponse ;
@@ -237,14 +250,18 @@ public ChatResponse call(Prompt prompt) {
237250 if (!isProxyToolCalls (prompt , this .defaultOptions ) && response != null
238251 && this .isToolCall (response , Set .of ("tool_use" ))) {
239252 var toolCallConversation = handleToolCalls (prompt , response );
240- return this .call (new Prompt (toolCallConversation , prompt .getOptions ()));
253+ return this .internalCall (new Prompt (toolCallConversation , prompt .getOptions ()), response );
241254 }
242255
243256 return response ;
244257 }
245258
246259 @ Override
247260 public Flux <ChatResponse > stream (Prompt prompt ) {
261+ return this .internalStream (prompt , null );
262+ }
263+
264+ public Flux <ChatResponse > internalStream (Prompt prompt , ChatResponse previousChatResponse ) {
248265 return Flux .deferContextual (contextView -> {
249266 ChatCompletionRequest request = createRequest (prompt , true );
250267
@@ -264,11 +281,14 @@ public Flux<ChatResponse> stream(Prompt prompt) {
264281
265282 // @formatter:off
266283 Flux <ChatResponse > chatResponseFlux = response .switchMap (chatCompletionResponse -> {
267- ChatResponse chatResponse = toChatResponse (chatCompletionResponse );
284+ AnthropicApi .Usage usage = chatCompletionResponse .usage ();
285+ Usage currentChatResponseUsage = usage != null ? AnthropicUsage .from (chatCompletionResponse .usage ()) : new EmptyUsage ();
286+ Usage accumulatedUsage = UsageUtils .getCumulativeUsage (currentChatResponseUsage , previousChatResponse );
287+ ChatResponse chatResponse = toChatResponse (chatCompletionResponse , accumulatedUsage );
268288
269289 if (!isProxyToolCalls (prompt , this .defaultOptions ) && this .isToolCall (chatResponse , Set .of ("tool_use" ))) {
270290 var toolCallConversation = handleToolCalls (prompt , chatResponse );
271- return this .stream (new Prompt (toolCallConversation , prompt .getOptions ()));
291+ return this .internalStream (new Prompt (toolCallConversation , prompt .getOptions ()), chatResponse );
272292 }
273293
274294 return Mono .just (chatResponse );
@@ -282,7 +302,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
282302 });
283303 }
284304
285- private ChatResponse toChatResponse (ChatCompletionResponse chatCompletion ) {
305+ private ChatResponse toChatResponse (ChatCompletionResponse chatCompletion , Usage usage ) {
286306
287307 if (chatCompletion == null ) {
288308 logger .warn ("Null chat completion returned" );
@@ -328,12 +348,15 @@ private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion) {
328348 allGenerations .add (toolCallGeneration );
329349 }
330350
331- return new ChatResponse (allGenerations , this .from (chatCompletion ));
351+ return new ChatResponse (allGenerations , this .from (chatCompletion , usage ));
332352 }
333353
334354 private ChatResponseMetadata from (AnthropicApi .ChatCompletionResponse result ) {
355+ return from (result , AnthropicUsage .from (result .usage ()));
356+ }
357+
358+ private ChatResponseMetadata from (AnthropicApi .ChatCompletionResponse result , Usage usage ) {
335359 Assert .notNull (result , "Anthropic ChatCompletionResult must not be null" );
336- AnthropicUsage usage = AnthropicUsage .from (result .usage ());
337360 return ChatResponseMetadata .builder ()
338361 .withId (result .id ())
339362 .withModel (result .model ())
0 commit comments