@@ -90,7 +90,8 @@ public static boolean isToolUseFinish(ConverseStreamOutput event) {
9090 return true ;
9191 }
9292
93- public static Flux <ChatResponse > toChatResponse (Flux <ConverseStreamOutput > responses ) {
93+ public static Flux <ChatResponse > toChatResponse (Flux <ConverseStreamOutput > responses ,
94+ ChatResponse perviousChatResponse ) {
9495
9596 AtomicBoolean isInsideTool = new AtomicBoolean (false );
9697
@@ -120,20 +121,30 @@ public static Flux<ChatResponse> toChatResponse(Flux<ConverseStreamOutput> respo
120121
121122 List <AssistantMessage .ToolCall > toolCalls = new ArrayList <>();
122123
124+ Long promptTokens = 0L ;
125+ Long generationTokens = 0L ;
126+ Long totalTokens = 0L ;
127+
123128 for (ToolUseAggregationEvent .ToolUseEntry toolUseEntry : toolUseAggregationEvent .toolUseEntries ()) {
124129 var functionCallId = toolUseEntry .id ();
125130 var functionName = toolUseEntry .name ();
126131 var functionArguments = toolUseEntry .input ();
127132 toolCalls .add (
128133 new AssistantMessage .ToolCall (functionCallId , "function" , functionName , functionArguments ));
134+
135+ if (toolUseEntry .usage () != null ) {
136+ promptTokens += toolUseEntry .usage ().getPromptTokens ();
137+ generationTokens += toolUseEntry .usage ().getGenerationTokens ();
138+ totalTokens += toolUseEntry .usage ().getTotalTokens ();
139+ }
129140 }
130141
131142 AssistantMessage assistantMessage = new AssistantMessage ("" , Map .of (), toolCalls );
132143 Generation toolCallGeneration = new Generation (assistantMessage ,
133144 ChatGenerationMetadata .from ("tool_use" , null ));
134145
135146 var chatResponseMetaData = ChatResponseMetadata .builder ()
136- .withUsage (toolUseAggregationEvent . usage )
147+ .withUsage (new DefaultUsage ( promptTokens , generationTokens , totalTokens ) )
137148 .build ();
138149
139150 return new Aggregation (
@@ -181,22 +192,22 @@ else if (nextEvent instanceof ContentBlockStopEvent contentBlockStopEvent) {
181192 return new Aggregation ();
182193 }
183194 else if (nextEvent instanceof ConverseStreamMetadataEvent metadataEvent ) {
184- // return new Aggregation();
195+
185196 var newMeta = MetadataAggregation .builder ()
186197 .copy (lastAggregation .metadataAggregation ())
187198 .withTokenUsage (metadataEvent .usage ())
188199 .withMetrics (metadataEvent .metrics ())
189200 .withTrace (metadataEvent .trace ())
190201 .build ();
191202
192- DefaultUsage usage = new DefaultUsage (metadataEvent .usage ().inputTokens ().longValue (),
193- metadataEvent .usage ().outputTokens ().longValue (),
194- metadataEvent .usage ().totalTokens ().longValue ());
195-
196203 // TODO
197204 Document modelResponseFields = lastAggregation .metadataAggregation ().additionalModelResponseFields ();
198205 ConverseStreamMetrics metrics = metadataEvent .metrics ();
199206
207+ DefaultUsage usage = new DefaultUsage (metadataEvent .usage ().inputTokens ().longValue (),
208+ metadataEvent .usage ().outputTokens ().longValue (),
209+ metadataEvent .usage ().totalTokens ().longValue ());
210+
200211 var chatResponseMetaData = ChatResponseMetadata .builder ().withUsage (usage ).build ();
201212
202213 return new Aggregation (newMeta , new ChatResponse (List .of (), chatResponseMetaData ));
@@ -206,8 +217,42 @@ else if (nextEvent instanceof ConverseStreamMetadataEvent metadataEvent) {
206217 }
207218 })
208219 // .skip(1)
209- .map (aggregation -> aggregation .chatResponse ())
210- .filter (chatResponse -> chatResponse != ConverseApiUtils .EMPTY_CHAT_RESPONSE );
220+ .filter (aggregation -> aggregation .chatResponse () != ConverseApiUtils .EMPTY_CHAT_RESPONSE )
221+ .map (aggregation -> {
222+
223+ var chatResponse = aggregation .chatResponse ();
224+
225+ // Merge the previous chat response metadata with the current one.
226+ if (perviousChatResponse != null && perviousChatResponse .getMetadata () != null
227+ && perviousChatResponse .getMetadata ().getUsage () != null ) {
228+
229+ var metadataBuilder = ChatResponseMetadata .builder ();
230+
231+ Long promptTokens = perviousChatResponse .getMetadata ().getUsage ().getPromptTokens ();
232+ Long generationTokens = perviousChatResponse .getMetadata ().getUsage ().getGenerationTokens ();
233+ Long totalTokens = perviousChatResponse .getMetadata ().getUsage ().getTotalTokens ();
234+
235+ if (chatResponse .getMetadata () != null ) {
236+ metadataBuilder .withId (chatResponse .getMetadata ().getId ());
237+ metadataBuilder .withModel (chatResponse .getMetadata ().getModel ());
238+ metadataBuilder .withRateLimit (chatResponse .getMetadata ().getRateLimit ());
239+ metadataBuilder .withPromptMetadata (chatResponse .getMetadata ().getPromptMetadata ());
240+
241+ if (chatResponse .getMetadata ().getUsage () != null ) {
242+ promptTokens = promptTokens + chatResponse .getMetadata ().getUsage ().getPromptTokens ();
243+ generationTokens = generationTokens
244+ + chatResponse .getMetadata ().getUsage ().getGenerationTokens ();
245+ totalTokens = totalTokens + chatResponse .getMetadata ().getUsage ().getTotalTokens ();
246+ }
247+ }
248+
249+ metadataBuilder .withUsage (new DefaultUsage (promptTokens , generationTokens , totalTokens ));
250+
251+ return new ChatResponse (chatResponse .getResults (), metadataBuilder .build ());
252+ }
253+
254+ return aggregation .chatResponse ();
255+ });
211256 }
212257
213258 public static ConverseStreamOutput mergeToolUseEvents (ConverseStreamOutput previousEvent ,
@@ -245,7 +290,7 @@ else if (event.sdkEventType() == EventType.METADATA) {
245290 DefaultUsage usage = new DefaultUsage (metadataEvent .usage ().inputTokens ().longValue (),
246291 metadataEvent .usage ().outputTokens ().longValue (), metadataEvent .usage ().totalTokens ().longValue ());
247292 toolUseEventAggregator .withUsage (usage );
248- // TODO
293+
249294 if (!toolUseEventAggregator .isEmpty ()) {
250295 toolUseEventAggregator .squashIntoContentBlock ();
251296 return toolUseEventAggregator ;
@@ -400,7 +445,7 @@ ToolUseAggregationEvent appendPartialJson(String partialJson) {
400445 }
401446
402447 void squashIntoContentBlock () {
403- this .toolUseEntries .add (new ToolUseEntry (this .index , this .id , this .name , this .partialJson ));
448+ this .toolUseEntries .add (new ToolUseEntry (this .index , this .id , this .name , this .partialJson , this . usage ));
404449 this .index = null ;
405450 this .id = null ;
406451 this .name = null ;
@@ -424,7 +469,7 @@ public void accept(Visitor visitor) {
424469 throw new UnsupportedOperationException ();
425470 }
426471
427- public record ToolUseEntry (Integer index , String id , String name , String input ) {
472+ public record ToolUseEntry (Integer index , String id , String name , String input , DefaultUsage usage ) {
428473 }
429474
430475 }
0 commit comments