Skip to content

Commit 87cdf62

Browse files
committed
fix(chat): handle multiple returnDirect responses in DefaultChatClient (#4655)
Signed-off-by: Kuntal Maity <[email protected]>
1 parent 0fdb911 commit 87cdf62

File tree

2 files changed

+119
-4
lines changed

2 files changed

+119
-4
lines changed

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import java.util.List;
2727
import java.util.Map;
2828
import java.util.Optional;
29+
import java.util.Objects;
2930
import java.util.function.Consumer;
3031

3132
import io.micrometer.observation.Observation;
@@ -47,6 +48,7 @@
4748
import org.springframework.ai.chat.model.ChatModel;
4849
import org.springframework.ai.chat.model.ChatResponse;
4950
import org.springframework.ai.chat.model.Generation;
51+
import org.springframework.ai.model.tool.ToolExecutionResult;
5052
import org.springframework.ai.chat.prompt.ChatOptions;
5153
import org.springframework.ai.chat.prompt.Prompt;
5254
import org.springframework.ai.content.Media;
@@ -521,8 +523,33 @@ private ChatClientResponse doGetObservableChatClientResponse(ChatClientRequest c
521523

522524
@Nullable
523525
private static String getContentFromChatResponse(@Nullable ChatResponse chatResponse) {
524-
return Optional.ofNullable(chatResponse)
525-
.map(ChatResponse::getResult)
526+
if (chatResponse == null) {
527+
return null;
528+
}
529+
var results = chatResponse.getResults();
530+
if (results == null || results.isEmpty()) {
531+
return null;
532+
}
533+
if (results.size() == 1) {
534+
return Optional.ofNullable(results.get(0))
535+
.map(Generation::getOutput)
536+
.map(AbstractMessage::getText)
537+
.orElse(null);
538+
}
539+
boolean allReturnDirect = results.stream().allMatch(g -> {
540+
var finish = g.getMetadata() != null ? g.getMetadata().getFinishReason() : null;
541+
return finish != null && finish.equalsIgnoreCase(ToolExecutionResult.FINISH_REASON);
542+
});
543+
if (allReturnDirect) {
544+
return results.stream()
545+
.map(Generation::getOutput)
546+
.map(AbstractMessage::getText)
547+
.filter(Objects::nonNull)
548+
.filter(StringUtils::hasText)
549+
.reduce((a, b) -> a + "\n" + b)
550+
.orElse(null);
551+
}
552+
return Optional.ofNullable(results.get(0))
526553
.map(Generation::getOutput)
527554
.map(AbstractMessage::getText)
528555
.orElse(null);
@@ -594,10 +621,35 @@ public Flux<String> content() {
594621
// @formatter:off
595622
return doGetObservableFluxChatResponse(this.request)
596623
.mapNotNull(ChatClientResponse::chatResponse)
597-
.map(r -> Optional.ofNullable(r.getResult())
624+
.map(r -> {
625+
var results = r.getResults();
626+
if (results == null || results.isEmpty()) {
627+
return "";
628+
}
629+
if (results.size() == 1) {
630+
return Optional.ofNullable(results.get(0))
631+
.map(Generation::getOutput)
632+
.map(AbstractMessage::getText)
633+
.orElse("");
634+
}
635+
boolean allReturnDirect = results.stream().allMatch(g -> {
636+
var finish = g.getMetadata() != null ? g.getMetadata().getFinishReason() : null;
637+
return finish != null && finish.equalsIgnoreCase(org.springframework.ai.model.tool.ToolExecutionResult.FINISH_REASON);
638+
});
639+
if (allReturnDirect) {
640+
return results.stream()
641+
.map(Generation::getOutput)
642+
.map(AbstractMessage::getText)
643+
.filter(java.util.Objects::nonNull)
644+
.filter(StringUtils::hasText)
645+
.reduce((a, b) -> a + "\n" + b)
646+
.orElse("");
647+
}
648+
return Optional.ofNullable(results.get(0))
598649
.map(Generation::getOutput)
599650
.map(AbstractMessage::getText)
600-
.orElse(""))
651+
.orElse("");
652+
})
601653
.filter(StringUtils::hasLength);
602654
// @formatter:on
603655
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
package org.springframework.ai.chat.client;
2+
3+
import java.util.List;
4+
5+
import org.junit.jupiter.api.Test;
6+
7+
import org.springframework.ai.chat.messages.AssistantMessage;
8+
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
9+
import org.springframework.ai.chat.model.ChatModel;
10+
import org.springframework.ai.chat.model.ChatResponse;
11+
import org.springframework.ai.chat.model.Generation;
12+
import org.springframework.ai.chat.prompt.Prompt;
13+
import org.springframework.ai.model.tool.ToolExecutionResult;
14+
15+
import static org.assertj.core.api.Assertions.assertThat;
16+
17+
/*
18+
* @author: Kuntal Maity
19+
*/
20+
class DefaultChatClientReturnDirectAggregationTests {
21+
22+
private static Generation generation(String text, String finishReason) {
23+
var metadata = ChatGenerationMetadata.builder().finishReason(finishReason).build();
24+
return new Generation(new AssistantMessage(text), metadata);
25+
}
26+
27+
@Test
28+
void aggregatesMultipleReturnDirectGenerationsInContent() {
29+
var chatResponse = new ChatResponse(List.of(generation("DATE=2025-10-18", ToolExecutionResult.FINISH_REASON),
30+
generation("TIME=12:34:56.789", ToolExecutionResult.FINISH_REASON)));
31+
32+
ChatModel stub = new ChatModel() {
33+
@Override
34+
public ChatResponse call(Prompt prompt) {
35+
return chatResponse;
36+
}
37+
};
38+
39+
var client = ChatClient.builder(stub).build();
40+
String content = client.prompt("now").call().content();
41+
42+
assertThat(content).isEqualTo("DATE=2025-10-18\nTIME=12:34:56.789");
43+
}
44+
45+
@Test
46+
void returnsFirstWhenNotAllReturnDirect() {
47+
var chatResponse = new ChatResponse(
48+
List.of(generation("FIRST", ToolExecutionResult.FINISH_REASON), generation("SECOND", "stop")));
49+
50+
ChatModel stub = new ChatModel() {
51+
@Override
52+
public ChatResponse call(Prompt prompt) {
53+
return chatResponse;
54+
}
55+
};
56+
57+
var client = ChatClient.builder(stub).build();
58+
String content = client.prompt("now").call().content();
59+
60+
assertThat(content).isEqualTo("FIRST");
61+
}
62+
63+
}

0 commit comments

Comments
 (0)