diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java index d962b5d37f0..dac8ef4529f 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java @@ -19,19 +19,13 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.function.Predicate; import java.util.stream.Collectors; +import org.springframework.ai.chat.client.advisor.api.*; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; -import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; -import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; -import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; -import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.document.Document; @@ -201,7 +195,7 @@ public Flux aroundStream(AdvisedRequest advisedRequest, StreamA // @formatter:on return advisedResponses.map(ar -> { - if (onFinishReason().test(ar)) { + if (AdvisedResponseStreamUtils.onFinishReason().test(ar)) { ar = after(ar); } return ar; @@ -260,16 +254,6 @@ protected Filter.Expression doGetFilterExpression(Map context) { } - private Predicate onFinishReason() { - return advisedResponse -> advisedResponse.response() - .getResults() - .stream() - .filter(result -> result != null && result.getMetadata() != null - && StringUtils.hasText(result.getMetadata().getFinishReason())) - .findFirst() - .isPresent(); - } - public static final class Builder { private final VectorStore vectorStore; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponseStreamUtils.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponseStreamUtils.java new file mode 100644 index 00000000000..91bca7581a7 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponseStreamUtils.java @@ -0,0 +1,31 @@ +package org.springframework.ai.chat.client.advisor.api; + +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.util.StringUtils; + +import java.util.function.Predicate; + +/** + * A stream utility class to provide support methods handling {@link AdvisedResponse}. + */ +public final class AdvisedResponseStreamUtils { + + /** + * Returns a predicate that checks whether the provided {@link AdvisedResponse} + * contains a {@link ChatResponse} with at least one result having a non-empty finish + * reason in its metadata. + * @return a {@link Predicate} that evaluates whether the finish reason exists within + * the response metadata. + */ + public static Predicate onFinishReason() { + return advisedResponse -> { + ChatResponse chatResponse = advisedResponse.response(); + return chatResponse != null && chatResponse.getResults() != null + && chatResponse.getResults() + .stream() + .anyMatch(result -> result != null && result.getMetadata() != null + && StringUtils.hasText(result.getMetadata().getFinishReason())); + }; + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/BaseAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/BaseAdvisor.java index 881b6be1017..47778787054 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/BaseAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/BaseAdvisor.java @@ -16,16 +16,12 @@ package org.springframework.ai.chat.client.advisor.api; -import java.util.function.Predicate; - import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Schedulers; -import org.springframework.ai.chat.model.ChatResponse; import org.springframework.util.Assert; -import org.springframework.util.StringUtils; /** * Base advisor that implements common aspects of the {@link CallAroundAdvisor} and @@ -65,24 +61,13 @@ default Flux aroundStream(AdvisedRequest advisedRequest, Stream .flatMapMany(chain::nextAroundStream); return advisedResponses.map(ar -> { - if (onFinishReason().test(ar)) { + if (AdvisedResponseStreamUtils.onFinishReason().test(ar)) { ar = after(ar); } return ar; }).onErrorResume(error -> Flux.error(new IllegalStateException("Stream processing failed", error))); } - private Predicate onFinishReason() { - return advisedResponse -> { - ChatResponse chatResponse = advisedResponse.response(); - return chatResponse != null && chatResponse.getResults() != null - && chatResponse.getResults() - .stream() - .anyMatch(result -> result != null && result.getMetadata() != null - && StringUtils.hasText(result.getMetadata().getFinishReason())); - }; - } - @Override default String getName() { return this.getClass().getSimpleName(); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponseStreamUtilsTest.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponseStreamUtilsTest.java new file mode 100644 index 00000000000..a02f1c8e7fc --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponseStreamUtilsTest.java @@ -0,0 +1,82 @@ +package org.springframework.ai.chat.client.advisor.api; + +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; + +import java.util.List; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; + +/** + * Unit tests for {@link AdvisedResponseStreamUtils}. + * + * @author ghdcksgml1 + */ +class AdvisedResponseStreamUtilsTest { + + @Nested + class OnFinishReason { + + @Test + void whenChatResponseIsNullThenReturnFalse() { + AdvisedResponse response = mock(AdvisedResponse.class); + given(response.response()).willReturn(null); + + boolean result = AdvisedResponseStreamUtils.onFinishReason().test(response); + + assertFalse(result); + } + + @Test + void whenChatResponseResultsIsNullThenReturnFalse() { + AdvisedResponse response = mock(AdvisedResponse.class); + ChatResponse chatResponse = mock(ChatResponse.class); + + given(chatResponse.getResults()).willReturn(null); + given(response.response()).willReturn(chatResponse); + + boolean result = AdvisedResponseStreamUtils.onFinishReason().test(response); + + assertFalse(result); + } + + @Test + void whenChatIsRunningThenReturnFalse() { + AdvisedResponse response = mock(AdvisedResponse.class); + ChatResponse chatResponse = mock(ChatResponse.class); + + Generation generation = new Generation(new AssistantMessage("running.."), ChatGenerationMetadata.NULL); + + given(chatResponse.getResults()).willReturn(List.of(generation)); + given(response.response()).willReturn(chatResponse); + + boolean result = AdvisedResponseStreamUtils.onFinishReason().test(response); + + assertFalse(result); + } + + @Test + void whenChatIsStopThenReturnTrue() { + AdvisedResponse response = mock(AdvisedResponse.class); + ChatResponse chatResponse = mock(ChatResponse.class); + + Generation generation = new Generation(new AssistantMessage("finish."), + ChatGenerationMetadata.builder().finishReason("STOP").build()); + + given(chatResponse.getResults()).willReturn(List.of(generation)); + given(response.response()).willReturn(chatResponse); + + boolean result = AdvisedResponseStreamUtils.onFinishReason().test(response); + + assertTrue(result); + } + + } + +}