|
26 | 26 | import java.util.List; |
27 | 27 | import java.util.Map; |
28 | 28 | import java.util.Optional; |
| 29 | +import java.util.Objects; |
29 | 30 | import java.util.function.Consumer; |
30 | 31 |
|
31 | 32 | import io.micrometer.observation.Observation; |
|
47 | 48 | import org.springframework.ai.chat.model.ChatModel; |
48 | 49 | import org.springframework.ai.chat.model.ChatResponse; |
49 | 50 | import org.springframework.ai.chat.model.Generation; |
| 51 | +import org.springframework.ai.model.tool.ToolExecutionResult; |
50 | 52 | import org.springframework.ai.chat.prompt.ChatOptions; |
51 | 53 | import org.springframework.ai.chat.prompt.Prompt; |
52 | 54 | import org.springframework.ai.content.Media; |
@@ -521,8 +523,33 @@ private ChatClientResponse doGetObservableChatClientResponse(ChatClientRequest c |
521 | 523 |
|
522 | 524 | @Nullable |
523 | 525 | private static String getContentFromChatResponse(@Nullable ChatResponse chatResponse) { |
524 | | - return Optional.ofNullable(chatResponse) |
525 | | - .map(ChatResponse::getResult) |
| 526 | + if (chatResponse == null) { |
| 527 | + return null; |
| 528 | + } |
| 529 | + var results = chatResponse.getResults(); |
| 530 | + if (results == null || results.isEmpty()) { |
| 531 | + return null; |
| 532 | + } |
| 533 | + if (results.size() == 1) { |
| 534 | + return Optional.ofNullable(results.get(0)) |
| 535 | + .map(Generation::getOutput) |
| 536 | + .map(AbstractMessage::getText) |
| 537 | + .orElse(null); |
| 538 | + } |
| 539 | + boolean allReturnDirect = results.stream().allMatch(g -> { |
| 540 | + var finish = g.getMetadata() != null ? g.getMetadata().getFinishReason() : null; |
| 541 | + return finish != null && finish.equalsIgnoreCase(ToolExecutionResult.FINISH_REASON); |
| 542 | + }); |
| 543 | + if (allReturnDirect) { |
| 544 | + return results.stream() |
| 545 | + .map(Generation::getOutput) |
| 546 | + .map(AbstractMessage::getText) |
| 547 | + .filter(Objects::nonNull) |
| 548 | + .filter(StringUtils::hasText) |
| 549 | + .reduce((a, b) -> a + "\n" + b) |
| 550 | + .orElse(null); |
| 551 | + } |
| 552 | + return Optional.ofNullable(results.get(0)) |
526 | 553 | .map(Generation::getOutput) |
527 | 554 | .map(AbstractMessage::getText) |
528 | 555 | .orElse(null); |
@@ -594,10 +621,35 @@ public Flux<String> content() { |
594 | 621 | // @formatter:off |
595 | 622 | return doGetObservableFluxChatResponse(this.request) |
596 | 623 | .mapNotNull(ChatClientResponse::chatResponse) |
597 | | - .map(r -> Optional.ofNullable(r.getResult()) |
| 624 | + .map(r -> { |
| 625 | + var results = r.getResults(); |
| 626 | + if (results == null || results.isEmpty()) { |
| 627 | + return ""; |
| 628 | + } |
| 629 | + if (results.size() == 1) { |
| 630 | + return Optional.ofNullable(results.get(0)) |
| 631 | + .map(Generation::getOutput) |
| 632 | + .map(AbstractMessage::getText) |
| 633 | + .orElse(""); |
| 634 | + } |
| 635 | + boolean allReturnDirect = results.stream().allMatch(g -> { |
| 636 | + var finish = g.getMetadata() != null ? g.getMetadata().getFinishReason() : null; |
| 637 | + return finish != null && finish.equalsIgnoreCase(org.springframework.ai.model.tool.ToolExecutionResult.FINISH_REASON); |
| 638 | + }); |
| 639 | + if (allReturnDirect) { |
| 640 | + return results.stream() |
| 641 | + .map(Generation::getOutput) |
| 642 | + .map(AbstractMessage::getText) |
| 643 | + .filter(java.util.Objects::nonNull) |
| 644 | + .filter(StringUtils::hasText) |
| 645 | + .reduce((a, b) -> a + "\n" + b) |
| 646 | + .orElse(""); |
| 647 | + } |
| 648 | + return Optional.ofNullable(results.get(0)) |
598 | 649 | .map(Generation::getOutput) |
599 | 650 | .map(AbstractMessage::getText) |
600 | | - .orElse("")) |
| 651 | + .orElse(""); |
| 652 | + }) |
601 | 653 | .filter(StringUtils::hasLength); |
602 | 654 | // @formatter:on |
603 | 655 | } |
|
0 commit comments