1616
1717package org .springframework .ai .chat .model ;
1818
19+ import java .util .ArrayList ;
1920import java .util .HashMap ;
2021import java .util .List ;
2122import java .util .Map ;
2425
2526import org .slf4j .Logger ;
2627import org .slf4j .LoggerFactory ;
28+ import org .springframework .util .CollectionUtils ;
2729import reactor .core .publisher .Flux ;
2830
2931import org .springframework .ai .chat .messages .AssistantMessage ;
3537import org .springframework .ai .chat .metadata .Usage ;
3638import org .springframework .util .StringUtils ;
3739
40+ import static org .springframework .ai .chat .messages .AssistantMessage .*;
41+
3842/**
3943 * Helper that for streaming chat responses, aggregate the chat response messages into a
4044 * single AssistantMessage. Job is performed in parallel to the chat response processing.
4145 *
4246 * @author Christian Tzolov
4347 * @author Alexandros Pappas
4448 * @author Thomas Vitale
49+ * @author Heonwoo Kim
4550 * @since 1.0.0
4651 */
4752public class MessageAggregator {
@@ -54,6 +59,7 @@ public Flux<ChatResponse> aggregate(Flux<ChatResponse> fluxChatResponse,
5459 // Assistant Message
5560 AtomicReference <StringBuilder > messageTextContentRef = new AtomicReference <>(new StringBuilder ());
5661 AtomicReference <Map <String , Object >> messageMetadataMapRef = new AtomicReference <>();
62+ AtomicReference <List <ToolCall >> toolCallsRef = new AtomicReference <>(new ArrayList <>());
5763
5864 // ChatGeneration Metadata
5965 AtomicReference <ChatGenerationMetadata > generationMetadataRef = new AtomicReference <>(
@@ -73,6 +79,7 @@ public Flux<ChatResponse> aggregate(Flux<ChatResponse> fluxChatResponse,
7379 return fluxChatResponse .doOnSubscribe (subscription -> {
7480 messageTextContentRef .set (new StringBuilder ());
7581 messageMetadataMapRef .set (new HashMap <>());
82+ toolCallsRef .set (new ArrayList <>());
7683 metadataIdRef .set ("" );
7784 metadataModelRef .set ("" );
7885 metadataUsagePromptTokensRef .set (0 );
@@ -94,6 +101,11 @@ public Flux<ChatResponse> aggregate(Flux<ChatResponse> fluxChatResponse,
94101 if (chatResponse .getResult ().getOutput ().getMetadata () != null ) {
95102 messageMetadataMapRef .get ().putAll (chatResponse .getResult ().getOutput ().getMetadata ());
96103 }
104+ AssistantMessage outputMessage = chatResponse .getResult ().getOutput ();
105+ if (!CollectionUtils .isEmpty (outputMessage .getToolCalls ())) {
106+ toolCallsRef .get ().addAll (outputMessage .getToolCalls ());
107+ }
108+
97109 }
98110 if (chatResponse .getMetadata () != null ) {
99111 if (chatResponse .getMetadata ().getUsage () != null ) {
@@ -119,6 +131,13 @@ public Flux<ChatResponse> aggregate(Flux<ChatResponse> fluxChatResponse,
119131 if (StringUtils .hasText (chatResponse .getMetadata ().getModel ())) {
120132 metadataModelRef .set (chatResponse .getMetadata ().getModel ());
121133 }
134+ Object toolCallsFromMetadata = chatResponse .getMetadata ().get ("toolCalls" );
135+ if (toolCallsFromMetadata instanceof List ) {
136+ @ SuppressWarnings ("unchecked" )
137+ List <ToolCall > toolCallsList = (List <ToolCall >) toolCallsFromMetadata ;
138+ toolCallsRef .get ().addAll (toolCallsList );
139+ }
140+
122141 }
123142 }).doOnComplete (() -> {
124143
@@ -133,12 +152,25 @@ public Flux<ChatResponse> aggregate(Flux<ChatResponse> fluxChatResponse,
133152 .promptMetadata (metadataPromptMetadataRef .get ())
134153 .build ();
135154
136- onAggregationComplete .accept (new ChatResponse (List .of (new Generation (
137- new AssistantMessage (messageTextContentRef .get ().toString (), messageMetadataMapRef .get ()),
155+ AssistantMessage finalAssistantMessage ;
156+ List <ToolCall > collectedToolCalls = toolCallsRef .get ();
157+
158+ if (!CollectionUtils .isEmpty (collectedToolCalls )) {
159+
160+ finalAssistantMessage = new AssistantMessage (messageTextContentRef .get ().toString (),
161+ messageMetadataMapRef .get (), collectedToolCalls );
162+ }
163+ else {
164+ finalAssistantMessage = new AssistantMessage (messageTextContentRef .get ().toString (),
165+ messageMetadataMapRef .get ());
166+ }
167+ onAggregationComplete .accept (new ChatResponse (List .of (new Generation (finalAssistantMessage ,
168+
138169 generationMetadataRef .get ())), chatResponseMetadata ));
139170
140171 messageTextContentRef .set (new StringBuilder ());
141172 messageMetadataMapRef .set (new HashMap <>());
173+ toolCallsRef .set (new ArrayList <>());
142174 metadataIdRef .set ("" );
143175 metadataModelRef .set ("" );
144176 metadataUsagePromptTokensRef .set (0 );
0 commit comments