diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java index 839d99e23d8..222d53d426f 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java @@ -16,6 +16,7 @@ package org.springframework.ai.chat.model; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -24,6 +25,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.util.CollectionUtils; import reactor.core.publisher.Flux; import org.springframework.ai.chat.messages.AssistantMessage; @@ -35,6 +37,8 @@ import org.springframework.ai.chat.metadata.Usage; import org.springframework.util.StringUtils; +import static org.springframework.ai.chat.messages.AssistantMessage.*; + /** * Helper that for streaming chat responses, aggregate the chat response messages into a * single AssistantMessage. Job is performed in parallel to the chat response processing. @@ -42,6 +46,7 @@ * @author Christian Tzolov * @author Alexandros Pappas * @author Thomas Vitale + * @author Heonwoo Kim * @since 1.0.0 */ public class MessageAggregator { @@ -54,6 +59,7 @@ public Flux aggregate(Flux fluxChatResponse, // Assistant Message AtomicReference messageTextContentRef = new AtomicReference<>(new StringBuilder()); AtomicReference> messageMetadataMapRef = new AtomicReference<>(); + AtomicReference> toolCallsRef = new AtomicReference<>(new ArrayList<>()); // ChatGeneration Metadata AtomicReference generationMetadataRef = new AtomicReference<>( @@ -73,6 +79,7 @@ public Flux aggregate(Flux fluxChatResponse, return fluxChatResponse.doOnSubscribe(subscription -> { messageTextContentRef.set(new StringBuilder()); messageMetadataMapRef.set(new HashMap<>()); + toolCallsRef.set(new ArrayList<>()); metadataIdRef.set(""); metadataModelRef.set(""); metadataUsagePromptTokensRef.set(0); @@ -94,6 +101,11 @@ public Flux aggregate(Flux fluxChatResponse, if (chatResponse.getResult().getOutput().getMetadata() != null) { messageMetadataMapRef.get().putAll(chatResponse.getResult().getOutput().getMetadata()); } + AssistantMessage outputMessage = chatResponse.getResult().getOutput(); + if (!CollectionUtils.isEmpty(outputMessage.getToolCalls())) { + toolCallsRef.get().addAll(outputMessage.getToolCalls()); + } + } if (chatResponse.getMetadata() != null) { if (chatResponse.getMetadata().getUsage() != null) { @@ -119,6 +131,13 @@ public Flux aggregate(Flux fluxChatResponse, if (StringUtils.hasText(chatResponse.getMetadata().getModel())) { metadataModelRef.set(chatResponse.getMetadata().getModel()); } + Object toolCallsFromMetadata = chatResponse.getMetadata().get("toolCalls"); + if (toolCallsFromMetadata instanceof List) { + @SuppressWarnings("unchecked") + List toolCallsList = (List) toolCallsFromMetadata; + toolCallsRef.get().addAll(toolCallsList); + } + } }).doOnComplete(() -> { @@ -133,12 +152,25 @@ public Flux aggregate(Flux fluxChatResponse, .promptMetadata(metadataPromptMetadataRef.get()) .build(); - onAggregationComplete.accept(new ChatResponse(List.of(new Generation( - new AssistantMessage(messageTextContentRef.get().toString(), messageMetadataMapRef.get()), + AssistantMessage finalAssistantMessage; + List collectedToolCalls = toolCallsRef.get(); + + if (!CollectionUtils.isEmpty(collectedToolCalls)) { + + finalAssistantMessage = new AssistantMessage(messageTextContentRef.get().toString(), + messageMetadataMapRef.get(), collectedToolCalls); + } + else { + finalAssistantMessage = new AssistantMessage(messageTextContentRef.get().toString(), + messageMetadataMapRef.get()); + } + onAggregationComplete.accept(new ChatResponse(List.of(new Generation(finalAssistantMessage, + generationMetadataRef.get())), chatResponseMetadata)); messageTextContentRef.set(new StringBuilder()); messageMetadataMapRef.set(new HashMap<>()); + toolCallsRef.set(new ArrayList<>()); metadataIdRef.set(""); metadataModelRef.set(""); metadataUsagePromptTokensRef.set(0); diff --git a/spring-ai-model/src/test/java/org/springframework/ai/chat/model/ChatResponseTests.java b/spring-ai-model/src/test/java/org/springframework/ai/chat/model/ChatResponseTests.java index a9d673e23be..08cdc993fba 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/chat/model/ChatResponseTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/chat/model/ChatResponseTests.java @@ -19,27 +19,32 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.atomic.AtomicReference; import org.junit.jupiter.api.Test; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import reactor.core.publisher.Flux; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.springframework.ai.chat.messages.AssistantMessage.*; /** * Unit tests for {@link ChatResponse}. * * @author Thomas Vitale + * @author Heonwoo Kim */ class ChatResponseTests { @Test void whenToolCallsArePresentThenReturnTrue() { ChatResponse chatResponse = ChatResponse.builder() - .generations(List.of(new Generation(new AssistantMessage("", Map.of(), - List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}")))))) + .generations(List.of(new Generation( + new AssistantMessage("", Map.of(), List.of(new ToolCall("toolA", "function", "toolA", "{}")))))) .build(); assertThat(chatResponse.hasToolCalls()).isTrue(); } @@ -80,4 +85,45 @@ void whenFinishReasonIsNotPresent() { assertThat(chatResponse.hasFinishReasons(Set.of("completed"))).isFalse(); } + @Test + void messageAggregatorShouldCorrectlyAggregateToolCallsFromStream() { + + MessageAggregator aggregator = new MessageAggregator(); + + ChatResponse chunk1 = new ChatResponse( + List.of(new Generation(new AssistantMessage("Thinking about the weather... ")))); + + ToolCall weatherToolCall = new ToolCall("tool-id-123", "function", "getCurrentWeather", + "{\"location\": \"Seoul\"}"); + + Map metadataWithToolCall = Map.of("toolCalls", List.of(weatherToolCall)); + ChatResponseMetadata responseMetadataForChunk2 = ChatResponseMetadata.builder() + .metadata(metadataWithToolCall) + .build(); + + ChatResponse chunk2 = new ChatResponse(List.of(new Generation(new AssistantMessage(""))), + responseMetadataForChunk2); + + Flux streamingResponse = Flux.just(chunk1, chunk2); + + AtomicReference aggregatedResponseRef = new AtomicReference<>(); + + aggregator.aggregate(streamingResponse, aggregatedResponseRef::set).blockLast(); + + ChatResponse finalResponse = aggregatedResponseRef.get(); + assertThat(finalResponse).isNotNull(); + + AssistantMessage finalAssistantMessage = finalResponse.getResult().getOutput(); + + assertThat(finalAssistantMessage).isNotNull(); + assertThat(finalAssistantMessage.getText()).isEqualTo("Thinking about the weather... "); + assertThat(finalAssistantMessage.hasToolCalls()).isTrue(); + assertThat(finalAssistantMessage.getToolCalls()).hasSize(1); + + ToolCall resultToolCall = finalAssistantMessage.getToolCalls().get(0); + assertThat(resultToolCall.id()).isEqualTo("tool-id-123"); + assertThat(resultToolCall.name()).isEqualTo("getCurrentWeather"); + assertThat(resultToolCall.arguments()).isEqualTo("{\"location\": \"Seoul\"}"); + } + }