diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiPaymentTransactionIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiPaymentTransactionIT.java index 5b0cd73e441..26635b2fde9 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiPaymentTransactionIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiPaymentTransactionIT.java @@ -28,11 +28,11 @@ import org.junit.jupiter.params.provider.ValueSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.ai.chat.client.AdvisedRequest; import org.springframework.ai.chat.client.ChatClient; -import org.springframework.ai.chat.client.advisor.api.RequestAdvisor; -import org.springframework.ai.chat.client.advisor.api.ResponseAdvisor; -import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; +import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.model.function.FunctionCallbackContext; import org.springframework.ai.openai.OpenAiChatModel; @@ -65,7 +65,7 @@ public class OpenAiPaymentTransactionIT { record TransactionStatusResponse(String id, String status) { } - private static class LoggingAdvisor implements RequestAdvisor, ResponseAdvisor { + private static class LoggingAdvisor implements CallAroundAdvisor { private final Logger logger = LoggerFactory.getLogger(LoggingAdvisor.class); @@ -74,7 +74,23 @@ public String getName() { } @Override - public AdvisedRequest adviseRequest(AdvisedRequest request, Map context) { + public int getOrder() { + return 0; + } + + @Override + public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { + + advisedRequest = this.before(advisedRequest); + + AdvisedResponse advisedResponse = chain.nextAroundCall(advisedRequest); + + this.observeAfter(advisedResponse); + + return advisedResponse; + } + + private AdvisedRequest before(AdvisedRequest request) { logger.info("System text: \n" + request.systemText()); logger.info("System params: " + request.systemParams()); logger.info("User text: \n" + request.userText()); @@ -86,10 +102,8 @@ public AdvisedRequest adviseRequest(AdvisedRequest request, Map return request; } - @Override - public ChatResponse adviseResponse(ChatResponse response, Map context) { - logger.info("Response: " + response); - return response; + private void observeAfter(AdvisedResponse advisedResponse) { + logger.info("Response: " + advisedResponse.response()); } } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientIT.java index 3d46d1f6383..5dfaa0be5ba 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientIT.java @@ -22,6 +22,7 @@ import java.util.Map; import java.util.stream.Collectors; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.params.ParameterizedTest; @@ -64,6 +65,38 @@ class OpenAiChatClientIT extends AbstractIT { record ActorsFilms(String actor, List movies) { } + @Test + @Disabled("Although the Re2 advisor improves the response correctness it is not always guarantied to work.") + void re2() { + // .user(" Could Scooby Doo fit in a Kangaroo Pouch? Choices: (A) Yes (B) No") + // .user("Roger has 5 tennis balls. He buys 2 more cans of tennis " + + // "balls. Each can has 3 tennis balls. How many tennis balls " + + // "does he have now?") + + String REASON_QUESTION = """ + What do these words have in common? + Freight Stone Often Canine. + """; + + // @formatter:off + ChatClient chatClient = ChatClient.builder(chatModel) + .defaultOptions(OpenAiChatOptions.builder() + .withModel(OpenAiApi.ChatModel.GPT_4_O.getValue()).build()) + .defaultUser(REASON_QUESTION) + .build(); + + String response = chatClient.prompt() + .advisors(new ReReadingAdvisor()) + .call() + .content(); + // @formatter:on + + logger.info("" + response); + assertThat(response.toLowerCase().replace("(", " ").replace(")", " ").replace("\"", " ").replace("\"", " ")) + .contains(" eight", " one", " ten", " nine"); + + } + @Test void call() { diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/ReReadingAdvisor.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/ReReadingAdvisor.java new file mode 100644 index 00000000000..47d3d2af7df --- /dev/null +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/ReReadingAdvisor.java @@ -0,0 +1,94 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.openai.chat.client; + +import java.util.HashMap; +import java.util.Map; + +import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; +import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; +import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; +import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; + +import reactor.core.publisher.Flux; + +/** + * Drawing inspiration from the human strategy of re-reading, this advisor implements a + * re-reading strategy for LLM reasoning, dubbed RE2, to enhance understanding in the + * input phase. Based on the article: + * Re-Reading Improves Reasoning in Large + * Language Models + * + * @author Christian Tzolov + * @since 1.0.0 + */ +public class ReReadingAdvisor implements CallAroundAdvisor, StreamAroundAdvisor { + + private static final String DEFAULT_RE2_ADVISE_TEMPLATE = """ + {re2_input_query} + Read the question again: {re2_input_query} + """; + + private final String re2AdviseTemplate; + + private int order = 0; + + public ReReadingAdvisor() { + this(DEFAULT_RE2_ADVISE_TEMPLATE); + } + + public ReReadingAdvisor(String re2AdviseTemplate) { + this.re2AdviseTemplate = re2AdviseTemplate; + } + + public String getName() { + return this.getClass().getSimpleName(); + } + + private AdvisedRequest before(AdvisedRequest advisedRequest) { + + Map advisedUserParams = new HashMap<>(advisedRequest.userParams()); + advisedUserParams.put("re2_input_query", advisedRequest.userText()); + + return AdvisedRequest.from(advisedRequest) + .withUserText(this.re2AdviseTemplate) + .withUserParams(advisedUserParams) + .build(); + } + + @Override + public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { + return chain.nextAroundCall(this.before(advisedRequest)); + } + + @Override + public Flux aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) { + return chain.nextAroundStream(this.before(advisedRequest)); + } + + @Override + public int getOrder() { + return this.order; + } + + public ReReadingAdvisor withOrder(int order) { + this.order = order; + return this; + } + +} \ No newline at end of file diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiPaymentTransactionIT.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiPaymentTransactionIT.java index 3954019d4cc..0572d4151b9 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiPaymentTransactionIT.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiPaymentTransactionIT.java @@ -28,11 +28,11 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.ai.chat.client.AdvisedRequest; import org.springframework.ai.chat.client.ChatClient; -import org.springframework.ai.chat.client.advisor.api.RequestAdvisor; -import org.springframework.ai.chat.client.advisor.api.ResponseAdvisor; -import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; +import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; import org.springframework.ai.model.function.FunctionCallbackContext; import org.springframework.ai.model.function.FunctionCallbackWrapper.Builder.SchemaType; import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel; @@ -65,7 +65,7 @@ public class VertexAiGeminiPaymentTransactionIT { record TransactionStatusResponse(String id, String status) { } - private static class LoggingAdvisor implements RequestAdvisor, ResponseAdvisor { + private static class LoggingAdvisor implements CallAroundAdvisor { private final Logger logger = LoggerFactory.getLogger(LoggingAdvisor.class); @@ -75,7 +75,18 @@ public String getName() { } @Override - public AdvisedRequest adviseRequest(AdvisedRequest request, Map context) { + public int getOrder() { + return 0; + } + + @Override + public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { + var response = chain.nextAroundCall(before(advisedRequest)); + observeAfter(response); + return response; + } + + private AdvisedRequest before(AdvisedRequest request) { logger.info("System text: \n" + request.systemText()); logger.info("System params: " + request.systemParams()); logger.info("User text: \n" + request.userText()); @@ -87,10 +98,8 @@ public AdvisedRequest adviseRequest(AdvisedRequest request, Map return request; } - @Override - public ChatResponse adviseResponse(ChatResponse response, Map context) { - logger.info("Response: " + response); - return response; + private void observeAfter(AdvisedResponse advisedResponse) { + logger.info("Response: " + advisedResponse.response()); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java index 910d8c0d014..e470b4b02ac 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java @@ -21,43 +21,39 @@ import java.nio.charset.Charset; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.function.Consumer; import org.springframework.ai.chat.client.advisor.DefaultAroundAdvisorChain; +import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; +import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; import org.springframework.ai.chat.client.advisor.api.Advisor; -import org.springframework.ai.chat.client.advisor.api.AroundAdvisorChain; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; -import org.springframework.ai.chat.client.advisor.api.RequestAdvisor; -import org.springframework.ai.chat.client.advisor.api.ResponseAdvisor; -import org.springframework.ai.chat.client.advisor.api.ResponseAdvisor.StreamResponseMode; import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; -import org.springframework.ai.chat.client.advisor.observation.AdvisorObservableHelper; +import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; import org.springframework.ai.chat.client.observation.ChatClientObservationContext; import org.springframework.ai.chat.client.observation.ChatClientObservationConvention; import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation; import org.springframework.ai.chat.client.observation.DefaultChatClientObservationConvention; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.MessageType; -import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.MessageAggregator; import org.springframework.ai.chat.model.StreamingChatModel; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.converter.StructuredOutputConverter; import org.springframework.ai.model.Media; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackWrapper; -import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.core.Ordered; import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.io.Resource; import org.springframework.util.Assert; @@ -80,6 +76,7 @@ * @author Josh Long * @author Arjen Poutsma * @author Soby Chacko + * @author Dariusz Jedrzejczyk * @since 1.0.0 */ public class DefaultChatClient implements ChatClient { @@ -379,41 +376,14 @@ private ChatResponse doGetObservableChatResponse(DefaultChatClientRequestSpec in private ChatResponse doGetChatResponse(DefaultChatClientRequestSpec inputRequestSpec, String formatParam, Observation parentObservation) { - Map advisorContext = new ConcurrentHashMap<>(); - if (StringUtils.hasText(formatParam)) { - advisorContext.put("formatParam", formatParam); - } - advisorContext.putAll(inputRequestSpec.getAdvisorParams()); - - // DefaultChatClientRequestSpec advisedRequestSpec = inputRequestSpec; - AdvisedRequest advisedRequest = toAdvisedRequest(inputRequestSpec); - if (!CollectionUtils.isEmpty(inputRequestSpec.advisors)) { - - // Apply the Request advisors - var currentAdvisors = new ArrayList<>( - AdvisorObservableHelper.extractRequestAdvisors(inputRequestSpec.advisors)); - for (RequestAdvisor advisor : currentAdvisors) { - advisedRequest = AdvisorObservableHelper.adviseRequest(parentObservation, advisor, advisedRequest, - advisorContext); - } - } + AdvisedRequest advisedRequest = toAdvisedRequest(inputRequestSpec, formatParam); // Apply the around advisor chain that terminates with the, last, model call // advisor. - ChatResponse advisedResponse = inputRequestSpec.aroundAdvisorChain.nextAroundCall(advisedRequest, - advisorContext); - - // Apply the Response advisors. - if (!CollectionUtils.isEmpty(inputRequestSpec.getAdvisors())) { - var currentAdvisors = new ArrayList<>( - AdvisorObservableHelper.extractResponseAdvisors(inputRequestSpec.getAdvisors())); - for (ResponseAdvisor advisor : currentAdvisors) { - advisedResponse = AdvisorObservableHelper.adviseResponse(parentObservation, advisor, - advisedResponse, advisorContext); - } - } + AdvisedResponse advisedResponse = inputRequestSpec.aroundAdvisorChainBuilder.build() + .nextAroundCall(advisedRequest); - return advisedResponse; + return advisedResponse.response(); } public ChatResponse chatResponse() { @@ -426,46 +396,6 @@ public String content() { } - private static Prompt toPrompt(AdvisedRequest advisedRequest, String formatParam) { - - var messages = new ArrayList(advisedRequest.messages()); - - String processedSystemText = advisedRequest.systemText(); - if (StringUtils.hasText(processedSystemText)) { - if (!CollectionUtils.isEmpty(advisedRequest.systemParams())) { - processedSystemText = new PromptTemplate(processedSystemText, advisedRequest.systemParams()).render(); - } - messages.add(new SystemMessage(processedSystemText)); - } - - var processedUserText = StringUtils.hasText(formatParam) - ? advisedRequest.userText() + System.lineSeparator() + "{spring_ai_soc_format}" - : advisedRequest.userText(); - - if (StringUtils.hasText(processedUserText)) { - - Map userParams = new HashMap<>(advisedRequest.userParams()); - if (StringUtils.hasText(formatParam)) { - userParams.put("spring_ai_soc_format", formatParam); - } - if (!CollectionUtils.isEmpty(userParams)) { - processedUserText = new PromptTemplate(processedUserText, userParams).render(); - } - messages.add(new UserMessage(processedUserText, advisedRequest.media())); - } - - if (advisedRequest.chatOptions() instanceof FunctionCallingOptions functionCallingOptions) { - if (!advisedRequest.functionNames().isEmpty()) { - functionCallingOptions.setFunctions(new HashSet<>(advisedRequest.functionNames())); - } - if (!advisedRequest.functionCallbacks().isEmpty()) { - functionCallingOptions.setFunctionCallbacks(advisedRequest.functionCallbacks()); - } - } - - return new Prompt(messages, advisedRequest.chatOptions()); - } - public static class DefaultStreamResponseSpec implements StreamResponseSpec { private final DefaultChatClientRequestSpec request; @@ -487,119 +417,22 @@ private Flux doGetObservableFluxChatResponse(DefaultChatClientRequ observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)) .start(); + var initialAdvisedRequest = toAdvisedRequest(inputRequest, ""); + // @formatter:off - return doGetFluxChatResponse(inputRequest, observation) - .doOnError(observation::error) - .doFinally(s -> { - observation.stop(); - }) - .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); + // Apply the around advisor chain that terminates with the, last, + // model call advisor. + Flux stream = inputRequest.aroundAdvisorChainBuilder.build().nextAroundStream(initialAdvisedRequest); + + return stream + .map(AdvisedResponse::response) + .doOnError(observation::error) + .doFinally(s -> observation.stop()) + .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); // @formatter:on }); } - record AdvisedRequestWithContext(AdvisedRequest request, Map advisorContext) { - } - - private Flux doGetFluxChatResponse(DefaultChatClientRequestSpec inputRequest, - Observation parentObservation) { - - Map advisorContext = new ConcurrentHashMap<>(inputRequest.getAdvisorParams()); - - var reqWithContext = new AdvisedRequestWithContext(toAdvisedRequest(inputRequest), advisorContext); - - return Flux.fromIterable(AdvisorObservableHelper.extractRequestAdvisors(inputRequest.advisors)) - .transformDeferredContextual((f, ctx) -> f - // This allows us to call blocking code in reduce - .publishOn(Schedulers.boundedElastic()) - .reduce(reqWithContext, (rwc, advisor) -> { - // Apply the Request advisors - AdvisedRequest advisedRequest = AdvisorObservableHelper.adviseRequest(parentObservation, - advisor, rwc.request, rwc.advisorContext); - return new AdvisedRequestWithContext(advisedRequest, rwc.advisorContext); - })) - .single() - .flatMapMany(rwc -> { - - // Apply the around advisor chain that terminates with the, last, - // model call advisor. - Flux advisedResponse = inputRequest.aroundAdvisorChain.nextAroundStream(rwc.request, - rwc.advisorContext); - - // Apply the Response advisors - if (!CollectionUtils.isEmpty(inputRequest.getAdvisors())) { - - var responseAdvisors = new ArrayList<>( - AdvisorObservableHelper.extractResponseAdvisors(inputRequest.getAdvisors())); - - List perElementResponseAdvisors = responseAdvisors.stream() - .filter(a -> a.getStreamResponseMode() == StreamResponseMode.PER_ELEMENT) - .toList(); - - List onFinishElementResponseAdvisors = responseAdvisors.stream() - .filter(a -> a.getStreamResponseMode() == StreamResponseMode.ON_FINISH_ELEMENT) - .toList(); - - // PER_ELEMENT and ON_FINISH_ELEMENT - advisedResponse = advisedResponse.map(response -> { - // PER_ELEMENT - if (!CollectionUtils.isEmpty(perElementResponseAdvisors)) { - for (ResponseAdvisor advisor : perElementResponseAdvisors) { - response = AdvisorObservableHelper.adviseResponse(parentObservation, advisor, - response, rwc.advisorContext); - } - } - // ON_FINISH_ELEMENT - if (!CollectionUtils.isEmpty(onFinishElementResponseAdvisors)) { - for (ResponseAdvisor advisor : onFinishElementResponseAdvisors) { - boolean withFinishReason = response.getResults() - .stream() - .filter(result -> result != null && result.getMetadata() != null - && StringUtils.hasText(result.getMetadata().getFinishReason())) - .findFirst() - .isPresent(); - - if (withFinishReason) { - response = AdvisorObservableHelper.adviseResponse(parentObservation, advisor, - response, advisorContext); - } - } - } - return response; - }); - - // CUSTOM - // TODO: how to pass the parentObservation to the custom response - // advisor? - List customResponseAdvisors = responseAdvisors.stream() - .filter(a -> a.getStreamResponseMode() == StreamResponseMode.CUSTOM) - .toList(); - if (!CollectionUtils.isEmpty(customResponseAdvisors)) { - for (ResponseAdvisor advisor : customResponseAdvisors) { - advisedResponse = advisor.adviseResponse(advisedResponse, rwc.advisorContext); - } - } - - // AGGREGATE - List aggregateResponseAdvisors = responseAdvisors.stream() - .filter(a -> a.getStreamResponseMode() == StreamResponseMode.AGGREGATE) - .toList(); - - if (!CollectionUtils.isEmpty(aggregateResponseAdvisors)) { - advisedResponse = new MessageAggregator().aggregate(advisedResponse, chatResponse -> { - for (ResponseAdvisor advisor : aggregateResponseAdvisors) { - AdvisorObservableHelper.adviseResponse(parentObservation, advisor, chatResponse, - advisorContext); - } - }); - } - } - - return advisedResponse; - - }); - } - public Flux chatResponse() { return doGetObservableFluxChatResponse(this.request); } @@ -646,11 +479,7 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe private final Map advisorParams = new HashMap<>(); - private final DefaultAroundAdvisorChain aroundAdvisorChain; - - public AroundAdvisorChain getAroundAdvisorChain() { - return this.aroundAdvisorChain; - } + private final DefaultAroundAdvisorChain.Builder aroundAdvisorChainBuilder; private ObservationRegistry getObservationRegistry() { return this.observationRegistry; @@ -735,39 +564,50 @@ public DefaultChatClientRequestSpec(ChatModel chatModel, String userText, Map adviceContext, - AroundAdvisorChain chain) { - String formatParam = (String) adviceContext.get("formatParam"); - return chatModel.call(toPrompt(advisedRequest, formatParam)); - } - }) - .push(new StreamAroundAdvisor() { - - @Override - public String getName() { - return StreamAroundAdvisor.class.getSimpleName(); - } - @Override - public Flux aroundStream(AdvisedRequest advisedRequest, Map adviceContext, - AroundAdvisorChain chain) { - return chatModel.stream(toPrompt(advisedRequest, null)); - } - }) - .pushAll(this.advisors) - .build(); - // @formatter:on + // @formatter:off + // At the stack bottom add the non-streaming and streaming model call advisors. + // They play the role of the last advisor in the around advisor chain. + this.advisors.add(new CallAroundAdvisor() { + + @Override + public String getName() { + return CallAroundAdvisor.class.getSimpleName(); + } + + @Override + public int getOrder() { + return Ordered.LOWEST_PRECEDENCE; + } + + @Override + public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { + return new AdvisedResponse(chatModel.call(advisedRequest.toPrompt()), Collections.unmodifiableMap(advisedRequest.adviseContext())); + } + }); + + this.advisors.add(new StreamAroundAdvisor() { + + @Override + public String getName() { + return StreamAroundAdvisor.class.getSimpleName(); + } + + @Override + public int getOrder() { + return Ordered.LOWEST_PRECEDENCE; + } + + @Override + public Flux aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) { + return chatModel.stream(advisedRequest.toPrompt()) + .map( chatResponse -> new AdvisedResponse(chatResponse, Collections.unmodifiableMap(advisedRequest.adviseContext()))) + .publishOn(Schedulers.boundedElastic());// TODO add option to disable. + } + }); + // @formatter:on + + this.aroundAdvisorChainBuilder = DefaultAroundAdvisorChain.builder(observationRegistry) + .pushAll(this.advisors); } /** @@ -797,21 +637,21 @@ public ChatClientRequestSpec advisors(Consumer consumer) consumer.accept(as); this.advisorParams.putAll(as.getParams()); this.advisors.addAll(as.getAdvisors()); - this.aroundAdvisorChain.pushAll(as.getAdvisors()); + this.aroundAdvisorChainBuilder.pushAll(as.getAdvisors()); return this; } public ChatClientRequestSpec advisors(Advisor... advisors) { Assert.notNull(advisors, "the advisors must be non-null"); this.advisors.addAll(Arrays.asList(advisors)); - this.aroundAdvisorChain.pushAll(Arrays.asList(advisors)); + this.aroundAdvisorChainBuilder.pushAll(Arrays.asList(advisors)); return this; } public ChatClientRequestSpec advisors(List advisors) { Assert.notNull(advisors, "the advisors must be non-null"); this.advisors.addAll(advisors); - this.aroundAdvisorChain.pushAll(advisors); + this.aroundAdvisorChainBuilder.pushAll(advisors); return this; } @@ -944,11 +784,16 @@ public StreamResponseSpec stream() { } - private static AdvisedRequest toAdvisedRequest(DefaultChatClientRequestSpec inputRequest) { + private static AdvisedRequest toAdvisedRequest(DefaultChatClientRequestSpec inputRequest, String formatParam) { + Map advisorContext = new ConcurrentHashMap<>(inputRequest.getAdvisorParams()); + if (StringUtils.hasText(formatParam)) { + advisorContext.put("formatParam", formatParam); + } + return new AdvisedRequest(inputRequest.chatModel, inputRequest.userText, inputRequest.systemText, inputRequest.chatOptions, inputRequest.media, inputRequest.functionNames, inputRequest.functionCallbacks, inputRequest.messages, inputRequest.userParams, - inputRequest.systemParams, inputRequest.advisors, inputRequest.advisorParams); + inputRequest.systemParams, inputRequest.advisors, inputRequest.advisorParams, advisorContext); } public static DefaultChatClientRequestSpec toDefaultChatClientRequestSpec(AdvisedRequest advisedRequest, diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/RequestResponseAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/RequestResponseAdvisor.java index ab1ca3352d5..a0bb49b0b48 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/RequestResponseAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/RequestResponseAdvisor.java @@ -16,41 +16,78 @@ package org.springframework.ai.chat.client; -import org.springframework.ai.chat.client.advisor.api.ResponseAdvisor; - +import java.util.Collections; +import java.util.HashMap; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; -import org.springframework.ai.chat.client.advisor.api.RequestAdvisor; +import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; +import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; +import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; +import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; +import reactor.core.publisher.Flux; + /** * Advisor called before and after the {@link ChatModel#call(Prompt)} and * {@link ChatModel#stream(Prompt)} methods calls. The {@link ChatClient} maintains a * chain of advisors with shared advise context. * - * @deprecated since 1.0.0 please use {@link RequestAdvisor}, {@link ResponseAdvisor} - * instead. + * @deprecated since 1.0.0 please use {@link CallAroundAdvisor} or + * {@link StreamAroundAdvisor} instead. * @author Christian Tzolov * @since 1.0.0 */ @Deprecated -public interface RequestResponseAdvisor extends RequestAdvisor, ResponseAdvisor { +public interface RequestResponseAdvisor extends CallAroundAdvisor, StreamAroundAdvisor { @Override default String getName() { return this.getClass().getSimpleName(); } - @Override default AdvisedRequest adviseRequest(AdvisedRequest request, Map adviseContext) { return request; } - @Override default ChatResponse adviseResponse(ChatResponse response, Map adviseContext) { return response; } + default Flux adviseResponse(Flux fluxResponse, Map context) { + return fluxResponse; + } + + @Override + default AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { + var context = new HashMap<>(advisedRequest.adviseContext()); + var requestPrim = adviseRequest(advisedRequest, context); + advisedRequest = AdvisedRequest.from(requestPrim) + .withAdviseContext(Collections.unmodifiableMap(context)) + .build(); + + var advisedResponse = chain.nextAroundCall(advisedRequest); + + context = new HashMap<>(advisedResponse.adviseContext()); + var chatResponse = adviseResponse(advisedResponse.response(), context); + return new AdvisedResponse(chatResponse, Collections.unmodifiableMap(context)); + } + + default Flux aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) { + + ConcurrentHashMap context = new ConcurrentHashMap<>(advisedRequest.adviseContext()); + + advisedRequest = adviseRequest(advisedRequest, context); + + var advisedResponseStream = chain.nextAroundStream(advisedRequest); + + return this.adviseResponse(advisedResponseStream.map(ar -> ar.response()), context) + .map(chatResponse -> new AdvisedResponse(chatResponse, context)); + } + } \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java index abfb41db456..c9453f034f4 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java @@ -17,19 +17,28 @@ package org.springframework.ai.chat.client.advisor; import java.util.Map; - -import org.springframework.ai.chat.client.advisor.api.ResponseAdvisor; -import org.springframework.ai.chat.client.advisor.api.RequestAdvisor; +import java.util.function.Function; + +import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; +import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; +import org.springframework.ai.chat.client.advisor.api.Advisor; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; +import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; +import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; import org.springframework.util.Assert; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; + /** * Abstract class that serves as a base for chat memory advisors. * * @param the type of the chat memory. * @author Christian Tzolov - * @since 1.0.0 M1 + * @since 1.0.0 */ -public abstract class AbstractChatMemoryAdvisor implements RequestAdvisor, ResponseAdvisor { +public abstract class AbstractChatMemoryAdvisor implements CallAroundAdvisor, StreamAroundAdvisor { public static final String CHAT_MEMORY_CONVERSATION_ID_KEY = "chat_memory_conversation_id"; @@ -45,11 +54,14 @@ public abstract class AbstractChatMemoryAdvisor implements RequestAdvisor, Re protected final int defaultChatMemoryRetrieveSize; + private final boolean protectFromBlocking; + public AbstractChatMemoryAdvisor(T chatMemory) { - this(chatMemory, DEFAULT_CHAT_MEMORY_CONVERSATION_ID, DEFAULT_CHAT_MEMORY_RESPONSE_SIZE); + this(chatMemory, DEFAULT_CHAT_MEMORY_CONVERSATION_ID, DEFAULT_CHAT_MEMORY_RESPONSE_SIZE, true); } - public AbstractChatMemoryAdvisor(T chatMemory, String defaultConversationId, int defaultChatMemoryRetrieveSize) { + public AbstractChatMemoryAdvisor(T chatMemory, String defaultConversationId, int defaultChatMemoryRetrieveSize, + boolean protectFromBlocking) { Assert.notNull(chatMemory, "The chatMemory must not be null!"); Assert.hasText(defaultConversationId, "The conversationId must not be empty!"); @@ -58,6 +70,7 @@ public AbstractChatMemoryAdvisor(T chatMemory, String defaultConversationId, int this.chatMemoryStore = chatMemory; this.defaultConversationId = defaultConversationId; this.defaultChatMemoryRetrieveSize = defaultChatMemoryRetrieveSize; + this.protectFromBlocking = protectFromBlocking; } @Override @@ -66,8 +79,12 @@ public String getName() { } @Override - public StreamResponseMode getStreamResponseMode() { - return StreamResponseMode.AGGREGATE; + public int getOrder() { + // The (Ordered.HIGHEST_PRECEDENCE + 1000) value ensures this order has lower + // priority (e.g. precedences) than the internal Spring AI advisors. It leaves + // room (1000 slots) for the user to plug in their own advisors with higher + // priority. + return Advisor.DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER; } protected T getChatMemoryStore() { @@ -86,4 +103,20 @@ protected int doGetChatMemoryRetrieveSize(Map context) { : this.defaultChatMemoryRetrieveSize; } + protected Flux doNextWithProtectFromBlockingBefore(AdvisedRequest advisedRequest, + StreamAroundAdvisorChain chain, Function beforeAdvise) { + + // This can be executed by both blocking and non-blocking Threads + // E.g. a command line or Tomcat blocking Thread implementation + // or by a WebFlux dispatch in a non-blocking manner. + return (this.protectFromBlocking) ? + // @formatter:off + Mono.just(advisedRequest) + .publishOn(Schedulers.boundedElastic()) + .map(beforeAdvise) + .flatMapMany(request -> chain.nextAroundStream(request)) + : chain.nextAroundStream(beforeAdvise.apply(advisedRequest)); + } + + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChain.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChain.java index a3b56b75f11..c8fff5b2eac 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChain.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChain.java @@ -1,26 +1,56 @@ +/* +* Copyright 2024 - 2024 the original author or authors. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* https://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ package org.springframework.ai.chat.client.advisor; -import java.util.ArrayDeque; +import java.util.ArrayList; import java.util.Deque; import java.util.List; -import java.util.Map; +import java.util.concurrent.ConcurrentLinkedDeque; -import org.springframework.ai.chat.client.AdvisedRequest; +import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; +import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; import org.springframework.ai.chat.client.advisor.api.Advisor; -import org.springframework.ai.chat.client.advisor.api.AroundAdvisorChain; import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; -import org.springframework.ai.chat.client.advisor.observation.AdvisorObservableHelper; +import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationContext; +import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationConvention; import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationDocumentation; -import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.client.advisor.observation.DefaultAdvisorObservationConvention; +import org.springframework.core.OrderComparator; import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; import io.micrometer.observation.ObservationRegistry; import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import reactor.core.publisher.Flux; -public class DefaultAroundAdvisorChain implements AroundAdvisorChain { +/** + * Implementation of the {@link CallAroundAdvisorChain} and + * {@link StreamAroundAdvisorChain}. Used by the {@link ChatClient} to delegate the call + * to the next {@link CallAroundAdvisor} or {@link StreamAroundAdvisor} in the chain. + * + * @author Christian Tzolov + * @author Dariusz Jedrzejczyk + * @since 1.0.0 + */ +public class DefaultAroundAdvisorChain implements CallAroundAdvisorChain, StreamAroundAdvisorChain { + + public static final AdvisorObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultAdvisorObservationConvention(); private final Deque callAroundAdvisors; @@ -28,50 +58,20 @@ public class DefaultAroundAdvisorChain implements AroundAdvisorChain { private final ObservationRegistry observationRegistry; - public DefaultAroundAdvisorChain(ObservationRegistry observationRegistry) { - this(observationRegistry, new ArrayDeque(), new ArrayDeque()); - } - - public DefaultAroundAdvisorChain(CallAroundAdvisor aroundAdvisor, ObservationRegistry observationRegistry) { - this(observationRegistry, new ArrayDeque(), new ArrayDeque()); - this.push(aroundAdvisor); - } + DefaultAroundAdvisorChain(ObservationRegistry observationRegistry, Deque callAroundAdvisors, + Deque streamAroundAdvisors) { - public DefaultAroundAdvisorChain(ObservationRegistry observationRegistry, - Deque callAroundAdvisors, Deque streamAroundAdvisors) { + Assert.notNull(observationRegistry, "the observationRegistry must be non-null"); Assert.notNull(callAroundAdvisors, "the callAroundAdvisors must be non-null"); + Assert.notNull(streamAroundAdvisors, "the streamAroundAdvisors must be non-null"); + this.observationRegistry = observationRegistry; this.callAroundAdvisors = callAroundAdvisors; this.streamAroundAdvisors = streamAroundAdvisors; } - public DefaultAroundAdvisorChain(ObservationRegistry observationRegistry, List advisors) { - this(observationRegistry); - Assert.notNull(advisors, "the advisors must be non-null"); - advisors.forEach(this::push); - } - - public void pushAll(List advisors) { - Assert.notNull(advisors, "the advisors must be non-null"); - advisors.forEach(this::push); - } - - public void push(Advisor aroundAdvisor) { - - Assert.notNull(aroundAdvisor, "the aroundAdvisor must be non-null"); - - if (aroundAdvisor instanceof CallAroundAdvisor callAroundAdvisor) { - this.callAroundAdvisors.push(callAroundAdvisor); - } - // Note: the advisor can implement both the CallAroundAdvisor and - // StreamAroundAdvisor. - if (aroundAdvisor instanceof StreamAroundAdvisor streamAroundAdvisor) { - this.streamAroundAdvisors.push(streamAroundAdvisor); - } - } - @Override - public ChatResponse nextAroundCall(AdvisedRequest advisedRequest, Map adviceContext) { + public AdvisedResponse nextAroundCall(AdvisedRequest advisedRequest) { if (this.callAroundAdvisors.isEmpty()) { throw new IllegalStateException("No AroundAdvisor available to execute"); @@ -83,20 +83,17 @@ public ChatResponse nextAroundCall(AdvisedRequest advisedRequest, Map observationContext, - this.observationRegistry) - .observe(() -> advisor.aroundCall(advisedRequest, adviceContext, this)); + .observation(null, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry) + .observe(() -> advisor.aroundCall(advisedRequest, this)); } @Override - public Flux nextAroundStream(AdvisedRequest advisedRequest, Map adviceContext) { - + public Flux nextAroundStream(AdvisedRequest advisedRequest) { return Flux.deferContextual(contextView -> { - if (this.streamAroundAdvisors.isEmpty()) { return Flux.error(new IllegalStateException("No AroundAdvisor available to execute")); } @@ -107,21 +104,21 @@ public Flux nextAroundStream(AdvisedRequest advisedRequest, Map observationContext, - this.observationRegistry); + DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry); observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start(); - return advisor.aroundStream(advisedRequest, adviceContext, this) - .doOnError(observation::error) - .doFinally(s -> { - observation.stop(); - }) - .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); + // @formatter:off + return Flux.defer(() -> advisor.aroundStream(advisedRequest, this)) + .doOnError(observation::error) + .doFinally(s -> { + observation.stop(); + }).contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); + // @formatter:on }); } @@ -131,26 +128,67 @@ public static Builder builder(ObservationRegistry observationRegistry) { public static class Builder { - private final DefaultAroundAdvisorChain aroundAdvisorChain; + private final ObservationRegistry observationRegistry; + + private final Deque callAroundAdvisors; + + private final Deque streamAroundAdvisors; public Builder(ObservationRegistry observationRegistry) { - this.aroundAdvisorChain = new DefaultAroundAdvisorChain(observationRegistry); + this.observationRegistry = observationRegistry; + this.callAroundAdvisors = new ConcurrentLinkedDeque<>(); + this.streamAroundAdvisors = new ConcurrentLinkedDeque<>(); } public Builder push(Advisor aroundAdvisor) { Assert.notNull(aroundAdvisor, "the aroundAdvisor must be non-null"); - this.aroundAdvisorChain.push(aroundAdvisor); - return this; + return this.pushAll(List.of(aroundAdvisor)); } - public Builder pushAll(List aroundAdvisors) { - Assert.notNull(aroundAdvisors, "the aroundAdvisors must be non-null"); - this.aroundAdvisorChain.pushAll(aroundAdvisors); + public Builder pushAll(List advisors) { + Assert.notNull(advisors, "the advisors must be non-null"); + if (!CollectionUtils.isEmpty(advisors)) { + List callAroundAdvisors = advisors.stream() + .filter(a -> a instanceof CallAroundAdvisor) + .map(a -> (CallAroundAdvisor) a) + .toList(); + + if (!CollectionUtils.isEmpty(callAroundAdvisors)) { + callAroundAdvisors.forEach(this.callAroundAdvisors::push); + } + + List streamAroundAdvisors = advisors.stream() + .filter(a -> a instanceof StreamAroundAdvisor) + .map(a -> (StreamAroundAdvisor) a) + .toList(); + + if (!CollectionUtils.isEmpty(streamAroundAdvisors)) { + streamAroundAdvisors.forEach(this.streamAroundAdvisors::push); + } + + this.reOrder(); + } return this; } + /** + * (Re)orders the advisors in priority order based on their Ordered attribute. + */ + private void reOrder() { + ArrayList callAdvisors = new ArrayList<>(this.callAroundAdvisors); + OrderComparator.sort(callAdvisors); + this.callAroundAdvisors.clear(); + callAdvisors.forEach(this.callAroundAdvisors::addLast); + + ArrayList streamAdvisors = new ArrayList<>(this.streamAroundAdvisors); + OrderComparator.sort(streamAdvisors); + this.streamAroundAdvisors.clear(); + streamAdvisors.forEach(this.streamAroundAdvisors::addLast); + } + public DefaultAroundAdvisorChain build() { - return this.aroundAdvisorChain; + return new DefaultAroundAdvisorChain(this.observationRegistry, this.callAroundAdvisors, + this.streamAroundAdvisors); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java index 5f125f0bd32..66d4ee34755 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java @@ -18,19 +18,23 @@ import java.util.ArrayList; import java.util.List; -import java.util.Map; -import org.springframework.ai.chat.client.AdvisedRequest; +import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; +import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; +import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; import org.springframework.ai.chat.memory.ChatMemory; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; -import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.MessageAggregator; + +import reactor.core.publisher.Flux; /** * Memory is retrieved added as a collection of messages to the prompt * * @author Christian Tzolov - * @since 1.0.0 M1 + * @since 1.0.0 */ public class MessageChatMemoryAdvisor extends AbstractChatMemoryAdvisor { @@ -39,15 +43,35 @@ public MessageChatMemoryAdvisor(ChatMemory chatMemory) { } public MessageChatMemoryAdvisor(ChatMemory chatMemory, String defaultConversationId, int chatHistoryWindowSize) { - super(chatMemory, defaultConversationId, chatHistoryWindowSize); + super(chatMemory, defaultConversationId, chatHistoryWindowSize, true); + } + + @Override + public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { + + advisedRequest = this.before(advisedRequest); + + AdvisedResponse advisedResponse = chain.nextAroundCall(advisedRequest); + + this.observeAfter(advisedResponse); + + return advisedResponse; } @Override - public AdvisedRequest adviseRequest(AdvisedRequest request, Map context) { + public Flux aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) { + + Flux advisedResponses = this.doNextWithProtectFromBlockingBefore(advisedRequest, chain, + this::before); + + return new MessageAggregator().aggregateAdvisedResponse(advisedResponses, this::observeAfter); + } - String conversationId = this.doGetConversationId(context); + private AdvisedRequest before(AdvisedRequest request) { - int chatMemoryRetrieveSize = this.doGetChatMemoryRetrieveSize(context); + String conversationId = this.doGetConversationId(request.adviseContext()); + + int chatMemoryRetrieveSize = this.doGetChatMemoryRetrieveSize(request.adviseContext()); // 1. Retrieve the chat memory for the current conversation. List memoryMessages = this.getChatMemoryStore().get(conversationId, chatMemoryRetrieveSize); @@ -61,19 +85,20 @@ public AdvisedRequest adviseRequest(AdvisedRequest request, Map // 4. Add the new user input to the conversation memory. UserMessage userMessage = new UserMessage(request.userText(), request.media()); - this.getChatMemoryStore().add(this.doGetConversationId(context), userMessage); + this.getChatMemoryStore().add(this.doGetConversationId(request.adviseContext()), userMessage); return advisedRequest; } - @Override - public ChatResponse adviseResponse(ChatResponse chatResponse, Map context) { - - List assistantMessages = chatResponse.getResults().stream().map(g -> (Message) g.getOutput()).toList(); + private void observeAfter(AdvisedResponse advisedResponse) { - this.getChatMemoryStore().add(this.doGetConversationId(context), assistantMessages); + List assistantMessages = advisedResponse.response() + .getResults() + .stream() + .map(g -> (Message) g.getOutput()) + .toList(); - return chatResponse; + this.getChatMemoryStore().add(this.doGetConversationId(advisedResponse.adviseContext()), assistantMessages); } } \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java index 19af746d607..961bca0f32b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java @@ -21,19 +21,24 @@ import java.util.Map; import java.util.stream.Collectors; -import org.springframework.ai.chat.client.AdvisedRequest; +import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; +import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; +import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; import org.springframework.ai.chat.memory.ChatMemory; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.messages.UserMessage; -import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.MessageAggregator; import org.springframework.ai.model.Content; +import reactor.core.publisher.Flux; + /** * Memory is retrieved added into the prompt's system text. * * @author Christian Tzolov - * @since 1.0.0 M1 + * @since 1.0.0 */ public class PromptChatMemoryAdvisor extends AbstractChatMemoryAdvisor { @@ -61,16 +66,37 @@ public PromptChatMemoryAdvisor(ChatMemory chatMemory, String systemTextAdvise) { public PromptChatMemoryAdvisor(ChatMemory chatMemory, String defaultConversationId, int chatHistoryWindowSize, String systemTextAdvise) { - super(chatMemory, defaultConversationId, chatHistoryWindowSize); + super(chatMemory, defaultConversationId, chatHistoryWindowSize, true); this.systemTextAdvise = systemTextAdvise; } @Override - public AdvisedRequest adviseRequest(AdvisedRequest request, Map context) { + public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { + + advisedRequest = this.before(advisedRequest); + + AdvisedResponse advisedResponse = chain.nextAroundCall(advisedRequest); + + this.observeAfter(advisedResponse); + + return advisedResponse; + } + + @Override + public Flux aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) { + + Flux advisedResponses = this.doNextWithProtectFromBlockingBefore(advisedRequest, chain, + this::before); + + return new MessageAggregator().aggregateAdvisedResponse(advisedResponses, this::observeAfter); + } + + private AdvisedRequest before(AdvisedRequest request) { // 1. Advise system parameters. List memoryMessages = this.getChatMemoryStore() - .get(this.doGetConversationId(context), this.doGetChatMemoryRetrieveSize(context)); + .get(this.doGetConversationId(request.adviseContext()), + this.doGetChatMemoryRetrieveSize(request.adviseContext())); String memory = (memoryMessages != null) ? memoryMessages.stream() .filter(m -> m.getMessageType() == MessageType.USER || m.getMessageType() == MessageType.ASSISTANT) @@ -91,19 +117,20 @@ public AdvisedRequest adviseRequest(AdvisedRequest request, Map // 4. Add the new user input to the conversation memory. UserMessage userMessage = new UserMessage(request.userText(), request.media()); - this.getChatMemoryStore().add(this.doGetConversationId(context), userMessage); + this.getChatMemoryStore().add(this.doGetConversationId(request.adviseContext()), userMessage); return advisedRequest; } - @Override - public ChatResponse adviseResponse(ChatResponse chatResponse, Map context) { - - List assistantMessages = chatResponse.getResults().stream().map(g -> (Message) g.getOutput()).toList(); + private void observeAfter(AdvisedResponse advisedResponse) { - this.getChatMemoryStore().add(this.doGetConversationId(context), assistantMessages); + List assistantMessages = advisedResponse.response() + .getResults() + .stream() + .map(g -> (Message) g.getOutput()) + .toList(); - return chatResponse; + this.getChatMemoryStore().add(this.doGetConversationId(advisedResponse.adviseContext()), assistantMessages); } } \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java index 2b36a4c410d..c00f8e0bfb9 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java @@ -19,11 +19,15 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.Predicate; import java.util.stream.Collectors; -import org.springframework.ai.chat.client.AdvisedRequest; -import org.springframework.ai.chat.client.advisor.api.RequestAdvisor; -import org.springframework.ai.chat.client.advisor.api.ResponseAdvisor; +import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; +import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; +import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; +import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.document.Document; import org.springframework.ai.model.Content; @@ -34,6 +38,10 @@ import org.springframework.util.Assert; import org.springframework.util.StringUtils; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; + /** * Context for the question is retrieved from a Vector Store and added to the prompt's * user text. @@ -41,7 +49,7 @@ * @author Christian Tzolov * @since 1.0.0 */ -public class QuestionAnswerAdvisor implements RequestAdvisor, ResponseAdvisor { +public class QuestionAnswerAdvisor implements CallAroundAdvisor, StreamAroundAdvisor { private static final String DEFAULT_USER_TEXT_ADVISE = """ Context information is below. @@ -63,10 +71,24 @@ public class QuestionAnswerAdvisor implements RequestAdvisor, ResponseAdvisor { public static final String FILTER_EXPRESSION = "qa_filter_expression"; + private final boolean protectFromBlocking; + + /** + * The QuestionAnswerAdvisor retrieves context information from a Vector Store and + * combines it with the user's text. + * @param vectorStore The vector store to use + */ public QuestionAnswerAdvisor(VectorStore vectorStore) { this(vectorStore, SearchRequest.defaults(), DEFAULT_USER_TEXT_ADVISE); } + /** + * The QuestionAnswerAdvisor retrieves context information from a Vector Store and + * combines it with the user's text. + * @param vectorStore The vector store to use + * @param searchRequest The search request defined using the portable filter + * expression syntax + */ public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest) { this(vectorStore, searchRequest, DEFAULT_USER_TEXT_ADVISE); } @@ -79,9 +101,26 @@ public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchReques * expression syntax * @param userTextAdvise the user text to append to the existing user prompt. The text * should contain a placeholder named "question_answer_context". - * */ public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest, String userTextAdvise) { + this(vectorStore, searchRequest, userTextAdvise, true); + } + + /** + * The QuestionAnswerAdvisor retrieves context information from a Vector Store and + * combines it with the user's text. + * @param vectorStore The vector store to use + * @param searchRequest The search request defined using the portable filter + * expression syntax + * @param userTextAdvise the user text to append to the existing user prompt. The text + * should contain a placeholder named "question_answer_context". + * @param protectFromBlocking if true the advisor will protect the execution from + * blocking threads. If false the advisor will not protect the execution from blocking + * threads. This is useful when the advisor is used in a non-blocking environment. It + * is true by default. + */ + public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest, String userTextAdvise, + boolean protectFromBlocking) { Assert.notNull(vectorStore, "The vectorStore must not be null!"); Assert.notNull(searchRequest, "The searchRequest must not be null!"); @@ -90,6 +129,7 @@ public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchReques this.vectorStore = vectorStore; this.searchRequest = searchRequest; this.userTextAdvise = userTextAdvise; + this.protectFromBlocking = protectFromBlocking; } @Override @@ -98,7 +138,46 @@ public String getName() { } @Override - public AdvisedRequest adviseRequest(AdvisedRequest request, Map context) { + public int getOrder() { + return 0; + } + + @Override + public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { + + AdvisedRequest advisedRequest2 = before(advisedRequest); + + AdvisedResponse advisedResponse = chain.nextAroundCall(advisedRequest2); + + return after(advisedResponse); + } + + @Override + public Flux aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) { + + // This can be executed by both blocking and non-blocking Threads + // E.g. a command line or Tomcat blocking Thread implementation + // or by a WebFlux dispatch in a non-blocking manner. + Flux advisedResponses = (this.protectFromBlocking) ? + // @formatter:off + Mono.just(advisedRequest) + .publishOn(Schedulers.boundedElastic()) + .map(this::before) + .flatMapMany(request -> chain.nextAroundStream(request)) + : chain.nextAroundStream(before(advisedRequest)); + // @formatter:on + + return advisedResponses.map(ar -> { + if (onFinishReason().test(ar)) { + ar = after(ar); + } + return ar; + }); + } + + private AdvisedRequest before(AdvisedRequest request) { + + var context = new HashMap<>(request.adviseContext()); // 1. Advise the system text. String advisedUserText = request.userText() + System.lineSeparator() + this.userTextAdvise; @@ -124,21 +203,16 @@ public AdvisedRequest adviseRequest(AdvisedRequest request, Map AdvisedRequest advisedRequest = AdvisedRequest.from(request) .withUserText(advisedUserText) .withUserParams(advisedUserParams) + .withAdviseContext(context) .build(); return advisedRequest; } - @Override - public ChatResponse adviseResponse(ChatResponse response, Map context) { - ChatResponse.Builder chatResponseBuilder = ChatResponse.builder().from(response); - chatResponseBuilder.withMetadata(RETRIEVED_DOCUMENTS, context.get(RETRIEVED_DOCUMENTS)); - return chatResponseBuilder.build(); - } - - @Override - public StreamResponseMode getStreamResponseMode() { - return StreamResponseMode.ON_FINISH_ELEMENT; + private AdvisedResponse after(AdvisedResponse advisedResponse) { + ChatResponse.Builder chatResponseBuilder = ChatResponse.builder().from(advisedResponse.response()); + chatResponseBuilder.withMetadata(RETRIEVED_DOCUMENTS, advisedResponse.adviseContext().get(RETRIEVED_DOCUMENTS)); + return new AdvisedResponse(chatResponseBuilder.build(), advisedResponse.adviseContext()); } protected Filter.Expression doGetFilterExpression(Map context) { @@ -151,4 +225,57 @@ protected Filter.Expression doGetFilterExpression(Map context) { } + private Predicate onFinishReason() { + return (advisedResponse) -> advisedResponse.response() + .getResults() + .stream() + .filter(result -> result != null && result.getMetadata() != null + && StringUtils.hasText(result.getMetadata().getFinishReason())) + .findFirst() + .isPresent(); + } + + public static Builder builder(VectorStore vectorStore) { + return new Builder(vectorStore); + } + + public static class Builder { + + private final VectorStore vectorStore; + + private SearchRequest searchRequest = SearchRequest.defaults(); + + private String userTextAdvise = DEFAULT_USER_TEXT_ADVISE; + + private boolean protectFromBlocking = true; + + private Builder(VectorStore vectorStore) { + Assert.notNull(vectorStore, "The vectorStore must not be null!"); + this.vectorStore = vectorStore; + } + + public Builder withSearchRequest(SearchRequest searchRequest) { + Assert.notNull(searchRequest, "The searchRequest must not be null!"); + this.searchRequest = searchRequest; + return this; + } + + public Builder withUserTextAdvise(String userTextAdvise) { + Assert.hasText(userTextAdvise, "The userTextAdvise must not be empty!"); + this.userTextAdvise = userTextAdvise; + return this; + } + + public Builder withProtectFromBlocking(boolean protectFromBlocking) { + this.protectFromBlocking = protectFromBlocking; + return this; + } + + public QuestionAnswerAdvisor build() { + return new QuestionAnswerAdvisor(this.vectorStore, this.searchRequest, this.userTextAdvise, + this.protectFromBlocking); + } + + } + } \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SafeGuardAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SafeGuardAdvisor.java new file mode 100644 index 00000000000..4871707e47c --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SafeGuardAdvisor.java @@ -0,0 +1,79 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.chat.client.advisor; + +import java.util.List; + +import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; +import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; +import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; +import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.util.CollectionUtils; + +import reactor.core.publisher.Flux; + +/** + * A {@link CallAroundAdvisor} and {@link StreamAroundAdvisor} that filters out the + * response if the user input contains any of the sensitive words. + * + * @author Christian Tzolov + * @since 1.0.0 + */ +public class SafeGuardAdvisor implements CallAroundAdvisor, StreamAroundAdvisor { + + private final List sensitiveWords; + + public SafeGuardAdvisor(List sensitiveWords) { + this.sensitiveWords = sensitiveWords; + } + + public String getName() { + return this.getClass().getSimpleName(); + } + + @Override + public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { + + if (!CollectionUtils.isEmpty(this.sensitiveWords) + && sensitiveWords.stream().anyMatch(w -> advisedRequest.userText().contains(w))) { + return new AdvisedResponse(ChatResponse.builder().withGenerations(List.of()).build(), + advisedRequest.adviseContext()); + } + + return chain.nextAroundCall(advisedRequest); + } + + @Override + public Flux aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) { + + if (!CollectionUtils.isEmpty(this.sensitiveWords) + && sensitiveWords.stream().anyMatch(w -> advisedRequest.userText().contains(w))) { + return Flux.empty(); + } + + return chain.nextAroundStream(advisedRequest); + + } + + @Override + public int getOrder() { + return 0; + } + +} \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisor.java index 393de09a460..f327c3b5504 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisor.java @@ -15,23 +15,28 @@ */ package org.springframework.ai.chat.client.advisor; -import java.util.Map; import java.util.function.Function; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.ai.chat.client.AdvisedRequest; -import org.springframework.ai.chat.client.advisor.api.ResponseAdvisor; -import org.springframework.ai.chat.client.advisor.api.RequestAdvisor; +import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; +import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; +import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; +import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.MessageAggregator; import org.springframework.ai.model.ModelOptionsUtils; +import reactor.core.publisher.Flux; + /** * A simple logger advisor that logs the request and response messages. * * @author Christian Tzolov */ -public class SimpleLoggerAdvisor implements RequestAdvisor, ResponseAdvisor { +public class SimpleLoggerAdvisor implements CallAroundAdvisor, StreamAroundAdvisor { private static final Logger logger = LoggerFactory.getLogger(SimpleLoggerAdvisor.class); @@ -63,15 +68,17 @@ public String getName() { } @Override - public AdvisedRequest adviseRequest(AdvisedRequest request, Map context) { + public int getOrder() { + return 0; + } + + private AdvisedRequest before(AdvisedRequest request) { logger.debug("request: {}", this.requestToString.apply(request)); return request; } - @Override - public ChatResponse adviseResponse(ChatResponse response, Map context) { - logger.debug("response: {}", this.responseToString.apply(response)); - return response; + private void observeAfter(AdvisedResponse advisedResponse) { + logger.debug("response: {}", this.responseToString.apply(advisedResponse.response())); } @Override @@ -80,8 +87,25 @@ public String toString() { } @Override - public StreamResponseMode getStreamResponseMode() { - return StreamResponseMode.AGGREGATE; + public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { + + advisedRequest = before(advisedRequest); + + AdvisedResponse advisedResponse = chain.nextAroundCall(advisedRequest); + + observeAfter(advisedResponse); + + return advisedResponse; + } + + @Override + public Flux aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) { + + advisedRequest = before(advisedRequest); + + Flux advisedResponses = chain.nextAroundStream(advisedRequest); + + return new MessageAggregator().aggregateAdvisedResponse(advisedResponses, this::observeAfter); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/VectorStoreChatMemoryAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/VectorStoreChatMemoryAdvisor.java index 9b545bf9352..b6ab9439877 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/VectorStoreChatMemoryAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/VectorStoreChatMemoryAdvisor.java @@ -21,22 +21,27 @@ import java.util.Map; import java.util.stream.Collectors; -import org.springframework.ai.chat.client.AdvisedRequest; +import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; +import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; +import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.messages.UserMessage; -import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.MessageAggregator; import org.springframework.ai.document.Document; import org.springframework.ai.model.Content; import org.springframework.ai.vectorstore.SearchRequest; import org.springframework.ai.vectorstore.VectorStore; +import reactor.core.publisher.Flux; + /** * Memory is retrieved from a VectorStore added into the prompt's system text. * * @author Christian Tzolov - * @since 1.0.0 M1 + * @since 1.0.0 */ public class VectorStoreChatMemoryAdvisor extends AbstractChatMemoryAdvisor { @@ -73,18 +78,41 @@ public VectorStoreChatMemoryAdvisor(VectorStore vectorStore, String defaultConve public VectorStoreChatMemoryAdvisor(VectorStore vectorStore, String defaultConversationId, int chatHistoryWindowSize, String systemTextAdvise) { - super(vectorStore, defaultConversationId, chatHistoryWindowSize); + super(vectorStore, defaultConversationId, chatHistoryWindowSize, true); this.systemTextAdvise = systemTextAdvise; } @Override - public AdvisedRequest adviseRequest(AdvisedRequest request, Map context) { + public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { + + advisedRequest = this.before(advisedRequest); + + AdvisedResponse advisedResponse = chain.nextAroundCall(advisedRequest); + + this.observeAfter(advisedResponse); + + return advisedResponse; + } + + @Override + public Flux aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) { + + Flux advisedResponses = this.doNextWithProtectFromBlockingBefore(advisedRequest, chain, + this::before); + + // The observeAfter will certainly be executed on non-blocking Threads in case + // of some models - e.g. when the model client is a WebClient + return new MessageAggregator().aggregateAdvisedResponse(advisedResponses, this::observeAfter); + } + + private AdvisedRequest before(AdvisedRequest request) { String advisedSystemText = request.systemText() + System.lineSeparator() + this.systemTextAdvise; var searchRequest = SearchRequest.query(request.userText()) - .withTopK(this.doGetChatMemoryRetrieveSize(context)) - .withFilterExpression(DOCUMENT_METADATA_CONVERSATION_ID + "=='" + this.doGetConversationId(context) + "'"); + .withTopK(this.doGetChatMemoryRetrieveSize(request.adviseContext())) + .withFilterExpression(DOCUMENT_METADATA_CONVERSATION_ID + "=='" + + this.doGetConversationId(request.adviseContext()) + "'"); List documents = this.getChatMemoryStore().similaritySearch(searchRequest); @@ -101,19 +129,22 @@ public AdvisedRequest adviseRequest(AdvisedRequest request, Map .build(); UserMessage userMessage = new UserMessage(request.userText(), request.media()); - this.getChatMemoryStore().write(toDocuments(List.of(userMessage), this.doGetConversationId(context))); + this.getChatMemoryStore() + .write(toDocuments(List.of(userMessage), this.doGetConversationId(request.adviseContext()))); return advisedRequest; } - @Override - public ChatResponse adviseResponse(ChatResponse chatResponse, Map context) { - - List assistantMessages = chatResponse.getResults().stream().map(g -> (Message) g.getOutput()).toList(); + private void observeAfter(AdvisedResponse advisedResponse) { - this.getChatMemoryStore().write(toDocuments(assistantMessages, this.doGetConversationId(context))); + List assistantMessages = advisedResponse.response() + .getResults() + .stream() + .map(g -> (Message) g.getOutput()) + .toList(); - return chatResponse; + this.getChatMemoryStore() + .write(toDocuments(assistantMessages, this.doGetConversationId(advisedResponse.adviseContext()))); } private List toDocuments(List messages, String conversationId) { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/AdvisedRequest.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java similarity index 64% rename from spring-ai-core/src/main/java/org/springframework/ai/chat/client/AdvisedRequest.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java index 903c304dc3c..8630c697e40 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/AdvisedRequest.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java @@ -14,17 +14,28 @@ * limitations under the License. */ -package org.springframework.ai.chat.client; +package org.springframework.ai.chat.client.advisor.api; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.function.Function; import org.springframework.ai.model.Media; -import org.springframework.ai.chat.client.advisor.api.Advisor; import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; /** * The data of the chat client request that can be modified before the execution of the @@ -44,11 +55,18 @@ * @param systemParams the map of system parameters * @param advisors the list of request response advisors * @param advisorParams the map of advisor parameters + * @param adviseContext the map of advise context */ public record AdvisedRequest(ChatModel chatModel, String userText, String systemText, ChatOptions chatOptions, List media, List functionNames, List functionCallbacks, List messages, Map userParams, Map systemParams, List advisors, - Map advisorParams) { + Map advisorParams, Map adviseContext) { + + public AdvisedRequest updateContext(Function, Map> contextTransform) { + return from(this) + .withAdviseContext(Collections.unmodifiableMap(contextTransform.apply(new HashMap<>(this.adviseContext)))) + .build(); + } public static Builder from(AdvisedRequest from) { Builder builder = new Builder(); @@ -64,7 +82,8 @@ public static Builder from(AdvisedRequest from) { builder.systemParams = from.systemParams; builder.advisors = from.advisors; builder.advisorParams = from.advisorParams; - builder.advisorParams = from.advisorParams; + builder.adviseContext = from.adviseContext; + return builder; } @@ -98,6 +117,8 @@ public static class Builder { private Map advisorParams = Map.of(); + private Map adviseContext = Map.of(); + public Builder withChatModel(ChatModel chatModel) { this.chatModel = chatModel; return this; @@ -158,12 +179,58 @@ public Builder withAdvisorParams(Map advisorParams) { return this; } + public Builder withAdviseContext(Map adviseContext) { + this.adviseContext = adviseContext; + return this; + } + public AdvisedRequest build() { return new AdvisedRequest(chatModel, this.userText, this.systemText, this.chatOptions, this.media, this.functionNames, this.functionCallbacks, this.messages, this.userParams, this.systemParams, - this.advisors, this.advisorParams); + this.advisors, this.advisorParams, this.adviseContext); + } + + } + + public Prompt toPrompt() { + + var messages = new ArrayList(this.messages()); + + String processedSystemText = this.systemText(); + if (StringUtils.hasText(processedSystemText)) { + if (!CollectionUtils.isEmpty(this.systemParams())) { + processedSystemText = new PromptTemplate(processedSystemText, this.systemParams()).render(); + } + messages.add(new SystemMessage(processedSystemText)); + } + + String formatParam = (String) this.adviseContext().get("formatParam"); + + var processedUserText = StringUtils.hasText(formatParam) + ? this.userText() + System.lineSeparator() + "{spring_ai_soc_format}" : this.userText(); + + if (StringUtils.hasText(processedUserText)) { + + Map userParams = new HashMap<>(this.userParams()); + if (StringUtils.hasText(formatParam)) { + userParams.put("spring_ai_soc_format", formatParam); + } + if (!CollectionUtils.isEmpty(userParams)) { + processedUserText = new PromptTemplate(processedUserText, userParams).render(); + } + messages.add(new UserMessage(processedUserText, this.media())); + } + + if (this.chatOptions() instanceof FunctionCallingOptions functionCallingOptions) { + if (!this.functionNames().isEmpty()) { + functionCallingOptions.setFunctions(new HashSet<>(this.functionNames())); + } + if (!this.functionCallbacks().isEmpty()) { + functionCallingOptions.setFunctionCallbacks(this.functionCallbacks()); + } } + return new Prompt(messages, this.chatOptions()); } } \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponse.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponse.java new file mode 100644 index 00000000000..8c81740cfd3 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponse.java @@ -0,0 +1,73 @@ +/* +* Copyright 2024 - 2024 the original author or authors. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* https://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +package org.springframework.ai.chat.client.advisor.api; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.function.Function; + +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.util.Assert; + +/** + * @author Christian Tzolov + * @since 1.0.0 + */ +public record AdvisedResponse(ChatResponse response, Map adviseContext) { + + public AdvisedResponse updateContext(Function, Map> contextTransform) { + return new AdvisedResponse(response, + Collections.unmodifiableMap(contextTransform.apply(new HashMap<>(adviseContext)))); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private ChatResponse response; + + private Map adviseContext; + + public Builder() { + } + + public static Builder from(AdvisedResponse advisedResponse) { + return new Builder().withResponse(advisedResponse.response) + .withAdviseContext(advisedResponse.adviseContext); + } + + public Builder withResponse(ChatResponse response) { + Assert.notNull(response, "the response must be non-null"); + this.response = response; + return this; + } + + public Builder withAdviseContext(Map adviseContext) { + Assert.notNull(adviseContext, "the adviseContext must be non-null"); + this.adviseContext = adviseContext; + return this; + } + + public AdvisedResponse build() { + Assert.notNull(this.adviseContext, "the adviseContext must be non-null"); + return new AdvisedResponse(response, adviseContext); + } + + } +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/Advisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/Advisor.java index eaa09c6efb5..a4be691b423 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/Advisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/Advisor.java @@ -15,18 +15,39 @@ */ package org.springframework.ai.chat.client.advisor.api; +import org.springframework.core.Ordered; + /** * Parent advisor interface for all advisors. * * @author Christian Tzolov + * @author Dariusz Jedrzejczyk * @since 1.0.0 - * @see RequestAdvisor - * @see ResponseAdvisor * @see CallAroundAdvisor * @see StreamAroundAdvisor - * @see AroundAdvisorChain + * @see CallAroundAdvisorChain */ -public interface Advisor { +public interface Advisor extends Ordered { + + /** + * Useful constant for the highest precedence value for ordering advisors. + */ + public static int HIGHEST_PRECEDENCE_ORDER = Ordered.HIGHEST_PRECEDENCE; + + /** + * Useful constant for the lowest precedence value for ordering advisors. Note that + * the values from Ordered.LOWEST_PRECEDENCE to Ordered.LOWEST_PRECEDENCE + 1000 are + * reserved for internal use within the Spring AI framework. + */ + public static int LOWEST_PRECEDENCE_ORDER = Ordered.LOWEST_PRECEDENCE + 1000; + + /** + * Useful constant for the default Chat Memory precedence order. Ensures this order + * has lower priority (e.g. precedences) than the Spring AI internal advisors. It + * leaves room (1000 slots) for the user to plug in their own advisors with higher + * priority. + */ + public static int DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER = Ordered.HIGHEST_PRECEDENCE + 1000; /** * @return the advisor name. diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AroundAdvisorChain.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AroundAdvisorChain.java deleted file mode 100644 index f33a686c24d..00000000000 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AroundAdvisorChain.java +++ /dev/null @@ -1,16 +0,0 @@ -package org.springframework.ai.chat.client.advisor.api; - -import java.util.Map; - -import org.springframework.ai.chat.client.AdvisedRequest; -import org.springframework.ai.chat.model.ChatResponse; - -import reactor.core.publisher.Flux; - -public interface AroundAdvisorChain { - - ChatResponse nextAroundCall(AdvisedRequest advisedRequest, Map adviceContext); - - Flux nextAroundStream(AdvisedRequest advisedRequest, Map adviceContext); - -} \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisor.java index 4ff26544e94..57d19df600e 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisor.java @@ -15,13 +15,9 @@ */ package org.springframework.ai.chat.client.advisor.api; -import java.util.Map; - -import org.springframework.ai.chat.client.AdvisedRequest; -import org.springframework.ai.chat.model.ChatResponse; - /** * @author Christian Tzolov + * @author Dariusz Jedrzejczyk * @since 1.0.0 */ @@ -30,10 +26,9 @@ public interface CallAroundAdvisor extends Advisor { /** * Around advice that wraps the ChatModel#call(Prompt) method. * @param advisedRequest the advised request - * @param adviceContext the advice context * @param chain the advisor chain * @return the response */ - ChatResponse aroundCall(AdvisedRequest advisedRequest, Map adviceContext, AroundAdvisorChain chain); + AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain); } \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisorChain.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisorChain.java new file mode 100644 index 00000000000..56b3240b00e --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisorChain.java @@ -0,0 +1,30 @@ +/* +* Copyright 2024 - 2024 the original author or authors. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* https://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +package org.springframework.ai.chat.client.advisor.api; + +/** + * The Call Around Advisor Chain is used to invoke the next Around Advisor in the chain. + * Used for non-streaming responses. + * + * @author Christian Tzolov + * @author Dariusz Jedrzejczyk + * @since 1.0.0 + */ +public interface CallAroundAdvisorChain { + + AdvisedResponse nextAroundCall(AdvisedRequest advisedRequest); + +} \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/RequestAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/RequestAdvisor.java deleted file mode 100644 index 8ba198e323a..00000000000 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/RequestAdvisor.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.chat.client.advisor.api; - -import java.util.Map; - -import org.springframework.ai.chat.client.AdvisedRequest; -import org.springframework.ai.chat.client.ChatClient; -import org.springframework.ai.chat.model.ChatModel; -import org.springframework.ai.chat.prompt.Prompt; - -/** - * Advisor called before the {@link ChatModel#call(Prompt)} and - * {@link ChatModel#stream(Prompt)} methods are called. The {@link ChatClient} maintains a - * chain of advisors with shared advise context. - * - * @author Christian Tzolov - * @since 1.0.0 - */ -public interface RequestAdvisor extends Advisor { - - /** - * @param request the {@link AdvisedRequest} data to be advised. Represents the row - * {@link ChatClient.ChatClientRequestSpec} data before sealed into a {@link Prompt}. - * @param adviseContext the shared data between the advisors in the chain. It is - * shared between all request and response advising points of all advisors in the - * chain. - * @return the advised {@link AdvisedRequest}. - */ - AdvisedRequest adviseRequest(AdvisedRequest request, Map adviseContext); - -} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/ResponseAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/ResponseAdvisor.java deleted file mode 100644 index e1040c91da0..00000000000 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/ResponseAdvisor.java +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.chat.client.advisor.api; - -import java.util.Map; - -import org.springframework.ai.chat.client.ChatClient; -import org.springframework.ai.chat.model.ChatModel; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.prompt.Prompt; - -import reactor.core.publisher.Flux; - -/** - * Advisor called after the {@link ChatModel#call(Prompt)} (or - * {@link ChatModel#stream(Prompt)}) method call. The {@link ChatClient} maintains a chain - * of advisors with shared advise context. - * - * @author Christian Tzolov - * @since 1.0.0 - */ -public interface ResponseAdvisor extends Advisor { - - /** - * @param response the {@link ChatResponse} data to be advised. Represents the row - * {@link ChatResponse} data after the {@link ChatModel#call(Prompt)} method is - * called. - * @param adviseContext the shared data between the advisors in the chain. It is - * shared between all request and response advising points of all advisors in the - * chain. - * @return the advised {@link ChatResponse}. - */ - ChatResponse adviseResponse(ChatResponse response, Map adviseContext); - - /** - * Different modes of advising the streaming responses. - */ - public enum StreamResponseMode { - - /** - * Called for each response element in the Flux. The response advisor can modify - * the elements before they are returned to the client. - */ - PER_ELEMENT, - /** - * Called only on Flux elements that contain a finish reason. Usually the last - * element in the Flux. The response advisor can modify the elements before they - * are returned to the client. - */ - ON_FINISH_ELEMENT, - /** - * Called only once after all Flux elements have been consumed. All elements are - * merged into a single ChatResponse element and provided to the response advisor - * to process.
- * Mind that at that stage the response advisor can not longer modify the response - * returned to the client. - */ - AGGREGATE, - /** - * Delegates to the stream advisor implementation. - */ - CUSTOM; - - } - - default StreamResponseMode getStreamResponseMode() { - return StreamResponseMode.ON_FINISH_ELEMENT; - } - - /** - * @param fluxResponse the streaming {@link ChatResponse} data to be advised. - * Represents the row {@link ChatResponse} stream data after the - * {@link ChatModel#stream(Prompt)} method is called. - * @param adviseContext the shared data between the advisors in the chain. It is - * shared between all request and response advising points of all advisors in the - * chain. - * @return the advised {@link ChatResponse} flux. - */ - default Flux adviseResponse(Flux fluxResponse, Map adviseContext) { - return fluxResponse; - } - -} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisor.java index d2379a7724a..69653dc45e5 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisor.java @@ -15,15 +15,11 @@ */ package org.springframework.ai.chat.client.advisor.api; -import java.util.Map; - -import org.springframework.ai.chat.client.AdvisedRequest; -import org.springframework.ai.chat.model.ChatResponse; - import reactor.core.publisher.Flux; /** * @author Christian Tzolov + * @author Dariusz Jedrzejczyk * @since 1.0.0 */ public interface StreamAroundAdvisor extends Advisor { @@ -31,11 +27,9 @@ public interface StreamAroundAdvisor extends Advisor { /** * Around advice that wraps the invocation of the advised request. * @param advisedRequest - * @param adviceContext * @param chain the chain of advisors to execute * @return the result of the advised request */ - Flux aroundStream(AdvisedRequest advisedRequest, Map adviceContext, - AroundAdvisorChain chain); + Flux aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain); } \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisorChain.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisorChain.java new file mode 100644 index 00000000000..49d5e0998c3 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisorChain.java @@ -0,0 +1,32 @@ +/* +* Copyright 2024 - 2024 the original author or authors. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* https://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +package org.springframework.ai.chat.client.advisor.api; + +import reactor.core.publisher.Flux; + +/** + * The StreamAroundAdvisorChain is used to delegate the call to the next + * StreamAroundAdvisor in the chain. Used for streaming responses. + * + * @author Christian Tzolov + * @author Dariusz Jedrzejczyk + * @since 1.0.0 + */ +public interface StreamAroundAdvisorChain { + + Flux nextAroundStream(AdvisedRequest advisedRequest); + +} \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/around/CacheAroundAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/around/CacheAroundAdvisor.java deleted file mode 100644 index 51dcc6b5ec9..00000000000 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/around/CacheAroundAdvisor.java +++ /dev/null @@ -1,149 +0,0 @@ -/* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ -package org.springframework.ai.chat.client.advisor.around; - -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; - -import org.springframework.ai.chat.client.AdvisedRequest; -import org.springframework.ai.chat.client.advisor.api.AroundAdvisorChain; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; -import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; -import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.MessageType; -import org.springframework.ai.chat.metadata.ChatGenerationMetadata; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.Generation; -import org.springframework.ai.chat.model.MessageAggregator; -import org.springframework.ai.document.Document; -import org.springframework.ai.vectorstore.SearchRequest; -import org.springframework.ai.vectorstore.VectorStore; -import org.springframework.util.CollectionUtils; - -import reactor.core.publisher.Flux; - -/** - * @author Christian Tzolov - * @since 1.0.0 - */ -public class CacheAroundAdvisor implements CallAroundAdvisor, StreamAroundAdvisor { - - private final VectorStore vectorStore; - - private static final String DOCUMENT_METADATA_ADVISOR_CACHE_TAG = "advisorCacheDocument"; - - private static final String DOCUMENT_METADATA_ADVISOR_CACHE_RESPONSE = "advisorCacheResponse"; - - public CacheAroundAdvisor(VectorStore vectorStore) { - this.vectorStore = vectorStore; - } - - public String getName() { - return this.getClass().getSimpleName(); - } - - @Override - public ChatResponse aroundCall(AdvisedRequest advisedRequest, Map adviceContext, - AroundAdvisorChain chain) { - - var cachedResponseOption = getCacheEntry(advisedRequest, adviceContext); - if (cachedResponseOption.isPresent()) { - return cachedResponseOption.get(); - } - - ChatResponse chatResponse = chain.nextAroundCall(advisedRequest, adviceContext); - - saveCacheEntry(advisedRequest.userText(), chatResponse); - - return chatResponse; - } - - @Override - public Flux aroundStream(AdvisedRequest advisedRequest, Map adviceContext, - AroundAdvisorChain chain) { - - var cachedResponseOption = getCacheEntry(advisedRequest, adviceContext); - if (cachedResponseOption.isPresent()) { - return Flux.just(cachedResponseOption.get()); - } - - Flux fluxChatResponse = chain.nextAroundStream(advisedRequest, adviceContext); - - return new MessageAggregator().aggregate(fluxChatResponse, chatResponse -> { - saveCacheEntry(advisedRequest.userText(), chatResponse); - }); - } - - private void saveCacheEntry(String userQuestion, ChatResponse chatResponse) { - List assistantMessages = chatResponse.getResults().stream().map(g -> (Message) g.getOutput()).toList(); - if (!CollectionUtils.isEmpty(assistantMessages)) { - this.vectorStore.add(toDocuments(userQuestion, assistantMessages)); - } - } - - private Optional getCacheEntry(AdvisedRequest advisedRequest, Map adviceContext) { - - // TODO: convert into pompty first or at least materialize the user params. - String userText = advisedRequest.userText(); - - // @formatter:off - var searchRequest = SearchRequest.query(userText) - .withSimilarityThreshold(0.95) - .withTopK(1) - .withFilterExpression("'"+ DOCUMENT_METADATA_ADVISOR_CACHE_TAG + "' == 'true'"); - // @formatter:on - - List doc = vectorStore.similaritySearch(searchRequest); - - // return cached response - return CollectionUtils.isEmpty(doc) ? Optional.empty() : Optional.of(fromDocument(doc.get(0))); - } - - private ChatResponse fromDocument(Document doc) { - - if (!doc.getMetadata().containsKey(DOCUMENT_METADATA_ADVISOR_CACHE_RESPONSE)) { - throw new IllegalStateException("The document is missing the cache response metadata!"); - } - String cachedResponse = "" + doc.getMetadata().get(DOCUMENT_METADATA_ADVISOR_CACHE_RESPONSE); - - return ChatResponse.builder() - .withGenerations(List.of(new Generation(new AssistantMessage(cachedResponse, Map.of()), - ChatGenerationMetadata.from("STOP", null)))) - .build(); - } - - private List toDocuments(String userQuestion, List messages) { - - List docs = messages.stream() - .filter(m -> m.getMessageType() == MessageType.ASSISTANT) - .map(message -> { - var metadata = new HashMap<>(message.getMetadata() != null ? message.getMetadata() : new HashMap<>()); - metadata.put(DOCUMENT_METADATA_ADVISOR_CACHE_TAG, "true"); - metadata.put(DOCUMENT_METADATA_ADVISOR_CACHE_RESPONSE, message.getContent()); - // TODO: Pehaps we need to serialize the message metadata to the document - - return new Document(userQuestion, metadata); - - }) - .toList(); - - return docs; - } - -} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/around/SafeGuardAroundAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/around/SafeGuardAroundAdvisor.java deleted file mode 100644 index f682ceb1b70..00000000000 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/around/SafeGuardAroundAdvisor.java +++ /dev/null @@ -1,59 +0,0 @@ -package org.springframework.ai.chat.client.advisor.around; - -import java.util.List; -import java.util.Map; - -import org.springframework.ai.chat.client.AdvisedRequest; -import org.springframework.ai.chat.client.advisor.api.AroundAdvisorChain; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; -import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.util.CollectionUtils; - -import reactor.core.publisher.Flux; - -/** - * A {@link CallAroundAdvisor} and {@link StreamAroundAdvisor} that filters out the - * response if the user input contains any of the sensitive words. - * - * @author Christian Tzolov - * @since 1.0.0 - */ -public class SafeGuardAroundAdvisor implements CallAroundAdvisor, StreamAroundAdvisor { - - private final List sensitiveWords; - - public SafeGuardAroundAdvisor(List sensitiveWords) { - this.sensitiveWords = sensitiveWords; - } - - public String getName() { - return this.getClass().getSimpleName(); - } - - @Override - public ChatResponse aroundCall(AdvisedRequest advisedRequest, Map adviceContext, - AroundAdvisorChain chain) { - - if (!CollectionUtils.isEmpty(this.sensitiveWords) - && sensitiveWords.stream().anyMatch(w -> advisedRequest.userText().contains(w))) { - return ChatResponse.builder().withGenerations(List.of()).build(); - } - - return chain.nextAroundCall(advisedRequest, adviceContext); - } - - @Override - public Flux aroundStream(AdvisedRequest advisedRequest, Map adviceContext, - AroundAdvisorChain chain) { - - if (!CollectionUtils.isEmpty(this.sensitiveWords) - && sensitiveWords.stream().anyMatch(w -> advisedRequest.userText().contains(w))) { - return Flux.empty(); - } - - return chain.nextAroundStream(advisedRequest, adviceContext); - - } - -} \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservableHelper.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservableHelper.java deleted file mode 100644 index e267a5ee2fe..00000000000 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservableHelper.java +++ /dev/null @@ -1,103 +0,0 @@ -/* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ -package org.springframework.ai.chat.client.advisor.observation; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Map; - -import org.springframework.ai.chat.client.AdvisedRequest; -import org.springframework.ai.chat.client.advisor.api.Advisor; -import org.springframework.ai.chat.client.advisor.api.RequestAdvisor; -import org.springframework.ai.chat.client.advisor.api.ResponseAdvisor; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.util.CollectionUtils; - -import io.micrometer.observation.Observation; - -/** - * @author Christian Tzolov - * @since 1.0.0 - */ -public abstract class AdvisorObservableHelper { - - public static final AdvisorObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultAdvisorObservationConvention(); - - public static AdvisedRequest adviseRequest(Observation parentObservation, RequestAdvisor advisor, - AdvisedRequest advisedRequest, Map advisorContext) { - - var observationContext = AdvisorObservationContext.builder() - .withAdvisorName(advisor.getName()) - .withAdvisorType(AdvisorObservationContext.Type.BEFORE) - .withAdvisedRequest(advisedRequest) - .withAdvisorRequestContext(advisorContext) - .build(); - - return AdvisorObservationDocumentation.AI_ADVISOR - .observation(null, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, - parentObservation.getObservationRegistry()) - .parentObservation(parentObservation) - .observe(() -> advisor.adviseRequest(advisedRequest, advisorContext)); - } - - public static ChatResponse adviseResponse(Observation parentObservation, ResponseAdvisor advisor, - ChatResponse response, Map advisorContext) { - - var observationContext = AdvisorObservationContext.builder() - .withAdvisorName(advisor.getName()) - .withAdvisorType(AdvisorObservationContext.Type.AFTER) - .withAdvisorRequestContext(advisorContext) - .build(); - - return AdvisorObservationDocumentation.AI_ADVISOR - .observation(null, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, - parentObservation.getObservationRegistry()) - .parentObservation(parentObservation) - .observe(() -> advisor.adviseResponse(response, advisorContext)); - } - - public static List extractRequestAdvisors(List advisors) { - return advisors.stream() - .filter(advisor -> advisor instanceof RequestAdvisor) - .map(a -> (RequestAdvisor) a) - .toList(); - } - - /** - * Extracts the {@link ResponseAdvisor} instances from the given list of advisors and - * returns them in reverse order. - * @param advisors list of all registered advisor types. - * @return the list of {@link ResponseAdvisor} instances in reverse order. - */ - public static List extractResponseAdvisors(List advisors) { - - var list = advisors.stream() - .filter(advisor -> advisor instanceof ResponseAdvisor) - .map(a -> (ResponseAdvisor) a) - .toList(); - - // reverse the list - if (CollectionUtils.isEmpty(list)) { - return list; - } - - var reversedList = new ArrayList<>(list); - Collections.reverse(reversedList); - return Collections.unmodifiableList(reversedList); - } - -} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationContext.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationContext.java index 5a4421e98ca..3346c7b56e2 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationContext.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationContext.java @@ -17,8 +17,8 @@ import java.util.Map; -import org.springframework.ai.chat.client.AdvisedRequest; import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.util.Assert; @@ -59,6 +59,11 @@ public enum Type { */ private Map advisorResponseContext; + /** + * The order of the advisor in the advisor chain. + */ + private int order; + public void setAdvisorName(String advisorName) { this.advisorName = advisorName; } @@ -99,6 +104,14 @@ public void setAdvisorResponseContext(Map advisorResponseContext this.advisorResponseContext = advisorResponseContext; } + public int getOrder() { + return this.order; + } + + public void setOrder(int order) { + this.order = order; + } + public static Builder builder() { return new Builder(); } @@ -132,6 +145,11 @@ public Builder withAdvisorResponseContext(Map advisorResponseCon return this; } + public Builder withOrder(int order) { + this.context.setOrder(order); + return this; + } + public AdvisorObservationContext build() { Assert.hasText(this.context.advisorName, "The advisorName must not be empty!"); Assert.notNull(this.context.advisorType, "The advisorType must not be null!"); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationDocumentation.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationDocumentation.java index 5bb022b0879..67eafbf0bb3 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationDocumentation.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationDocumentation.java @@ -81,6 +81,15 @@ public enum HighCardinalityKeyNames implements KeyName { public String asString() { return "spring.ai.chat.client.advisor.name"; } + }, + /** + * Advisor order in the advisor chain. + */ + ADVISOR_ORDER { + @Override + public String asString() { + return "spring.ai.chat.client.advisor.order"; + } }; } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/DefaultAdvisorObservationConvention.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/DefaultAdvisorObservationConvention.java index e6babe01891..43ab4ca5c83 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/DefaultAdvisorObservationConvention.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/DefaultAdvisorObservationConvention.java @@ -90,7 +90,7 @@ protected KeyValue springAiKind() { @Override public KeyValues getHighCardinalityKeyValues(AdvisorObservationContext context) { - return KeyValues.of(advisorName(context)); + return KeyValues.of(advisorName(context), advisorOrder(context)); } protected KeyValue advisorName(AdvisorObservationContext context) { @@ -100,4 +100,8 @@ protected KeyValue advisorName(AdvisorObservationContext context) { return ADVISOR_NAME_NONE; } + protected KeyValue advisorOrder(AdvisorObservationContext context) { + return KeyValue.of(HighCardinalityKeyNames.ADVISOR_ORDER, "" + context.getOrder()); + } + } \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/memory/ChatMemory.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/memory/ChatMemory.java index 6d383c0c43c..c4e9b33f0db 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/memory/ChatMemory.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/memory/ChatMemory.java @@ -30,6 +30,8 @@ */ public interface ChatMemory { + // TODO: consider a non-blocking interface for streaming usages + default void add(String conversationId, Message message) { this.add(conversationId, List.of(message)); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java index 6aef10ed777..1c6bfc70225 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java @@ -24,6 +24,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.ChatResponseMetadata; @@ -46,6 +47,27 @@ public class MessageAggregator { private static final Logger logger = LoggerFactory.getLogger(MessageAggregator.class); + public Flux aggregateAdvisedResponse(Flux advisedResponses, + Consumer aggregationHandler) { + + AtomicReference> adviseContext = new AtomicReference<>(new HashMap<>()); + + return new MessageAggregator().aggregate(advisedResponses.map(ar -> { + adviseContext.get().putAll(ar.adviseContext()); + return ar.response(); + + }), aggregatedChatResponse -> { + + AdvisedResponse aggregatedAdvisedResponse = AdvisedResponse.builder() + .withResponse(aggregatedChatResponse) + .withAdviseContext(adviseContext.get()) + .build(); + + aggregationHandler.accept(aggregatedAdvisedResponse); + + }).map(cr -> new AdvisedResponse(cr, adviseContext.get())); + } + public Flux aggregate(Flux fluxChatResponse, Consumer onAggregationComplete) { diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientAdvisorTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientAdvisorTests.java index 432029fb813..a9ea8730e75 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientAdvisorTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientAdvisorTests.java @@ -16,8 +16,10 @@ package org.springframework.ai.chat.client; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.when; + import java.util.List; -import java.util.Map; import java.util.stream.Collectors; import org.junit.jupiter.api.Test; @@ -26,11 +28,10 @@ import org.mockito.Captor; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; -import reactor.core.publisher.Flux; - import org.springframework.ai.chat.client.advisor.PromptChatMemoryAdvisor; import org.springframework.ai.chat.memory.ChatMemory; import org.springframework.ai.chat.memory.InMemoryChatMemory; +import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.model.ChatModel; @@ -38,8 +39,7 @@ import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.when; +import reactor.core.publisher.Flux; /** * @author Christian Tzolov @@ -61,19 +61,17 @@ private String join(Flux fluxContent) { public void promptChatMemory() { when(chatModel.call(promptCaptor.capture())) - .thenReturn(new ChatResponse(List.of(new Generation("Hello John")))) - .thenReturn(new ChatResponse(List.of(new Generation("Your name is John")))); + .thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("Hello John"))))) + .thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("Your name is John"))))); ChatMemory chatMemory = new InMemoryChatMemory(); var chatClient = ChatClient.builder(chatModel) - .defaultSystem("Default system text.") - .defaultAdvisors(new PromptChatMemoryAdvisor(chatMemory)) - .build(); + .defaultSystem("Default system text.") + .defaultAdvisors(new PromptChatMemoryAdvisor(chatMemory)) + .build(); - var content = chatClient.prompt() - .user("my name is John") - .call().content(); + var content = chatClient.prompt().user("my name is John").call().content(); assertThat(content).isEqualTo("Hello John"); @@ -92,9 +90,7 @@ public void promptChatMemory() { Message userMessage = promptCaptor.getValue().getInstructions().get(1); assertThat(userMessage.getContent()).isEqualToIgnoringWhitespace("my name is John"); - content = chatClient.prompt() - .user("What is my name?") - .call().content(); + content = chatClient.prompt().user("What is my name?").call().content(); assertThat(content).isEqualTo("Your name is John"); @@ -119,31 +115,28 @@ public void promptChatMemory() { @Test public void streamingPromptChatMemory() { - when(chatModel.stream(promptCaptor.capture())) - .thenReturn( - Flux.generate(() -> new ChatResponse(List.of(new Generation("Hello John"))), (state, sink) -> { - sink.next(state); - sink.complete(); - return state; - })) - .thenReturn( - Flux.generate(() -> new ChatResponse(List.of(new Generation("Your name is John"))), - (state, sink) -> { - sink.next(state); - sink.complete(); - return state; - })); + when(chatModel.stream(promptCaptor.capture())).thenReturn(Flux.generate( + () -> new ChatResponse(List.of(new Generation(new AssistantMessage("Hello John")))), (state, sink) -> { + sink.next(state); + sink.complete(); + return state; + })) + .thenReturn(Flux.generate( + () -> new ChatResponse(List.of(new Generation(new AssistantMessage("Your name is John")))), + (state, sink) -> { + sink.next(state); + sink.complete(); + return state; + })); ChatMemory chatMemory = new InMemoryChatMemory(); var chatClient = ChatClient.builder(chatModel) - .defaultSystem("Default system text.") - .defaultAdvisors(new PromptChatMemoryAdvisor(chatMemory)) - .build(); + .defaultSystem("Default system text.") + .defaultAdvisors(new PromptChatMemoryAdvisor(chatMemory)) + .build(); - var content = join(chatClient.prompt() - .user("my name is John") - .stream().content()); + var content = join(chatClient.prompt().user("my name is John").stream().content()); assertThat(content).isEqualTo("Hello John"); @@ -162,9 +155,7 @@ public void streamingPromptChatMemory() { Message userMessage = promptCaptor.getValue().getInstructions().get(1); assertThat(userMessage.getContent()).isEqualToIgnoringWhitespace("my name is John"); - content = join(chatClient.prompt() - .user("What is my name?") - .stream().content()); + content = join(chatClient.prompt().user("What is my name?").stream().content()); assertThat(content).isEqualTo("Your name is John"); @@ -186,87 +177,97 @@ public void streamingPromptChatMemory() { assertThat(userMessage.getContent()).isEqualToIgnoringWhitespace("What is my name?"); } - public static class MockAdvisor implements RequestResponseAdvisor { + // public static class MockAdvisor implements RequestResponseAdvisor { - public AdvisedRequest advisedRequest; + // public AdvisedRequest advisedRequest; - public Map advisedRequestContext; + // public Map advisedRequestContext; - public Map chatResponseContext; + // public Map chatResponseContext; - public ChatResponse chatResponse; + // public ChatResponse chatResponse; - public Map fluxChatResponseContext; + // public Map fluxChatResponseContext; - public Flux fluxChatResponse; + // public Flux fluxChatResponse; - @Override - public AdvisedRequest adviseRequest(AdvisedRequest request, Map context) { - advisedRequest = request; - advisedRequestContext = context; + // @Override + // public AdvisedRequest adviseRequest(AdvisedRequest request, Map + // context) { + // advisedRequest = request; + // advisedRequestContext = context; - context.put("adviseRequest", "adviseRequest"); + // context.put("adviseRequest", "adviseRequest"); - return request; - } + // return request; + // } - @Override - public ChatResponse adviseResponse(ChatResponse response, Map context) { - chatResponse = response; - chatResponseContext = context; + // @Override + // public ChatResponse adviseResponse(ChatResponse response, Map + // context) { + // chatResponse = response; + // chatResponseContext = context; - context.put("adviseResponse", "adviseResponse"); - return response; - } + // context.put("adviseResponse", "adviseResponse"); + // return response; + // } - @Override - public Flux adviseResponse(Flux fluxResponse, Map context) { - fluxChatResponse = fluxResponse; - fluxChatResponseContext = context; + // @Override + // public Flux adviseResponse(Flux fluxResponse, + // Map context) { + // fluxChatResponse = fluxResponse; + // fluxChatResponseContext = context; - context.put("fluxAdviseResponse", "fluxAdviseResponse"); + // context.put("fluxAdviseResponse", "fluxAdviseResponse"); - return fluxResponse; - } + // return fluxResponse; + // } - }; + // }; - @Test - public void advisors() { + // @Test + // public void advisors() { - var mockAdvisor = new MockAdvisor(); + // var mockAdvisor = new MockAdvisor(); - when(chatModel.call(promptCaptor.capture())).thenReturn(new ChatResponse(List.of(new Generation("Hello John")))) - .thenReturn(new ChatResponse(List.of(new Generation("Your name is John")))); + // when(chatModel.call(promptCaptor.capture())) + // .thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("Hello + // John"))))) + // .thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("Your name + // is John"))))); - when(chatModel.call(promptCaptor.capture())).thenReturn(new ChatResponse(List.of(new Generation("Hello John")))) - .thenReturn(new ChatResponse(List.of(new Generation("Your name is John")))); + // when(chatModel.call(promptCaptor.capture())) + // .thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("Hello + // John"))))) + // .thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("Your name + // is John"))))); - var chatClient = ChatClient.builder(chatModel) - .defaultSystem("Default system text.") - .defaultAdvisors(mockAdvisor) - .build(); + // var chatClient = ChatClient.builder(chatModel) + // .defaultSystem("Default system text.") + // .defaultAdvisors(mockAdvisor) + // .build(); - var content = chatClient.prompt() - .user("my name is John") - .advisors(a -> a.param("key1", "value1").params(Map.of("key2", "value2"))) - .call() - .content(); + // var content = chatClient.prompt() + // .user("my name is John") + // .advisors(a -> a.param("key1", "value1").params(Map.of("key2", "value2"))) + // .call() + // .content(); - assertThat(content).isEqualTo("Hello John"); + // assertThat(content).isEqualTo("Hello John"); - assertThat(mockAdvisor.advisedRequestContext).containsEntry("key1", "value1") - .containsEntry("key2", "value2") - .containsEntry("adviseRequest", "adviseRequest"); - assertThat(mockAdvisor.advisedRequest.advisorParams()).containsEntry("key1", "value1") - .containsEntry("key2", "value2") - .doesNotContainKey("adviseRequest"); - - assertThat(mockAdvisor.chatResponseContext).containsEntry("key1", "value1") - .containsEntry("key2", "value2") - .containsEntry("adviseRequest", "adviseRequest") - .containsEntry("adviseResponse", "adviseResponse"); - assertThat(mockAdvisor.chatResponse).isNotNull(); - } + // assertThat(mockAdvisor.advisedRequestContext).containsEntry("key1", "value1") + // .containsEntry("key2", "value2") + // .containsEntry("adviseRequest", "adviseRequest"); + // assertThat(mockAdvisor.advisedRequest.advisorParams()).containsEntry("key1", + // "value1") + // .containsEntry("key2", "value2") + // .doesNotContainKey("adviseRequest"); + + // assertThat(mockAdvisor.chatResponseContext).containsEntry("key1", "value1") + // .containsEntry("key2", "value2") + // .containsEntry("adviseRequest", "adviseRequest") + // .containsEntry("adviseResponse", "adviseResponse"); + // assertThat(mockAdvisor.chatResponse).isNotNull(); + // } } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/AdvisorsTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/AdvisorsTests.java new file mode 100644 index 00000000000..1272bcebda3 --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/AdvisorsTests.java @@ -0,0 +1,215 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.client.advisor; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.when; +import static org.mockito.Mockito.verify; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; +import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; +import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; +import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; + +import reactor.core.publisher.Flux; + +/** + * @author Christian Tzolov + */ +@ExtendWith(MockitoExtension.class) +public class AdvisorsTests { + + @Mock + ChatModel chatModel; + + @Captor + ArgumentCaptor promptCaptor; + + public class MockAroundAdvisor implements CallAroundAdvisor, StreamAroundAdvisor { + + public AdvisedRequest advisedRequest; + + public AdvisedResponse advisedResponse; + + public List aroundAdvisedResponses = new ArrayList<>(); + + private final String name; + + private final int order; + + public MockAroundAdvisor(String name, int order) { + this.name = name; + this.order = order; + } + + @Override + public String getName() { + return name; + } + + @Override + public int getOrder() { + return order; + } + + @Override + public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { + + this.advisedRequest = advisedRequest.updateContext(context -> { + context.put("aroundCallBefore" + name, "AROUND_CALL_BEFORE " + name); + context.put("lastBefore", name); + return context; + }); + + AdvisedResponse advisedResponse = this.advisedResponse = chain.nextAroundCall(this.advisedRequest); + + this.advisedResponse = advisedResponse.updateContext(context -> { + context.put("aroundCallAfter" + name, "AROUND_CALL_AFTER " + name); + context.put("lastAfter", name); + return context; + }); + + return this.advisedResponse; + } + + @Override + public Flux aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) { + + this.advisedRequest = advisedRequest.updateContext(context -> { + context.put("aroundStreamBefore" + name, "AROUND_STREAM_BEFORE " + name); + context.put("lastBefore", name); + return context; + }); + + Flux advisedResponseStream = chain.nextAroundStream(this.advisedRequest); + + return advisedResponseStream.map(advisedResponse -> { + return advisedResponse.updateContext(context -> { + context.put("aroundStreamAfter" + name, "AROUND_STREAM_AFTER " + name); + context.put("lastAfter", name); + return context; + }); + }).doOnNext(ar -> this.aroundAdvisedResponses.add(ar)); + + } + + } + + @Test + public void callAdvisorsContextPropagation() { + + // Order==0 has higher priority thant order == 1. The lower the order the higher + // the priority. + var mockAroundAdvisor1 = new MockAroundAdvisor("Advisor1", 0); + var mockAroundAdvisor2 = new MockAroundAdvisor("Advisor2", 1); + + when(chatModel.call(promptCaptor.capture())) + .thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("Hello John"))))); + + var chatClient = ChatClient.builder(chatModel) + .defaultSystem("Default system text.") + .defaultAdvisors(mockAroundAdvisor1) + .build(); + + var content = chatClient.prompt() + .user("my name is John") + .advisors(mockAroundAdvisor2) + .advisors(a -> a.param("key1", "value1").params(Map.of("key2", "value2"))) + .call() + .content(); + + assertThat(content).isEqualTo("Hello John"); + + // AROUND + assertThat(mockAroundAdvisor1.advisedResponse.response()).isNotNull(); + assertThat(mockAroundAdvisor1.advisedResponse.adviseContext()).containsEntry("key1", "value1") + .containsEntry("key2", "value2") + .containsEntry("aroundCallBeforeAdvisor1", "AROUND_CALL_BEFORE Advisor1") + .containsEntry("aroundCallAfterAdvisor1", "AROUND_CALL_AFTER Advisor1") + .containsEntry("aroundCallBeforeAdvisor2", "AROUND_CALL_BEFORE Advisor2") + .containsEntry("aroundCallAfterAdvisor2", "AROUND_CALL_AFTER Advisor2") + .containsEntry("lastBefore", "Advisor2") // inner + .containsEntry("lastAfter", "Advisor1"); // outer + + verify(chatModel).call(promptCaptor.capture()); + } + + @Test + public void streamAdvisorsContextPropagation() { + + var mockAroundAdvisor1 = new MockAroundAdvisor("Advisor1", 0); + var mockAroundAdvisor2 = new MockAroundAdvisor("Advisor2", 1); + + when(chatModel.stream(promptCaptor.capture())) + .thenReturn(Flux.just(new ChatResponse(List.of(new Generation(new AssistantMessage("Hello")))), + new ChatResponse(List.of(new Generation(new AssistantMessage(" John")))))); + + var chatClient = ChatClient.builder(chatModel) + .defaultSystem("Default system text.") + .defaultAdvisors(mockAroundAdvisor1) + .build(); + + var content = chatClient.prompt() + .user("my name is John") + .advisors(a -> a.param("key1", "value1").params(Map.of("key2", "value2"))) + .advisors(mockAroundAdvisor2) + .stream() + .content() + .collectList() + .block() + .stream() + .collect(Collectors.joining()); + + assertThat(content).isEqualTo("Hello John"); + + // AROUND + assertThat(mockAroundAdvisor1.aroundAdvisedResponses).isNotEmpty(); + + mockAroundAdvisor1.aroundAdvisedResponses.stream().forEach(advisedResponse -> { + assertThat(advisedResponse.adviseContext()).containsEntry("key1", "value1") + .containsEntry("key2", "value2") + .containsEntry("aroundStreamBeforeAdvisor1", "AROUND_STREAM_BEFORE Advisor1") + .containsEntry("aroundStreamAfterAdvisor1", "AROUND_STREAM_AFTER Advisor1") + .containsEntry("aroundStreamBeforeAdvisor2", "AROUND_STREAM_BEFORE Advisor2") + .containsEntry("aroundStreamAfterAdvisor2", "AROUND_STREAM_AFTER Advisor2") + .containsEntry("lastBefore", "Advisor2") // inner + .containsEntry("lastAfter", "Advisor1"); // outer + }); + + verify(chatModel).stream(promptCaptor.capture()); + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/observation/DefaultAdvisorObservationConventionTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/observation/DefaultAdvisorObservationConventionTests.java index e56219f8e64..33d3fdecdb8 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/observation/DefaultAdvisorObservationConventionTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/observation/DefaultAdvisorObservationConventionTests.java @@ -42,17 +42,17 @@ void shouldHaveName() { void contextualName() { AdvisorObservationContext observationContext = AdvisorObservationContext.builder() .withAdvisorName("MyName") - .withAdvisorType(AdvisorObservationContext.Type.BEFORE) + .withAdvisorType(AdvisorObservationContext.Type.AROUND) .build(); assertThat(this.observationConvention.getContextualName(observationContext)) - .isEqualTo("chat_client_advisor my_name_before"); + .isEqualTo("chat_client_advisor my_name_around"); } @Test void supportsAdvisorObservationContext() { AdvisorObservationContext observationContext = AdvisorObservationContext.builder() .withAdvisorName("MyName") - .withAdvisorType(AdvisorObservationContext.Type.BEFORE) + .withAdvisorType(AdvisorObservationContext.Type.AROUND) .build(); assertThat(this.observationConvention.supportsContext(observationContext)).isTrue(); assertThat(this.observationConvention.supportsContext(new Observation.Context())).isFalse(); @@ -62,10 +62,10 @@ void supportsAdvisorObservationContext() { void shouldHaveLowCardinalityKeyValuesWhenDefined() { AdvisorObservationContext observationContext = AdvisorObservationContext.builder() .withAdvisorName("MyName") - .withAdvisorType(AdvisorObservationContext.Type.AFTER) + .withAdvisorType(AdvisorObservationContext.Type.AROUND) .build(); assertThat(this.observationConvention.getLowCardinalityKeyValues(observationContext)).contains( - KeyValue.of(LowCardinalityKeyNames.ADVISOR_TYPE.asString(), "AFTER"), + KeyValue.of(LowCardinalityKeyNames.ADVISOR_TYPE.asString(), "AROUND"), KeyValue.of(LowCardinalityKeyNames.SPRING_AI_KIND.asString(), "chat_client_advisor")); } @@ -73,11 +73,13 @@ void shouldHaveLowCardinalityKeyValuesWhenDefined() { void shouldHaveKeyValuesWhenDefinedAndResponse() { AdvisorObservationContext observationContext = AdvisorObservationContext.builder() .withAdvisorName("MyName") - .withAdvisorType(AdvisorObservationContext.Type.AFTER) + .withAdvisorType(AdvisorObservationContext.Type.AROUND) + .withOrder(678) .build(); assertThat(this.observationConvention.getHighCardinalityKeyValues(observationContext)) - .contains(KeyValue.of(HighCardinalityKeyNames.ADVISOR_NAME.asString(), "MyName")); + .contains(KeyValue.of(HighCardinalityKeyNames.ADVISOR_NAME.asString(), "MyName")) + .contains(KeyValue.of(HighCardinalityKeyNames.ADVISOR_ORDER.asString(), "678")); } } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java index 7067139d6a3..0f36c2c9782 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java @@ -25,8 +25,8 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; -import org.springframework.ai.chat.client.AdvisedRequest; import org.springframework.ai.chat.client.DefaultChatClient.DefaultChatClientRequestSpec; +import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; import org.springframework.ai.chat.client.RequestResponseAdvisor; import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation.HighCardinalityKeyNames; import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation.LowCardinalityKeyNames; @@ -96,6 +96,11 @@ public String getName() { return name; } + @Override + public int getOrder() { + return 0; + } + @Override public AdvisedRequest adviseRequest(AdvisedRequest request, Map context) { return request; @@ -148,7 +153,8 @@ void shouldHaveOptionalKeyValues() { ChatClientObservationContext observationContext = new ChatClientObservationContext(request, "json", true); assertThat(this.observationConvention.getHighCardinalityKeyValues(observationContext)).contains( - KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_ADVISORS.asString(), "[\"advisor1\",\"advisor2\"]"), + KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_ADVISORS.asString(), + "[\"advisor1\",\"advisor2\",\"CallAroundAdvisor\",\"StreamAroundAdvisor\"]"), KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_ADVISOR_PARAMS.asString(), "[\"advParam1\":\"advisorParam1Value\"]"), KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_TOOL_FUNCTION_NAMES.asString(), diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/images/advisors-api-classes.jpg b/spring-ai-docs/src/main/antora/modules/ROOT/images/advisors-api-classes.jpg new file mode 100644 index 00000000000..2bec61341e6 Binary files /dev/null and b/spring-ai-docs/src/main/antora/modules/ROOT/images/advisors-api-classes.jpg differ diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/images/advisors-flow.jpg b/spring-ai-docs/src/main/antora/modules/ROOT/images/advisors-flow.jpg new file mode 100644 index 00000000000..a277b190e5a Binary files /dev/null and b/spring-ai-docs/src/main/antora/modules/ROOT/images/advisors-flow.jpg differ diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/images/advisors-non-stream-vs-stream.jpg b/spring-ai-docs/src/main/antora/modules/ROOT/images/advisors-non-stream-vs-stream.jpg new file mode 100644 index 00000000000..1ff2af7b399 Binary files /dev/null and b/spring-ai-docs/src/main/antora/modules/ROOT/images/advisors-non-stream-vs-stream.jpg differ diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc index 873cce6f444..93940645ae7 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc @@ -3,6 +3,7 @@ * xref:getting-started.adoc[Getting Started] * xref:api/index.adoc[] ** xref:api/chatclient.adoc[] +*** xref:api/advisors.adoc[Advisors] ** xref:api/chatmodel.adoc[] *** xref:api/bedrock-chat.adoc[Amazon Bedrock] **** xref:api/chat/bedrock/bedrock-anthropic3.adoc[Anthropic3] @@ -89,7 +90,6 @@ *** xref:api/vectordbs/typesense.adoc[] *** xref:api/vectordbs/weaviate.adoc[] - ** xref:api/functions.adoc[Function Calling] ** xref:api/multimodality.adoc[Multimodality] ** xref:api/prompt.adoc[] diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/advisors.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/advisors.adoc new file mode 100644 index 00000000000..7ba80e82c06 --- /dev/null +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/advisors.adoc @@ -0,0 +1,233 @@ +[[Advisors]] + += Advisors API + +The Spring AI Advisors API provides a flexible and powerful way to intercept, modify, and enhance AI-driven interactions in your Spring applications. +By leveraging the Advisors API, developers can create more sophisticated, reusable, and maintainable AI components. + +The key benefits include encapsulating recurring Generative AI patterns, transforming data sent to and from Language Models (LLMs), and providing portability across various models and use cases. + +You can configure existing advisors using the xref:api/chatclient.adoc#_advisor_configuration_in_chatclient[ChatClient API] as shown in the following example: + +[source,java] +---- +var chatClient = ChatClient.builder(chatModel) + .defaultAdvisors( + new MessageChatMemoryAdvisor(chatMemory), // chat-memory advisor + new QuestionAnswerAdvisor(vectorStore, SearchRequest.defaults()) // RAG advisor + ) + .build(); + +String response = chatClient.prompt() + // Set advisor parameters at runtime + .advisors(advisor -> advisor.param("chat_memory_conversation_id", "678") + .param("chat_memory_response_size", 100)) + .user(userText) + .call() + .content(); +---- + +It is recommend to register the advisors at build time using builder's `defaultAdvisors()` method. + +== Core Components + +The API consists of `CallAroundAdvisor` and `CallAroundAdvisorChain` for non-streaming scenarios, and `StreamAroundAdvisor` and `StreamAroundAdvisorChain` for streaming scenarios. +It also includes `AdvisedRequest` to represent the unsealed Prompt request, `AdvisedResponse` for the Chat Completion response. Both hold an `advise-context` to share state across the advisor chain. + +image::advisors-api-classes.jpg[Advisors API Classes, width=600, align="center"] + +The `nextAroundCall()` and the `nextAroundStream()` are the key advisor methods, typically performing actions such as examining the unsealed Prompt data, customizing and augmenting the Prompt data, invoking the next entity in the advisor chain, optionally blocking the request, examining the chat completion response, and throwing exceptions to indicate processing errors. + +In addition the `getOrder()` method determines advisor order in the chain, while `getName()` provides a unique advisor name. + +The Advisor Chain, created by the Spring AI framework, allows sequential invocation of multiple advisors ordered by their `getOrder()` values. +The lower values are executed first. +The last advisor, added automatically, sends the request to the LLM. + +Following flow diagram illustrates the interaction between the advisor chain and the Chat Model: + +image::advisors-flow.jpg[Advisors API Flow, width=400, align="left"] + +. The Spring AI framework creates an `AdvisedRequest` from user's `Prompt` along with an empty `AdvisorContext` object. +. Each advisor in the chain processes the request, potentially modifying it. Alternatively, it can choose to block the request by not making the call to invoke the next entity. In the latter case, the advisor is responsible for filling out the response. +. The final advisor, provided by the framework, sends the request to the `Chat Model`. +. The Chat Model's response is then passed back through the advisor chain and converted into `AdvisedResponse`. Later includes the shared `AdvisorContext` instance. +. Each advisor can process or modify the response. +. The final `AdvisedResponse` is returned to the client by extracting the `ChatCompletion`. + +=== Advisor Order + +The order of advisors in the chain is determined by the `getOrder()` method. +Advisors with lower order values are executed first. +Because the advisor chain is a stack, the first advisor in the chain is the last to process the request and the first to process the response. +If you want to ensure that an advisor is executed last, set its order close to the `Advisor.LOWEST_PRECEDENCE_ORDER` value and vice versa to execute first set the order close to the `Advisor.HIGHEST_PRECEDENCE_ORDER` value. +If you have multiple advisors with the same order value, the order of execution is not guaranteed. + +TIP: For use cases that need to be first in the chain on the input side and first on the output side, you have to use separate advisors for each side, configured with different order values and use teh advisor context to share state between them. + +== Implementing an Advisor + +To create an advisor, implement either `CallAroundAdvisor` or `StreamAroundAdvisor` (or both). The key method to implement is `nextAroundCall()` for non-streaming or `nextAroundStream()` for streaming advisors. + +=== Examples + +We will provide few hands-on examples to illustrate how to implement advisors for observing and augmenting use-cases. + +==== Logging Advisor + +We can implement a simple logging advisor that logs the `AdvisedRequest` before and the `AdvisedResponse` after the call to the next advisor in the chain. +Note that the advisor only observes the request and response and does not modify them. +This implementation support both non-streaming and streaming scenarios. + +[source,java] +---- +public class SimpleLoggerAdvisor implements CallAroundAdvisor, StreamAroundAdvisor { + + private static final Logger logger = LoggerFactory.getLogger(SimpleLoggerAdvisor.class); + + @Override + public String getName() { // <1> + return this.getClass().getSimpleName(); + } + + @Override + public int getOrder() { // <2> + return 0; + } + + @Override + public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { + + logger.debug("BEFORE: {}", advisedRequest); + + AdvisedResponse advisedResponse = chain.nextAroundCall(advisedRequest); + + logger.debug("AFTER: {}", advisedResponse); + + return advisedResponse; + } + + @Override + public Flux aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) { + + logger.debug("BEFORE: {}", advisedRequest); + + Flux advisedResponses = chain.nextAroundStream(advisedRequest); + + return new MessageAggregator().aggregateAdvisedResponse(advisedResponses, + advisedResponse -> logger.debug("AFTER: {}", advisedResponse)); // <3> + } +} +---- +<1> Provides a unique name for the advisor. +<2> You can control the order of execution by setting the order value. Lower values execute first. +<3> The `MessageAggregator` is a utility class that aggregates the Flux responses into a single AdvisedResponse. +This can be useful for logging or other processing that observe the entire response rather than individual items in the stream. +Note that you can not alter the response in the `MessageAggregator` as it is a read-only operation. + +==== Re-Reading (Re2) Advisor + +The "https://arxiv.org/pdf/2309.06275[Re-Reading Improves Reasoning in Large Language Models]" article introduces a technique called Re-Reading (Re2) that improves the reasoning capabilities of Large Language Models. +The Re2 technique requires augmenting the input prompt like this: + +---- +{Input_Query} +Read the question again: {Input_Query} +---- + +Implementing an advisor that applies the Re2 technique to the user's input query can be done like this: + +[source,java] +---- +public class ReReadingAdvisor implements CallAroundAdvisor, StreamAroundAdvisor { + + + private AdvisedRequest before(AdvisedRequest advisedRequest) { // <1> + + Map advisedUserParams = new HashMap<>(advisedRequest.userParams()); + advisedUserParams.put("re2_input_query", advisedRequest.userText()); + + return AdvisedRequest.from(advisedRequest) + .withUserText(""" + {re2_input_query} + Read the question again: {re2_input_query} + """) + .withUserParams(advisedUserParams) + .build(); + } + + @Override + public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { // <2> + return chain.nextAroundCall(this.before(advisedRequest)); + } + + @Override + public Flux aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) { // <3> + return chain.nextAroundStream(this.before(advisedRequest)); + } + + @Override + public int getOrder() { // <4> + return 0; + } + + @Override + public String getName() { // <5> + return this.getClass().getSimpleName(); + } +} +---- +<1> The `before` method augments the user's input query applying the Re-Reading technique. +<2> The `aroundCall` method intercepts the non-streaming request and applies the Re-Reading technique. +<3> The `aroundStream` method intercepts the streaming request and applies the Re-Reading technique. +<4> You can control the order of execution by setting the order value. Lower values execute first. +<5> Provides a unique name for the advisor. + +==== Spring AI built-in Advisors + +You can also explore the built-in advisors provided by the Spring AI framework. +For example the `MessageChatMemoryAdvisor`, `PromptChatMemoryAdvisor` and `VectorStoreChatMemoryAdvisor` advisors provide different strategies the conversation chat history in a chat memory store and the `QuestionAnswerAdvisor` uses a vector store to provide question-answering capabilities (e.g. implements the RAG pattern). + +The `SafeGuardAdvisor` is another, simple, built-in advisor that can be used to prevent the model from generating harmful or inappropriate content. + +=== Streaming vs Non-Streaming + +image::advisors-non-stream-vs-stream.jpg[Advisors Streaming vs Non-Streaming Flow, width=800, align="left"] + +* Non-streaming advisors work with complete requests and responses. +* Streaming advisors handle requests and responses as continuous streams, using reactive programming concepts (e.g., Flux for responses). + + +// TODO - Add a section on how to implement a streaming advisor with blocking and non-blocking code. + +[source,java] +---- +@Override +public Flux aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) { + + return Mono.just(advisedRequest) + .publishOn(Schedulers.boundedElastic()) + .map(request -> { + // This can be executed by blocking and non-blocking Threads. + // Advisor before next section + }) + .flatMapMany(request -> chain.nextAroundStream(request)) + .map(response -> { + // Advisor after next section + }); +} +---- + +=== Best Practices + +. Keep advisors focused on specific tasks for better modularity. +. Use the `adviseContext` to share state between advisors when necessary. +. Implement both streaming and non-streaming versions of your advisor for maximum flexibility. +. Carefully consider the order of advisors in your chain to ensure proper data flow. + + +== Backward Compatibility + +IMPORTANT: The `AdvisedRequest` class is moved to a new package. +While the `RequestResponseAdvisor` interface is still available it is marked as deprecated and will be removed around the M3 release. +It is recommended to use the new `CallAroundAdvisor` and `StreamAroundAdvisor` interfaces for new implementations. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc index 9d9c584a759..1be8ec5fce3 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc @@ -300,7 +300,7 @@ At the `ChatClient.Builder` level, you can specify the default prompt configurat * `defaultUser(String text)`, `defaultUser(Resource text)`, `defaultUser(Consumer userSpecConsumer)`: These methods let you define the user text. The `Consumer` allows you to use a lambda to specify the user text and any default parameters. -* `defaultAdvisors(RequestResponseAdvisor... advisor)`: Advisors allow modification of the data used to create the `Prompt`. The `QuestionAnswerAdvisor` implementation enables the pattern of `Retrieval Augmented Generation` by appending the prompt with context information related to the user text. +* `defaultAdvisors(Advisor... advisor)`: Advisors allow modification of the data used to create the `Prompt`. The `QuestionAnswerAdvisor` implementation enables the pattern of `Retrieval Augmented Generation` by appending the prompt with context information related to the user text. * `defaultAdvisors(Consumer advisorSpecConsumer)`: This method allows you to define a `Consumer` to configure multiple advisors using the `AdvisorSpec`. Advisors can modify the data used to create the final `Prompt`. The `Consumer` lets you specify a lambda to add advisors, such as `QuestionAnswerAdvisor`, which supports `Retrieval Augmented Generation` by appending the prompt with relevant context information based on the user text. @@ -315,12 +315,14 @@ java.util.function.Function function)` * `user(String text)`, `user(Resource text)`, `user(Consumer userSpecConsumer)` -* `advisors(RequestResponseAdvisor... advisor)` +* `advisors(Advisor... advisor)` * `advisors(Consumer advisorSpecConsumer)` == Advisors +The xref:api/advisors.adoc[Advisors API] provides a flexible and powerful way to intercept, modify, and enhance AI-driven interactions in your Spring applications. + A common pattern when calling an AI model with user text is to append or augment the prompt with contextual data. This contextual data can be of different types. Common types include: @@ -349,18 +351,18 @@ IMPORTANT: The order in which advisors are added to the chain is crucial, as it [source,java] ---- ChatClient.builder(chatModel) + .defaultAdvisors( + new MessageChatMemoryAdvisor(chatMemory), // chat-memory advisor + new QuestionAnswerAdvisor(vectorStore, SearchRequest.defaults()) // RAG advisor + ) .build() .prompt() - .advisors( - new ChatMemoryAdvisor(chatMemory), - new QuestionAnswerAdvisor(vectorStore, SearchRequest.defaults()) - ) .user(userText) - .call(); + .call() + .content(); ---- -In this configuration, the `ChatMemoryAdvisor` will be executed first, adding the conversation history to the prompt. Then, the `QuestionAnswerAdvisor` will perform its search based on the user's question and the added conversation history, potentially providing more relevant results. - +In this configuration, the `MessageChatMemoryAdvisor` will be executed first, adding the conversation history to the prompt. Then, the `QuestionAnswerAdvisor` will perform its search based on the user's question and the added conversation history, potentially providing more relevant results. === Retrieval Augmented Generation @@ -374,8 +376,8 @@ Assuming you have already loaded data into a `VectorStore`, you can perform Retr [source,java] ---- ChatResponse response = ChatClient.builder(chatModel) - .build().prompt() - .advisors(new QuestionAnswerAdvisor(vectorStore, SearchRequest.defaults())) + .defaultAdvisors(new QuestionAnswerAdvisor(vectorStore, SearchRequest.defaults())) + .build().prompt() .user(userText) .call() .chatResponse(); diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreWithChatMemoryAdvisorIT.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreWithChatMemoryAdvisorIT.java index d37962d428d..341b1dd5480 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreWithChatMemoryAdvisorIT.java +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreWithChatMemoryAdvisorIT.java @@ -16,15 +16,21 @@ package org.springframework.ai.vectorstore; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.List; +import java.util.Map; + import org.jetbrains.annotations.NotNull; -import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; import org.mockito.ArgumentMatchers; import org.mockito.Mockito; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; import org.postgresql.ds.PGSimpleDataSource; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.advisor.VectorStoreChatMemoryAdvisor; @@ -41,12 +47,6 @@ import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; -import java.util.List; -import java.util.Map; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.*; - /** * @author Fabian Krüger * @author Soby Chacko