Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,13 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
import java.util.stream.Collectors;

import org.springframework.ai.chat.client.advisor.api.*;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;

import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor;
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.document.Document;
Expand Down Expand Up @@ -201,7 +195,7 @@ public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamA
// @formatter:on

return advisedResponses.map(ar -> {
if (onFinishReason().test(ar)) {
if (AdvisedResponseStreamUtils.onFinishReason().test(ar)) {
ar = after(ar);
}
return ar;
Expand Down Expand Up @@ -260,16 +254,6 @@ protected Filter.Expression doGetFilterExpression(Map<String, Object> context) {

}

private Predicate<AdvisedResponse> onFinishReason() {
return advisedResponse -> advisedResponse.response()
.getResults()
.stream()
.filter(result -> result != null && result.getMetadata() != null
&& StringUtils.hasText(result.getMetadata().getFinishReason()))
.findFirst()
.isPresent();
}

public static final class Builder {

private final VectorStore vectorStore;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package org.springframework.ai.chat.client.advisor.api;

import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.util.StringUtils;

import java.util.function.Predicate;

/**
* A stream utility class to provide support methods handling {@link AdvisedResponse}.
*/
public final class AdvisedResponseStreamUtils {

/**
* Returns a predicate that checks whether the provided {@link AdvisedResponse}
* contains a {@link ChatResponse} with at least one result having a non-empty finish
* reason in its metadata.
* @return a {@link Predicate} that evaluates whether the finish reason exists within
* the response metadata.
*/
public static Predicate<AdvisedResponse> onFinishReason() {
return advisedResponse -> {
ChatResponse chatResponse = advisedResponse.response();
return chatResponse != null && chatResponse.getResults() != null
&& chatResponse.getResults()
.stream()
.anyMatch(result -> result != null && result.getMetadata() != null
&& StringUtils.hasText(result.getMetadata().getFinishReason()));
};
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,12 @@

package org.springframework.ai.chat.client.advisor.api;

import java.util.function.Predicate;

import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Scheduler;
import reactor.core.scheduler.Schedulers;

import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

/**
* Base advisor that implements common aspects of the {@link CallAroundAdvisor} and
Expand Down Expand Up @@ -65,24 +61,13 @@ default Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, Stream
.flatMapMany(chain::nextAroundStream);

return advisedResponses.map(ar -> {
if (onFinishReason().test(ar)) {
if (AdvisedResponseStreamUtils.onFinishReason().test(ar)) {
ar = after(ar);
}
return ar;
}).onErrorResume(error -> Flux.error(new IllegalStateException("Stream processing failed", error)));
}

private Predicate<AdvisedResponse> onFinishReason() {
return advisedResponse -> {
ChatResponse chatResponse = advisedResponse.response();
return chatResponse != null && chatResponse.getResults() != null
&& chatResponse.getResults()
.stream()
.anyMatch(result -> result != null && result.getMetadata() != null
&& StringUtils.hasText(result.getMetadata().getFinishReason()));
};
}

@Override
default String getName() {
return this.getClass().getSimpleName();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package org.springframework.ai.chat.client.advisor.api;

import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;

import java.util.List;

import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;

/**
* Unit tests for {@link AdvisedResponseStreamUtils}.
*
* @author ghdcksgml1
*/
class AdvisedResponseStreamUtilsTest {

@Nested
class OnFinishReason {

@Test
void whenChatResponseIsNullThenReturnFalse() {
AdvisedResponse response = mock(AdvisedResponse.class);
given(response.response()).willReturn(null);

boolean result = AdvisedResponseStreamUtils.onFinishReason().test(response);

assertFalse(result);
}

@Test
void whenChatResponseResultsIsNullThenReturnFalse() {
AdvisedResponse response = mock(AdvisedResponse.class);
ChatResponse chatResponse = mock(ChatResponse.class);

given(chatResponse.getResults()).willReturn(null);
given(response.response()).willReturn(chatResponse);

boolean result = AdvisedResponseStreamUtils.onFinishReason().test(response);

assertFalse(result);
}

@Test
void whenChatIsRunningThenReturnFalse() {
AdvisedResponse response = mock(AdvisedResponse.class);
ChatResponse chatResponse = mock(ChatResponse.class);

Generation generation = new Generation(new AssistantMessage("running.."), ChatGenerationMetadata.NULL);

given(chatResponse.getResults()).willReturn(List.of(generation));
given(response.response()).willReturn(chatResponse);

boolean result = AdvisedResponseStreamUtils.onFinishReason().test(response);

assertFalse(result);
}

@Test
void whenChatIsStopThenReturnTrue() {
AdvisedResponse response = mock(AdvisedResponse.class);
ChatResponse chatResponse = mock(ChatResponse.class);

Generation generation = new Generation(new AssistantMessage("finish."),
ChatGenerationMetadata.builder().finishReason("STOP").build());

given(chatResponse.getResults()).willReturn(List.of(generation));
given(response.response()).willReturn(chatResponse);

boolean result = AdvisedResponseStreamUtils.onFinishReason().test(response);

assertTrue(result);
}

}

}