diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiPaymentTransactionIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiPaymentTransactionIT.java index 5b0cd73e441..a052a3a4f8a 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiPaymentTransactionIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiPaymentTransactionIT.java @@ -28,11 +28,11 @@ import org.junit.jupiter.params.provider.ValueSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.ai.chat.client.AdvisedRequest; import org.springframework.ai.chat.client.ChatClient; -import org.springframework.ai.chat.client.advisor.api.RequestAdvisor; -import org.springframework.ai.chat.client.advisor.api.ResponseAdvisor; -import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; +import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; +import org.springframework.ai.chat.client.advisor.api.BeforeAdvisor; +import org.springframework.ai.chat.client.advisor.api.AfterAdvisor; import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.model.function.FunctionCallbackContext; import org.springframework.ai.openai.OpenAiChatModel; @@ -65,7 +65,7 @@ public class OpenAiPaymentTransactionIT { record TransactionStatusResponse(String id, String status) { } - private static class LoggingAdvisor implements RequestAdvisor, ResponseAdvisor { + private static class LoggingAdvisor implements BeforeAdvisor, AfterAdvisor { private final Logger logger = LoggerFactory.getLogger(LoggingAdvisor.class); @@ -74,7 +74,7 @@ public String getName() { } @Override - public AdvisedRequest adviseRequest(AdvisedRequest request, Map context) { + public AdvisedRequest before(AdvisedRequest request) { logger.info("System text: \n" + request.systemText()); logger.info("System params: " + request.systemParams()); logger.info("User text: \n" + request.userText()); @@ -87,9 +87,9 @@ public AdvisedRequest adviseRequest(AdvisedRequest request, Map } @Override - public ChatResponse adviseResponse(ChatResponse response, Map context) { - logger.info("Response: " + response); - return response; + public AdvisedResponse afterCall(AdvisedResponse advisedResponse) { + logger.info("Response: " + advisedResponse.response()); + return advisedResponse; } } diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiPaymentTransactionIT.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiPaymentTransactionIT.java index 3954019d4cc..ccd0b679941 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiPaymentTransactionIT.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiPaymentTransactionIT.java @@ -28,11 +28,11 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.ai.chat.client.AdvisedRequest; import org.springframework.ai.chat.client.ChatClient; -import org.springframework.ai.chat.client.advisor.api.RequestAdvisor; -import org.springframework.ai.chat.client.advisor.api.ResponseAdvisor; -import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; +import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; +import org.springframework.ai.chat.client.advisor.api.BeforeAdvisor; +import org.springframework.ai.chat.client.advisor.api.AfterAdvisor; import org.springframework.ai.model.function.FunctionCallbackContext; import org.springframework.ai.model.function.FunctionCallbackWrapper.Builder.SchemaType; import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel; @@ -65,7 +65,7 @@ public class VertexAiGeminiPaymentTransactionIT { record TransactionStatusResponse(String id, String status) { } - private static class LoggingAdvisor implements RequestAdvisor, ResponseAdvisor { + private static class LoggingAdvisor implements BeforeAdvisor, AfterAdvisor { private final Logger logger = LoggerFactory.getLogger(LoggingAdvisor.class); @@ -75,7 +75,7 @@ public String getName() { } @Override - public AdvisedRequest adviseRequest(AdvisedRequest request, Map context) { + public AdvisedRequest before(AdvisedRequest request) { logger.info("System text: \n" + request.systemText()); logger.info("System params: " + request.systemParams()); logger.info("User text: \n" + request.userText()); @@ -88,9 +88,9 @@ public AdvisedRequest adviseRequest(AdvisedRequest request, Map } @Override - public ChatResponse adviseResponse(ChatResponse response, Map context) { - logger.info("Response: " + response); - return response; + public AdvisedResponse afterCall(AdvisedResponse advisedResponse) { + logger.info("Response: " + advisedResponse.response()); + return advisedResponse; } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java index 910d8c0d014..6456ef08fc6 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java @@ -21,20 +21,24 @@ import java.nio.charset.Charset; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import org.springframework.ai.chat.client.advisor.DefaultAroundAdvisorChain; +import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; +import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; import org.springframework.ai.chat.client.advisor.api.Advisor; import org.springframework.ai.chat.client.advisor.api.AroundAdvisorChain; import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; -import org.springframework.ai.chat.client.advisor.api.RequestAdvisor; -import org.springframework.ai.chat.client.advisor.api.ResponseAdvisor; -import org.springframework.ai.chat.client.advisor.api.ResponseAdvisor.StreamResponseMode; +import org.springframework.ai.chat.client.advisor.api.BeforeAdvisor; +import org.springframework.ai.chat.client.advisor.api.AfterAdvisor; +import org.springframework.ai.chat.client.advisor.api.AfterAdvisor.AfterStreamMode; import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; import org.springframework.ai.chat.client.advisor.observation.AdvisorObservableHelper; import org.springframework.ai.chat.client.observation.ChatClientObservationContext; @@ -379,41 +383,23 @@ private ChatResponse doGetObservableChatResponse(DefaultChatClientRequestSpec in private ChatResponse doGetChatResponse(DefaultChatClientRequestSpec inputRequestSpec, String formatParam, Observation parentObservation) { - Map advisorContext = new ConcurrentHashMap<>(); - if (StringUtils.hasText(formatParam)) { - advisorContext.put("formatParam", formatParam); - } - advisorContext.putAll(inputRequestSpec.getAdvisorParams()); - - // DefaultChatClientRequestSpec advisedRequestSpec = inputRequestSpec; - AdvisedRequest advisedRequest = toAdvisedRequest(inputRequestSpec); - if (!CollectionUtils.isEmpty(inputRequestSpec.advisors)) { - - // Apply the Request advisors - var currentAdvisors = new ArrayList<>( - AdvisorObservableHelper.extractRequestAdvisors(inputRequestSpec.advisors)); - for (RequestAdvisor advisor : currentAdvisors) { - advisedRequest = AdvisorObservableHelper.adviseRequest(parentObservation, advisor, advisedRequest, - advisorContext); - } + AdvisedRequest advisedRequest = toAdvisedRequest(inputRequestSpec, formatParam); + + // Apply the Request advisors + for (BeforeAdvisor advisor : AdvisorObservableHelper.requestAdvisors(inputRequestSpec.advisors)) { + advisedRequest = AdvisorObservableHelper.adviseRequest(parentObservation, advisor, advisedRequest); } // Apply the around advisor chain that terminates with the, last, model call // advisor. - ChatResponse advisedResponse = inputRequestSpec.aroundAdvisorChain.nextAroundCall(advisedRequest, - advisorContext); + AdvisedResponse advisedResponse = inputRequestSpec.aroundAdvisorChain.nextAroundCall(advisedRequest); // Apply the Response advisors. - if (!CollectionUtils.isEmpty(inputRequestSpec.getAdvisors())) { - var currentAdvisors = new ArrayList<>( - AdvisorObservableHelper.extractResponseAdvisors(inputRequestSpec.getAdvisors())); - for (ResponseAdvisor advisor : currentAdvisors) { - advisedResponse = AdvisorObservableHelper.adviseResponse(parentObservation, advisor, - advisedResponse, advisorContext); - } + for (AfterAdvisor advisor : AdvisorObservableHelper.responseAdvisors(inputRequestSpec.getAdvisors())) { + advisedResponse = AdvisorObservableHelper.adviseResponse(parentObservation, advisor, advisedResponse); } - return advisedResponse; + return advisedResponse.response(); } public ChatResponse chatResponse() { @@ -426,46 +412,6 @@ public String content() { } - private static Prompt toPrompt(AdvisedRequest advisedRequest, String formatParam) { - - var messages = new ArrayList(advisedRequest.messages()); - - String processedSystemText = advisedRequest.systemText(); - if (StringUtils.hasText(processedSystemText)) { - if (!CollectionUtils.isEmpty(advisedRequest.systemParams())) { - processedSystemText = new PromptTemplate(processedSystemText, advisedRequest.systemParams()).render(); - } - messages.add(new SystemMessage(processedSystemText)); - } - - var processedUserText = StringUtils.hasText(formatParam) - ? advisedRequest.userText() + System.lineSeparator() + "{spring_ai_soc_format}" - : advisedRequest.userText(); - - if (StringUtils.hasText(processedUserText)) { - - Map userParams = new HashMap<>(advisedRequest.userParams()); - if (StringUtils.hasText(formatParam)) { - userParams.put("spring_ai_soc_format", formatParam); - } - if (!CollectionUtils.isEmpty(userParams)) { - processedUserText = new PromptTemplate(processedUserText, userParams).render(); - } - messages.add(new UserMessage(processedUserText, advisedRequest.media())); - } - - if (advisedRequest.chatOptions() instanceof FunctionCallingOptions functionCallingOptions) { - if (!advisedRequest.functionNames().isEmpty()) { - functionCallingOptions.setFunctions(new HashSet<>(advisedRequest.functionNames())); - } - if (!advisedRequest.functionCallbacks().isEmpty()) { - functionCallingOptions.setFunctionCallbacks(advisedRequest.functionCallbacks()); - } - } - - return new Prompt(messages, advisedRequest.chatOptions()); - } - public static class DefaultStreamResponseSpec implements StreamResponseSpec { private final DefaultChatClientRequestSpec request; @@ -498,104 +444,110 @@ private Flux doGetObservableFluxChatResponse(DefaultChatClientRequ }); } - record AdvisedRequestWithContext(AdvisedRequest request, Map advisorContext) { - } - private Flux doGetFluxChatResponse(DefaultChatClientRequestSpec inputRequest, Observation parentObservation) { - Map advisorContext = new ConcurrentHashMap<>(inputRequest.getAdvisorParams()); + var advisedRequest0 = toAdvisedRequest(inputRequest, ""); - var reqWithContext = new AdvisedRequestWithContext(toAdvisedRequest(inputRequest), advisorContext); - - return Flux.fromIterable(AdvisorObservableHelper.extractRequestAdvisors(inputRequest.advisors)) + return Flux.fromIterable(AdvisorObservableHelper.requestAdvisors(inputRequest.advisors)) .transformDeferredContextual((f, ctx) -> f // This allows us to call blocking code in reduce .publishOn(Schedulers.boundedElastic()) - .reduce(reqWithContext, (rwc, advisor) -> { + .reduce(advisedRequest0, (ar, advisor) -> { // Apply the Request advisors - AdvisedRequest advisedRequest = AdvisorObservableHelper.adviseRequest(parentObservation, - advisor, rwc.request, rwc.advisorContext); - return new AdvisedRequestWithContext(advisedRequest, rwc.advisorContext); + return AdvisorObservableHelper.adviseRequest(parentObservation, advisor, ar); })) .single() - .flatMapMany(rwc -> { + .flatMapMany(advisedRequest -> { // Apply the around advisor chain that terminates with the, last, // model call advisor. - Flux advisedResponse = inputRequest.aroundAdvisorChain.nextAroundStream(rwc.request, - rwc.advisorContext); + Flux advisedResponses = inputRequest.aroundAdvisorChain + .nextAroundStream(advisedRequest); // Apply the Response advisors if (!CollectionUtils.isEmpty(inputRequest.getAdvisors())) { var responseAdvisors = new ArrayList<>( - AdvisorObservableHelper.extractResponseAdvisors(inputRequest.getAdvisors())); + AdvisorObservableHelper.responseAdvisors(inputRequest.getAdvisors())); - List perElementResponseAdvisors = responseAdvisors.stream() - .filter(a -> a.getStreamResponseMode() == StreamResponseMode.PER_ELEMENT) + List perElementResponseAdvisors = responseAdvisors.stream() + .filter(a -> a.getAfterStreamMode() == AfterStreamMode.PER_ELEMENT) .toList(); - List onFinishElementResponseAdvisors = responseAdvisors.stream() - .filter(a -> a.getStreamResponseMode() == StreamResponseMode.ON_FINISH_ELEMENT) + List onFinishElementResponseAdvisors = responseAdvisors.stream() + .filter(a -> a.getAfterStreamMode() == AfterStreamMode.ON_FINISH_ELEMENT) .toList(); // PER_ELEMENT and ON_FINISH_ELEMENT - advisedResponse = advisedResponse.map(response -> { + advisedResponses = advisedResponses.map(advisedResponse -> { + // PER_ELEMENT if (!CollectionUtils.isEmpty(perElementResponseAdvisors)) { - for (ResponseAdvisor advisor : perElementResponseAdvisors) { - response = AdvisorObservableHelper.adviseResponse(parentObservation, advisor, - response, rwc.advisorContext); + for (AfterAdvisor advisor : perElementResponseAdvisors) { + advisedResponse = AdvisorObservableHelper.adviseResponse(parentObservation, advisor, + advisedResponse); } } // ON_FINISH_ELEMENT if (!CollectionUtils.isEmpty(onFinishElementResponseAdvisors)) { - for (ResponseAdvisor advisor : onFinishElementResponseAdvisors) { - boolean withFinishReason = response.getResults() - .stream() - .filter(result -> result != null && result.getMetadata() != null - && StringUtils.hasText(result.getMetadata().getFinishReason())) - .findFirst() - .isPresent(); - - if (withFinishReason) { - response = AdvisorObservableHelper.adviseResponse(parentObservation, advisor, - response, advisorContext); + boolean withFinishReason = advisedResponse.response() + .getResults() + .stream() + .filter(result -> result != null && result.getMetadata() != null + && StringUtils.hasText(result.getMetadata().getFinishReason())) + .findFirst() + .isPresent(); + + if (withFinishReason) { + for (AfterAdvisor advisor : onFinishElementResponseAdvisors) { + advisedResponse = AdvisorObservableHelper.adviseResponse(parentObservation, + advisor, advisedResponse); } } } - return response; + + return advisedResponse; }); // CUSTOM // TODO: how to pass the parentObservation to the custom response // advisor? - List customResponseAdvisors = responseAdvisors.stream() - .filter(a -> a.getStreamResponseMode() == StreamResponseMode.CUSTOM) + List customResponseAdvisors = responseAdvisors.stream() + .filter(a -> a.getAfterStreamMode() == AfterStreamMode.CUSTOM) .toList(); if (!CollectionUtils.isEmpty(customResponseAdvisors)) { - for (ResponseAdvisor advisor : customResponseAdvisors) { - advisedResponse = advisor.adviseResponse(advisedResponse, rwc.advisorContext); + for (AfterAdvisor advisor : customResponseAdvisors) { + advisedResponses = advisor.afterStream(advisedResponses); } } // AGGREGATE - List aggregateResponseAdvisors = responseAdvisors.stream() - .filter(a -> a.getStreamResponseMode() == StreamResponseMode.AGGREGATE) + List aggregateResponseAdvisors = responseAdvisors.stream() + .filter(a -> a.getAfterStreamMode() == AfterStreamMode.AGGREGATE) .toList(); if (!CollectionUtils.isEmpty(aggregateResponseAdvisors)) { - advisedResponse = new MessageAggregator().aggregate(advisedResponse, chatResponse -> { - for (ResponseAdvisor advisor : aggregateResponseAdvisors) { - AdvisorObservableHelper.adviseResponse(parentObservation, advisor, chatResponse, - advisorContext); + AtomicReference> adviseContext = new AtomicReference<>(new HashMap<>()); + + advisedResponses = new MessageAggregator().aggregate(advisedResponses.map(ar -> { + adviseContext.get().putAll(ar.adviseContext()); + return ar.response(); + }), chatResponse -> { + AdvisedResponse advisedResponse = AdvisedResponse.builder() + .withResponse(chatResponse) + .withAdviseContext(adviseContext.get()) + .build(); + for (AfterAdvisor advisor : aggregateResponseAdvisors) { + advisedResponse = AdvisorObservableHelper.adviseResponse(parentObservation, advisor, + advisedResponse); + adviseContext.set(advisedResponse.adviseContext()); } - }); + }).map(cr -> new AdvisedResponse(cr, adviseContext.get())); } } - return advisedResponse; + return advisedResponses.map(ar -> ar.response()); }); } @@ -747,10 +699,8 @@ public String getName() { } @Override - public ChatResponse aroundCall(AdvisedRequest advisedRequest, Map adviceContext, - AroundAdvisorChain chain) { - String formatParam = (String) adviceContext.get("formatParam"); - return chatModel.call(toPrompt(advisedRequest, formatParam)); + public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, AroundAdvisorChain chain) { + return new AdvisedResponse(chatModel.call(advisedRequest.toPrompt()), Collections.unmodifiableMap(advisedRequest.adviseContext())); } }) .push(new StreamAroundAdvisor() { @@ -760,9 +710,9 @@ public String getName() { return StreamAroundAdvisor.class.getSimpleName(); } @Override - public Flux aroundStream(AdvisedRequest advisedRequest, Map adviceContext, - AroundAdvisorChain chain) { - return chatModel.stream(toPrompt(advisedRequest, null)); + public Flux aroundStream(AdvisedRequest advisedRequest, AroundAdvisorChain chain) { + return chatModel.stream(advisedRequest.toPrompt()) + .map( chatResponse -> new AdvisedResponse(chatResponse, Collections.unmodifiableMap(advisedRequest.adviseContext()))); } }) .pushAll(this.advisors) @@ -944,11 +894,16 @@ public StreamResponseSpec stream() { } - private static AdvisedRequest toAdvisedRequest(DefaultChatClientRequestSpec inputRequest) { + private static AdvisedRequest toAdvisedRequest(DefaultChatClientRequestSpec inputRequest, String formatParam) { + Map advisorContext = new ConcurrentHashMap<>(inputRequest.getAdvisorParams()); + if (StringUtils.hasText(formatParam)) { + advisorContext.put("formatParam", formatParam); + } + return new AdvisedRequest(inputRequest.chatModel, inputRequest.userText, inputRequest.systemText, inputRequest.chatOptions, inputRequest.media, inputRequest.functionNames, inputRequest.functionCallbacks, inputRequest.messages, inputRequest.userParams, - inputRequest.systemParams, inputRequest.advisors, inputRequest.advisorParams); + inputRequest.systemParams, inputRequest.advisors, inputRequest.advisorParams, advisorContext); } public static DefaultChatClientRequestSpec toDefaultChatClientRequestSpec(AdvisedRequest advisedRequest, diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/RequestResponseAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/RequestResponseAdvisor.java index ab1ca3352d5..d7111c6753e 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/RequestResponseAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/RequestResponseAdvisor.java @@ -16,41 +16,74 @@ package org.springframework.ai.chat.client; -import org.springframework.ai.chat.client.advisor.api.ResponseAdvisor; +import org.springframework.ai.chat.client.advisor.api.AfterAdvisor; +import java.util.Collections; +import java.util.HashMap; import java.util.Map; -import org.springframework.ai.chat.client.advisor.api.RequestAdvisor; +import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; +import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; +import org.springframework.ai.chat.client.advisor.api.BeforeAdvisor; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; +import reactor.core.publisher.Flux; + /** * Advisor called before and after the {@link ChatModel#call(Prompt)} and * {@link ChatModel#stream(Prompt)} methods calls. The {@link ChatClient} maintains a * chain of advisors with shared advise context. * - * @deprecated since 1.0.0 please use {@link RequestAdvisor}, {@link ResponseAdvisor} - * instead. + * @deprecated since 1.0.0 please use {@link BeforeAdvisor}, {@link AfterAdvisor} instead. * @author Christian Tzolov * @since 1.0.0 */ @Deprecated -public interface RequestResponseAdvisor extends RequestAdvisor, ResponseAdvisor { +public interface RequestResponseAdvisor extends BeforeAdvisor, AfterAdvisor { @Override default String getName() { return this.getClass().getSimpleName(); } - @Override default AdvisedRequest adviseRequest(AdvisedRequest request, Map adviseContext) { return request; } @Override + default AdvisedRequest before(AdvisedRequest request) { + var context = new HashMap<>(request.adviseContext()); + var requestPrim = adviseRequest(request, context); + return AdvisedRequest.from(requestPrim).withAdviseContext(Collections.unmodifiableMap(context)).build(); + } + default ChatResponse adviseResponse(ChatResponse response, Map adviseContext) { return response; } + @Override + default AdvisedResponse afterCall(AdvisedResponse advisedResponse) { + var context = new HashMap<>(advisedResponse.adviseContext()); + var chatResponse = adviseResponse(advisedResponse.response(), context); + return new AdvisedResponse(chatResponse, Collections.unmodifiableMap(context)); + } + + default Flux adviseResponse(Flux fluxResponse, Map context) { + return fluxResponse; + } + + @Override + default Flux afterStream(Flux advisedResponseStream) { + + // TODO: this allows to modify the context for each chat response element in the + // stream. + return advisedResponseStream.map(advisedResponse -> { + var context = new HashMap<>(advisedResponse.adviseContext()); + var chatResponse = adviseResponse(advisedResponse.response(), context); + return new AdvisedResponse(chatResponse, Collections.unmodifiableMap(context)); + }); + } + } \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java index abfb41db456..4b8727c5fb2 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java @@ -18,8 +18,8 @@ import java.util.Map; -import org.springframework.ai.chat.client.advisor.api.ResponseAdvisor; -import org.springframework.ai.chat.client.advisor.api.RequestAdvisor; +import org.springframework.ai.chat.client.advisor.api.AfterAdvisor; +import org.springframework.ai.chat.client.advisor.api.BeforeAdvisor; import org.springframework.util.Assert; /** @@ -29,7 +29,7 @@ * @author Christian Tzolov * @since 1.0.0 M1 */ -public abstract class AbstractChatMemoryAdvisor implements RequestAdvisor, ResponseAdvisor { +public abstract class AbstractChatMemoryAdvisor implements BeforeAdvisor, AfterAdvisor { public static final String CHAT_MEMORY_CONVERSATION_ID_KEY = "chat_memory_conversation_id"; @@ -66,8 +66,8 @@ public String getName() { } @Override - public StreamResponseMode getStreamResponseMode() { - return StreamResponseMode.AGGREGATE; + public AfterStreamMode getAfterStreamMode() { + return AfterStreamMode.AGGREGATE; } protected T getChatMemoryStore() { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChain.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChain.java index a3b56b75f11..c6ebb74bc9f 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChain.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChain.java @@ -3,9 +3,9 @@ import java.util.ArrayDeque; import java.util.Deque; import java.util.List; -import java.util.Map; -import org.springframework.ai.chat.client.AdvisedRequest; +import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; +import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; import org.springframework.ai.chat.client.advisor.api.Advisor; import org.springframework.ai.chat.client.advisor.api.AroundAdvisorChain; import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; @@ -13,7 +13,6 @@ import org.springframework.ai.chat.client.advisor.observation.AdvisorObservableHelper; import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationContext; import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationDocumentation; -import org.springframework.ai.chat.model.ChatResponse; import org.springframework.util.Assert; import io.micrometer.observation.ObservationRegistry; @@ -71,7 +70,7 @@ public void push(Advisor aroundAdvisor) { } @Override - public ChatResponse nextAroundCall(AdvisedRequest advisedRequest, Map adviceContext) { + public AdvisedResponse nextAroundCall(AdvisedRequest advisedRequest) { if (this.callAroundAdvisors.isEmpty()) { throw new IllegalStateException("No AroundAdvisor available to execute"); @@ -83,17 +82,17 @@ public ChatResponse nextAroundCall(AdvisedRequest advisedRequest, Map observationContext, this.observationRegistry) - .observe(() -> advisor.aroundCall(advisedRequest, adviceContext, this)); + .observe(() -> advisor.aroundCall(advisedRequest, this)); } @Override - public Flux nextAroundStream(AdvisedRequest advisedRequest, Map adviceContext) { + public Flux nextAroundStream(AdvisedRequest advisedRequest) { return Flux.deferContextual(contextView -> { @@ -107,7 +106,7 @@ public Flux nextAroundStream(AdvisedRequest advisedRequest, Map nextAroundStream(AdvisedRequest advisedRequest, Map { - observation.stop(); - }) - .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); + return advisor.aroundStream(advisedRequest, this).doOnError(observation::error).doFinally(s -> { + observation.stop(); + }).contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); }); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java index 5f125f0bd32..2008a041fbd 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java @@ -18,13 +18,12 @@ import java.util.ArrayList; import java.util.List; -import java.util.Map; -import org.springframework.ai.chat.client.AdvisedRequest; +import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; +import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; import org.springframework.ai.chat.memory.ChatMemory; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; -import org.springframework.ai.chat.model.ChatResponse; /** * Memory is retrieved added as a collection of messages to the prompt @@ -43,11 +42,11 @@ public MessageChatMemoryAdvisor(ChatMemory chatMemory, String defaultConversatio } @Override - public AdvisedRequest adviseRequest(AdvisedRequest request, Map context) { + public AdvisedRequest before(AdvisedRequest request) { - String conversationId = this.doGetConversationId(context); + String conversationId = this.doGetConversationId(request.adviseContext()); - int chatMemoryRetrieveSize = this.doGetChatMemoryRetrieveSize(context); + int chatMemoryRetrieveSize = this.doGetChatMemoryRetrieveSize(request.adviseContext()); // 1. Retrieve the chat memory for the current conversation. List memoryMessages = this.getChatMemoryStore().get(conversationId, chatMemoryRetrieveSize); @@ -61,19 +60,23 @@ public AdvisedRequest adviseRequest(AdvisedRequest request, Map // 4. Add the new user input to the conversation memory. UserMessage userMessage = new UserMessage(request.userText(), request.media()); - this.getChatMemoryStore().add(this.doGetConversationId(context), userMessage); + this.getChatMemoryStore().add(this.doGetConversationId(request.adviseContext()), userMessage); return advisedRequest; } @Override - public ChatResponse adviseResponse(ChatResponse chatResponse, Map context) { + public AdvisedResponse afterCall(AdvisedResponse advisedResponse) { - List assistantMessages = chatResponse.getResults().stream().map(g -> (Message) g.getOutput()).toList(); + List assistantMessages = advisedResponse.response() + .getResults() + .stream() + .map(g -> (Message) g.getOutput()) + .toList(); - this.getChatMemoryStore().add(this.doGetConversationId(context), assistantMessages); + this.getChatMemoryStore().add(this.doGetConversationId(advisedResponse.adviseContext()), assistantMessages); - return chatResponse; + return advisedResponse; } } \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java index 19af746d607..3c8c7484c17 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java @@ -21,12 +21,12 @@ import java.util.Map; import java.util.stream.Collectors; -import org.springframework.ai.chat.client.AdvisedRequest; +import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; +import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; import org.springframework.ai.chat.memory.ChatMemory; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.messages.UserMessage; -import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.model.Content; /** @@ -66,11 +66,12 @@ public PromptChatMemoryAdvisor(ChatMemory chatMemory, String defaultConversation } @Override - public AdvisedRequest adviseRequest(AdvisedRequest request, Map context) { + public AdvisedRequest before(AdvisedRequest request) { // 1. Advise system parameters. List memoryMessages = this.getChatMemoryStore() - .get(this.doGetConversationId(context), this.doGetChatMemoryRetrieveSize(context)); + .get(this.doGetConversationId(request.adviseContext()), + this.doGetChatMemoryRetrieveSize(request.adviseContext())); String memory = (memoryMessages != null) ? memoryMessages.stream() .filter(m -> m.getMessageType() == MessageType.USER || m.getMessageType() == MessageType.ASSISTANT) @@ -91,19 +92,23 @@ public AdvisedRequest adviseRequest(AdvisedRequest request, Map // 4. Add the new user input to the conversation memory. UserMessage userMessage = new UserMessage(request.userText(), request.media()); - this.getChatMemoryStore().add(this.doGetConversationId(context), userMessage); + this.getChatMemoryStore().add(this.doGetConversationId(request.adviseContext()), userMessage); return advisedRequest; } @Override - public ChatResponse adviseResponse(ChatResponse chatResponse, Map context) { + public AdvisedResponse afterCall(AdvisedResponse advisedResponse) { - List assistantMessages = chatResponse.getResults().stream().map(g -> (Message) g.getOutput()).toList(); + List assistantMessages = advisedResponse.response() + .getResults() + .stream() + .map(g -> (Message) g.getOutput()) + .toList(); - this.getChatMemoryStore().add(this.doGetConversationId(context), assistantMessages); + this.getChatMemoryStore().add(this.doGetConversationId(advisedResponse.adviseContext()), assistantMessages); - return chatResponse; + return advisedResponse; } } \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java index 2b36a4c410d..ff00d403512 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java @@ -21,9 +21,10 @@ import java.util.Map; import java.util.stream.Collectors; -import org.springframework.ai.chat.client.AdvisedRequest; -import org.springframework.ai.chat.client.advisor.api.RequestAdvisor; -import org.springframework.ai.chat.client.advisor.api.ResponseAdvisor; +import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; +import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; +import org.springframework.ai.chat.client.advisor.api.BeforeAdvisor; +import org.springframework.ai.chat.client.advisor.api.AfterAdvisor; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.document.Document; import org.springframework.ai.model.Content; @@ -41,7 +42,7 @@ * @author Christian Tzolov * @since 1.0.0 */ -public class QuestionAnswerAdvisor implements RequestAdvisor, ResponseAdvisor { +public class QuestionAnswerAdvisor implements BeforeAdvisor, AfterAdvisor { private static final String DEFAULT_USER_TEXT_ADVISE = """ Context information is below. @@ -98,7 +99,9 @@ public String getName() { } @Override - public AdvisedRequest adviseRequest(AdvisedRequest request, Map context) { + public AdvisedRequest before(AdvisedRequest request) { + + var context = new HashMap<>(request.adviseContext()); // 1. Advise the system text. String advisedUserText = request.userText() + System.lineSeparator() + this.userTextAdvise; @@ -124,21 +127,22 @@ public AdvisedRequest adviseRequest(AdvisedRequest request, Map AdvisedRequest advisedRequest = AdvisedRequest.from(request) .withUserText(advisedUserText) .withUserParams(advisedUserParams) + .withAdviseContext(context) .build(); return advisedRequest; } @Override - public ChatResponse adviseResponse(ChatResponse response, Map context) { - ChatResponse.Builder chatResponseBuilder = ChatResponse.builder().from(response); - chatResponseBuilder.withMetadata(RETRIEVED_DOCUMENTS, context.get(RETRIEVED_DOCUMENTS)); - return chatResponseBuilder.build(); + public AdvisedResponse afterCall(AdvisedResponse advisedResponse) { + ChatResponse.Builder chatResponseBuilder = ChatResponse.builder().from(advisedResponse.response()); + chatResponseBuilder.withMetadata(RETRIEVED_DOCUMENTS, advisedResponse.adviseContext().get(RETRIEVED_DOCUMENTS)); + return new AdvisedResponse(chatResponseBuilder.build(), advisedResponse.adviseContext()); } @Override - public StreamResponseMode getStreamResponseMode() { - return StreamResponseMode.ON_FINISH_ELEMENT; + public AfterStreamMode getAfterStreamMode() { + return AfterStreamMode.ON_FINISH_ELEMENT; } protected Filter.Expression doGetFilterExpression(Map context) { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisor.java index 393de09a460..a473de846d5 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisor.java @@ -20,9 +20,10 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.ai.chat.client.AdvisedRequest; -import org.springframework.ai.chat.client.advisor.api.ResponseAdvisor; -import org.springframework.ai.chat.client.advisor.api.RequestAdvisor; +import org.springframework.ai.chat.client.advisor.api.AfterAdvisor; +import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; +import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; +import org.springframework.ai.chat.client.advisor.api.BeforeAdvisor; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.model.ModelOptionsUtils; @@ -31,7 +32,7 @@ * * @author Christian Tzolov */ -public class SimpleLoggerAdvisor implements RequestAdvisor, ResponseAdvisor { +public class SimpleLoggerAdvisor implements BeforeAdvisor, AfterAdvisor { private static final Logger logger = LoggerFactory.getLogger(SimpleLoggerAdvisor.class); @@ -63,15 +64,15 @@ public String getName() { } @Override - public AdvisedRequest adviseRequest(AdvisedRequest request, Map context) { + public AdvisedRequest before(AdvisedRequest request) { logger.debug("request: {}", this.requestToString.apply(request)); return request; } @Override - public ChatResponse adviseResponse(ChatResponse response, Map context) { - logger.debug("response: {}", this.responseToString.apply(response)); - return response; + public AdvisedResponse afterCall(AdvisedResponse advisedResponse) { + logger.debug("response: {}", this.responseToString.apply(advisedResponse.response())); + return advisedResponse; } @Override @@ -80,8 +81,8 @@ public String toString() { } @Override - public StreamResponseMode getStreamResponseMode() { - return StreamResponseMode.AGGREGATE; + public AfterStreamMode getAfterStreamMode() { + return AfterStreamMode.AGGREGATE; } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/VectorStoreChatMemoryAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/VectorStoreChatMemoryAdvisor.java index 9b545bf9352..c828c8b8e1a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/VectorStoreChatMemoryAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/VectorStoreChatMemoryAdvisor.java @@ -21,12 +21,12 @@ import java.util.Map; import java.util.stream.Collectors; -import org.springframework.ai.chat.client.AdvisedRequest; +import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; +import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.messages.UserMessage; -import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.document.Document; import org.springframework.ai.model.Content; import org.springframework.ai.vectorstore.SearchRequest; @@ -78,13 +78,14 @@ public VectorStoreChatMemoryAdvisor(VectorStore vectorStore, String defaultConve } @Override - public AdvisedRequest adviseRequest(AdvisedRequest request, Map context) { + public AdvisedRequest before(AdvisedRequest request) { String advisedSystemText = request.systemText() + System.lineSeparator() + this.systemTextAdvise; var searchRequest = SearchRequest.query(request.userText()) - .withTopK(this.doGetChatMemoryRetrieveSize(context)) - .withFilterExpression(DOCUMENT_METADATA_CONVERSATION_ID + "=='" + this.doGetConversationId(context) + "'"); + .withTopK(this.doGetChatMemoryRetrieveSize(request.adviseContext())) + .withFilterExpression(DOCUMENT_METADATA_CONVERSATION_ID + "=='" + + this.doGetConversationId(request.adviseContext()) + "'"); List documents = this.getChatMemoryStore().similaritySearch(searchRequest); @@ -101,19 +102,25 @@ public AdvisedRequest adviseRequest(AdvisedRequest request, Map .build(); UserMessage userMessage = new UserMessage(request.userText(), request.media()); - this.getChatMemoryStore().write(toDocuments(List.of(userMessage), this.doGetConversationId(context))); + this.getChatMemoryStore() + .write(toDocuments(List.of(userMessage), this.doGetConversationId(request.adviseContext()))); return advisedRequest; } @Override - public ChatResponse adviseResponse(ChatResponse chatResponse, Map context) { + public AdvisedResponse afterCall(AdvisedResponse advisedResponse) { - List assistantMessages = chatResponse.getResults().stream().map(g -> (Message) g.getOutput()).toList(); + List assistantMessages = advisedResponse.response() + .getResults() + .stream() + .map(g -> (Message) g.getOutput()) + .toList(); - this.getChatMemoryStore().write(toDocuments(assistantMessages, this.doGetConversationId(context))); + this.getChatMemoryStore() + .write(toDocuments(assistantMessages, this.doGetConversationId(advisedResponse.adviseContext()))); - return chatResponse; + return advisedResponse; } private List toDocuments(List messages, String conversationId) { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/AdvisedRequest.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java similarity index 64% rename from spring-ai-core/src/main/java/org/springframework/ai/chat/client/AdvisedRequest.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java index 903c304dc3c..b4c1a7fbaa9 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/AdvisedRequest.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java @@ -14,17 +14,28 @@ * limitations under the License. */ -package org.springframework.ai.chat.client; +package org.springframework.ai.chat.client.advisor.api; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.function.Function; import org.springframework.ai.model.Media; -import org.springframework.ai.chat.client.advisor.api.Advisor; import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; /** * The data of the chat client request that can be modified before the execution of the @@ -44,11 +55,18 @@ * @param systemParams the map of system parameters * @param advisors the list of request response advisors * @param advisorParams the map of advisor parameters + * @param adviseContext the map of advise context */ public record AdvisedRequest(ChatModel chatModel, String userText, String systemText, ChatOptions chatOptions, List media, List functionNames, List functionCallbacks, List messages, Map userParams, Map systemParams, List advisors, - Map advisorParams) { + Map advisorParams, Map adviseContext) { + + public AdvisedRequest augmentContext(Function, Map> contextTransform) { + return from(this) + .withAdviseContext(Collections.unmodifiableMap(contextTransform.apply(new HashMap<>(this.adviseContext)))) + .build(); + } public static Builder from(AdvisedRequest from) { Builder builder = new Builder(); @@ -64,7 +82,8 @@ public static Builder from(AdvisedRequest from) { builder.systemParams = from.systemParams; builder.advisors = from.advisors; builder.advisorParams = from.advisorParams; - builder.advisorParams = from.advisorParams; + builder.adviseContext = from.adviseContext; + return builder; } @@ -98,6 +117,8 @@ public static class Builder { private Map advisorParams = Map.of(); + private Map adviseContext = Map.of(); + public Builder withChatModel(ChatModel chatModel) { this.chatModel = chatModel; return this; @@ -158,12 +179,58 @@ public Builder withAdvisorParams(Map advisorParams) { return this; } + public Builder withAdviseContext(Map adviseContext) { + this.adviseContext = adviseContext; + return this; + } + public AdvisedRequest build() { return new AdvisedRequest(chatModel, this.userText, this.systemText, this.chatOptions, this.media, this.functionNames, this.functionCallbacks, this.messages, this.userParams, this.systemParams, - this.advisors, this.advisorParams); + this.advisors, this.advisorParams, this.adviseContext); + } + + } + + public Prompt toPrompt() { + + var messages = new ArrayList(this.messages()); + + String processedSystemText = this.systemText(); + if (StringUtils.hasText(processedSystemText)) { + if (!CollectionUtils.isEmpty(this.systemParams())) { + processedSystemText = new PromptTemplate(processedSystemText, this.systemParams()).render(); + } + messages.add(new SystemMessage(processedSystemText)); + } + + String formatParam = (String) this.adviseContext().get("formatParam"); + + var processedUserText = StringUtils.hasText(formatParam) + ? this.userText() + System.lineSeparator() + "{spring_ai_soc_format}" : this.userText(); + + if (StringUtils.hasText(processedUserText)) { + + Map userParams = new HashMap<>(this.userParams()); + if (StringUtils.hasText(formatParam)) { + userParams.put("spring_ai_soc_format", formatParam); + } + if (!CollectionUtils.isEmpty(userParams)) { + processedUserText = new PromptTemplate(processedUserText, userParams).render(); + } + messages.add(new UserMessage(processedUserText, this.media())); + } + + if (this.chatOptions() instanceof FunctionCallingOptions functionCallingOptions) { + if (!this.functionNames().isEmpty()) { + functionCallingOptions.setFunctions(new HashSet<>(this.functionNames())); + } + if (!this.functionCallbacks().isEmpty()) { + functionCallingOptions.setFunctionCallbacks(this.functionCallbacks()); + } } + return new Prompt(messages, this.chatOptions()); } } \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponse.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponse.java new file mode 100644 index 00000000000..dc7a5cdb554 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponse.java @@ -0,0 +1,73 @@ +/* +* Copyright 2024 - 2024 the original author or authors. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* https://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +package org.springframework.ai.chat.client.advisor.api; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.function.Function; + +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.util.Assert; + +/** + * @author Christian Tzolov + * @since 1.0.0 + */ +public record AdvisedResponse(ChatResponse response, Map adviseContext) { + + public AdvisedResponse contextTransform(Function, Map> contextTransform) { + return new AdvisedResponse(response, + Collections.unmodifiableMap(contextTransform.apply(new HashMap<>(adviseContext)))); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private ChatResponse response; + + private Map adviseContext; + + public Builder() { + } + + public static Builder from(AdvisedResponse advisedResponse) { + return new Builder().withResponse(advisedResponse.response) + .withAdviseContext(advisedResponse.adviseContext); + } + + public Builder withResponse(ChatResponse response) { + Assert.notNull(response, "the response must be non-null"); + this.response = response; + return this; + } + + public Builder withAdviseContext(Map adviseContext) { + Assert.notNull(adviseContext, "the adviseContext must be non-null"); + this.adviseContext = adviseContext; + return this; + } + + public AdvisedResponse build() { + Assert.notNull(this.adviseContext, "the adviseContext must be non-null"); + return new AdvisedResponse(response, adviseContext); + } + + } +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/Advisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/Advisor.java index eaa09c6efb5..c66521cbc3d 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/Advisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/Advisor.java @@ -20,8 +20,8 @@ * * @author Christian Tzolov * @since 1.0.0 - * @see RequestAdvisor - * @see ResponseAdvisor + * @see BeforeAdvisor + * @see AfterAdvisor * @see CallAroundAdvisor * @see StreamAroundAdvisor * @see AroundAdvisorChain diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/ResponseAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AfterAdvisor.java similarity index 74% rename from spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/ResponseAdvisor.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AfterAdvisor.java index e1040c91da0..0e8fe3e0204 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/ResponseAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AfterAdvisor.java @@ -16,8 +16,6 @@ package org.springframework.ai.chat.client.advisor.api; -import java.util.Map; - import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; @@ -33,23 +31,20 @@ * @author Christian Tzolov * @since 1.0.0 */ -public interface ResponseAdvisor extends Advisor { +public interface AfterAdvisor extends Advisor { /** * @param response the {@link ChatResponse} data to be advised. Represents the row * {@link ChatResponse} data after the {@link ChatModel#call(Prompt)} method is * called. - * @param adviseContext the shared data between the advisors in the chain. It is - * shared between all request and response advising points of all advisors in the - * chain. * @return the advised {@link ChatResponse}. */ - ChatResponse adviseResponse(ChatResponse response, Map adviseContext); + AdvisedResponse afterCall(AdvisedResponse advisedResponse); /** * Different modes of advising the streaming responses. */ - public enum StreamResponseMode { + public enum AfterStreamMode { /** * Called for each response element in the Flux. The response advisor can modify @@ -77,21 +72,18 @@ public enum StreamResponseMode { } - default StreamResponseMode getStreamResponseMode() { - return StreamResponseMode.ON_FINISH_ELEMENT; + default AfterStreamMode getAfterStreamMode() { + return AfterStreamMode.ON_FINISH_ELEMENT; } /** - * @param fluxResponse the streaming {@link ChatResponse} data to be advised. + * @param advisedResponseStream the streaming {@link ChatResponse} data to be advised. * Represents the row {@link ChatResponse} stream data after the * {@link ChatModel#stream(Prompt)} method is called. - * @param adviseContext the shared data between the advisors in the chain. It is - * shared between all request and response advising points of all advisors in the - * chain. * @return the advised {@link ChatResponse} flux. */ - default Flux adviseResponse(Flux fluxResponse, Map adviseContext) { - return fluxResponse; + default Flux afterStream(Flux advisedResponseStream) { + return advisedResponseStream; } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AroundAdvisorChain.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AroundAdvisorChain.java index f33a686c24d..ce44782f33b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AroundAdvisorChain.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AroundAdvisorChain.java @@ -1,16 +1,11 @@ package org.springframework.ai.chat.client.advisor.api; -import java.util.Map; - -import org.springframework.ai.chat.client.AdvisedRequest; -import org.springframework.ai.chat.model.ChatResponse; - import reactor.core.publisher.Flux; public interface AroundAdvisorChain { - ChatResponse nextAroundCall(AdvisedRequest advisedRequest, Map adviceContext); + AdvisedResponse nextAroundCall(AdvisedRequest advisedRequest); - Flux nextAroundStream(AdvisedRequest advisedRequest, Map adviceContext); + Flux nextAroundStream(AdvisedRequest advisedRequest); } \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/RequestAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/BeforeAdvisor.java similarity index 77% rename from spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/RequestAdvisor.java rename to spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/BeforeAdvisor.java index 8ba198e323a..cd0f09abb86 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/RequestAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/BeforeAdvisor.java @@ -16,9 +16,6 @@ package org.springframework.ai.chat.client.advisor.api; -import java.util.Map; - -import org.springframework.ai.chat.client.AdvisedRequest; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.prompt.Prompt; @@ -31,16 +28,13 @@ * @author Christian Tzolov * @since 1.0.0 */ -public interface RequestAdvisor extends Advisor { +public interface BeforeAdvisor extends Advisor { /** * @param request the {@link AdvisedRequest} data to be advised. Represents the row * {@link ChatClient.ChatClientRequestSpec} data before sealed into a {@link Prompt}. - * @param adviseContext the shared data between the advisors in the chain. It is - * shared between all request and response advising points of all advisors in the - * chain. * @return the advised {@link AdvisedRequest}. */ - AdvisedRequest adviseRequest(AdvisedRequest request, Map adviseContext); + AdvisedRequest before(AdvisedRequest request); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisor.java index 4ff26544e94..18fadf671ef 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisor.java @@ -15,10 +15,8 @@ */ package org.springframework.ai.chat.client.advisor.api; -import java.util.Map; - -import org.springframework.ai.chat.client.AdvisedRequest; -import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.prompt.Prompt; /** * @author Christian Tzolov @@ -30,10 +28,9 @@ public interface CallAroundAdvisor extends Advisor { /** * Around advice that wraps the ChatModel#call(Prompt) method. * @param advisedRequest the advised request - * @param adviceContext the advice context * @param chain the advisor chain * @return the response */ - ChatResponse aroundCall(AdvisedRequest advisedRequest, Map adviceContext, AroundAdvisorChain chain); + AdvisedResponse aroundCall(AdvisedRequest advisedRequest, AroundAdvisorChain chain); } \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisor.java index d2379a7724a..1fbc1f23648 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisor.java @@ -15,11 +15,6 @@ */ package org.springframework.ai.chat.client.advisor.api; -import java.util.Map; - -import org.springframework.ai.chat.client.AdvisedRequest; -import org.springframework.ai.chat.model.ChatResponse; - import reactor.core.publisher.Flux; /** @@ -31,11 +26,9 @@ public interface StreamAroundAdvisor extends Advisor { /** * Around advice that wraps the invocation of the advised request. * @param advisedRequest - * @param adviceContext * @param chain the chain of advisors to execute * @return the result of the advised request */ - Flux aroundStream(AdvisedRequest advisedRequest, Map adviceContext, - AroundAdvisorChain chain); + Flux aroundStream(AdvisedRequest advisedRequest, AroundAdvisorChain chain); } \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/around/CacheAroundAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/around/CacheAroundAdvisor.java deleted file mode 100644 index 51dcc6b5ec9..00000000000 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/around/CacheAroundAdvisor.java +++ /dev/null @@ -1,149 +0,0 @@ -/* -* Copyright 2024 - 2024 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ -package org.springframework.ai.chat.client.advisor.around; - -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; - -import org.springframework.ai.chat.client.AdvisedRequest; -import org.springframework.ai.chat.client.advisor.api.AroundAdvisorChain; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; -import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; -import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.MessageType; -import org.springframework.ai.chat.metadata.ChatGenerationMetadata; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.Generation; -import org.springframework.ai.chat.model.MessageAggregator; -import org.springframework.ai.document.Document; -import org.springframework.ai.vectorstore.SearchRequest; -import org.springframework.ai.vectorstore.VectorStore; -import org.springframework.util.CollectionUtils; - -import reactor.core.publisher.Flux; - -/** - * @author Christian Tzolov - * @since 1.0.0 - */ -public class CacheAroundAdvisor implements CallAroundAdvisor, StreamAroundAdvisor { - - private final VectorStore vectorStore; - - private static final String DOCUMENT_METADATA_ADVISOR_CACHE_TAG = "advisorCacheDocument"; - - private static final String DOCUMENT_METADATA_ADVISOR_CACHE_RESPONSE = "advisorCacheResponse"; - - public CacheAroundAdvisor(VectorStore vectorStore) { - this.vectorStore = vectorStore; - } - - public String getName() { - return this.getClass().getSimpleName(); - } - - @Override - public ChatResponse aroundCall(AdvisedRequest advisedRequest, Map adviceContext, - AroundAdvisorChain chain) { - - var cachedResponseOption = getCacheEntry(advisedRequest, adviceContext); - if (cachedResponseOption.isPresent()) { - return cachedResponseOption.get(); - } - - ChatResponse chatResponse = chain.nextAroundCall(advisedRequest, adviceContext); - - saveCacheEntry(advisedRequest.userText(), chatResponse); - - return chatResponse; - } - - @Override - public Flux aroundStream(AdvisedRequest advisedRequest, Map adviceContext, - AroundAdvisorChain chain) { - - var cachedResponseOption = getCacheEntry(advisedRequest, adviceContext); - if (cachedResponseOption.isPresent()) { - return Flux.just(cachedResponseOption.get()); - } - - Flux fluxChatResponse = chain.nextAroundStream(advisedRequest, adviceContext); - - return new MessageAggregator().aggregate(fluxChatResponse, chatResponse -> { - saveCacheEntry(advisedRequest.userText(), chatResponse); - }); - } - - private void saveCacheEntry(String userQuestion, ChatResponse chatResponse) { - List assistantMessages = chatResponse.getResults().stream().map(g -> (Message) g.getOutput()).toList(); - if (!CollectionUtils.isEmpty(assistantMessages)) { - this.vectorStore.add(toDocuments(userQuestion, assistantMessages)); - } - } - - private Optional getCacheEntry(AdvisedRequest advisedRequest, Map adviceContext) { - - // TODO: convert into pompty first or at least materialize the user params. - String userText = advisedRequest.userText(); - - // @formatter:off - var searchRequest = SearchRequest.query(userText) - .withSimilarityThreshold(0.95) - .withTopK(1) - .withFilterExpression("'"+ DOCUMENT_METADATA_ADVISOR_CACHE_TAG + "' == 'true'"); - // @formatter:on - - List doc = vectorStore.similaritySearch(searchRequest); - - // return cached response - return CollectionUtils.isEmpty(doc) ? Optional.empty() : Optional.of(fromDocument(doc.get(0))); - } - - private ChatResponse fromDocument(Document doc) { - - if (!doc.getMetadata().containsKey(DOCUMENT_METADATA_ADVISOR_CACHE_RESPONSE)) { - throw new IllegalStateException("The document is missing the cache response metadata!"); - } - String cachedResponse = "" + doc.getMetadata().get(DOCUMENT_METADATA_ADVISOR_CACHE_RESPONSE); - - return ChatResponse.builder() - .withGenerations(List.of(new Generation(new AssistantMessage(cachedResponse, Map.of()), - ChatGenerationMetadata.from("STOP", null)))) - .build(); - } - - private List toDocuments(String userQuestion, List messages) { - - List docs = messages.stream() - .filter(m -> m.getMessageType() == MessageType.ASSISTANT) - .map(message -> { - var metadata = new HashMap<>(message.getMetadata() != null ? message.getMetadata() : new HashMap<>()); - metadata.put(DOCUMENT_METADATA_ADVISOR_CACHE_TAG, "true"); - metadata.put(DOCUMENT_METADATA_ADVISOR_CACHE_RESPONSE, message.getContent()); - // TODO: Pehaps we need to serialize the message metadata to the document - - return new Document(userQuestion, metadata); - - }) - .toList(); - - return docs; - } - -} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/around/SafeGuardAroundAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/around/SafeGuardAroundAdvisor.java index f682ceb1b70..a354fda1d8f 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/around/SafeGuardAroundAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/around/SafeGuardAroundAdvisor.java @@ -1,9 +1,9 @@ package org.springframework.ai.chat.client.advisor.around; import java.util.List; -import java.util.Map; -import org.springframework.ai.chat.client.AdvisedRequest; +import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; +import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; import org.springframework.ai.chat.client.advisor.api.AroundAdvisorChain; import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; @@ -15,7 +15,7 @@ /** * A {@link CallAroundAdvisor} and {@link StreamAroundAdvisor} that filters out the * response if the user input contains any of the sensitive words. - * + * * @author Christian Tzolov * @since 1.0.0 */ @@ -32,27 +32,26 @@ public String getName() { } @Override - public ChatResponse aroundCall(AdvisedRequest advisedRequest, Map adviceContext, - AroundAdvisorChain chain) { + public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, AroundAdvisorChain chain) { if (!CollectionUtils.isEmpty(this.sensitiveWords) && sensitiveWords.stream().anyMatch(w -> advisedRequest.userText().contains(w))) { - return ChatResponse.builder().withGenerations(List.of()).build(); + return new AdvisedResponse(ChatResponse.builder().withGenerations(List.of()).build(), + advisedRequest.adviseContext()); } - return chain.nextAroundCall(advisedRequest, adviceContext); + return chain.nextAroundCall(advisedRequest); } @Override - public Flux aroundStream(AdvisedRequest advisedRequest, Map adviceContext, - AroundAdvisorChain chain) { + public Flux aroundStream(AdvisedRequest advisedRequest, AroundAdvisorChain chain) { if (!CollectionUtils.isEmpty(this.sensitiveWords) && sensitiveWords.stream().anyMatch(w -> advisedRequest.userText().contains(w))) { return Flux.empty(); } - return chain.nextAroundStream(advisedRequest, adviceContext); + return chain.nextAroundStream(advisedRequest); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservableHelper.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservableHelper.java index e267a5ee2fe..c803017d07a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservableHelper.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservableHelper.java @@ -18,13 +18,12 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; -import java.util.Map; -import org.springframework.ai.chat.client.AdvisedRequest; +import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; +import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; import org.springframework.ai.chat.client.advisor.api.Advisor; -import org.springframework.ai.chat.client.advisor.api.RequestAdvisor; -import org.springframework.ai.chat.client.advisor.api.ResponseAdvisor; -import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.client.advisor.api.BeforeAdvisor; +import org.springframework.ai.chat.client.advisor.api.AfterAdvisor; import org.springframework.util.CollectionUtils; import io.micrometer.observation.Observation; @@ -37,57 +36,62 @@ public abstract class AdvisorObservableHelper { public static final AdvisorObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultAdvisorObservationConvention(); - public static AdvisedRequest adviseRequest(Observation parentObservation, RequestAdvisor advisor, - AdvisedRequest advisedRequest, Map advisorContext) { + public static AdvisedRequest adviseRequest(Observation parentObservation, BeforeAdvisor advisor, + AdvisedRequest advisedRequest) { var observationContext = AdvisorObservationContext.builder() .withAdvisorName(advisor.getName()) .withAdvisorType(AdvisorObservationContext.Type.BEFORE) .withAdvisedRequest(advisedRequest) - .withAdvisorRequestContext(advisorContext) + .withAdvisorRequestContext(advisedRequest.adviseContext()) .build(); return AdvisorObservationDocumentation.AI_ADVISOR .observation(null, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, parentObservation.getObservationRegistry()) .parentObservation(parentObservation) - .observe(() -> advisor.adviseRequest(advisedRequest, advisorContext)); + .observe(() -> advisor.before(advisedRequest)); } - public static ChatResponse adviseResponse(Observation parentObservation, ResponseAdvisor advisor, - ChatResponse response, Map advisorContext) { + public static AdvisedResponse adviseResponse(Observation parentObservation, AfterAdvisor advisor, + AdvisedResponse advisedResponse) { var observationContext = AdvisorObservationContext.builder() .withAdvisorName(advisor.getName()) .withAdvisorType(AdvisorObservationContext.Type.AFTER) - .withAdvisorRequestContext(advisorContext) + .withAdvisorRequestContext(advisedResponse.adviseContext()) .build(); return AdvisorObservationDocumentation.AI_ADVISOR .observation(null, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, parentObservation.getObservationRegistry()) .parentObservation(parentObservation) - .observe(() -> advisor.adviseResponse(response, advisorContext)); + .observe(() -> advisor.afterCall(advisedResponse)); } - public static List extractRequestAdvisors(List advisors) { + public static List requestAdvisors(List advisors) { + if (CollectionUtils.isEmpty(advisors)) { + return Collections.emptyList(); + } return advisors.stream() - .filter(advisor -> advisor instanceof RequestAdvisor) - .map(a -> (RequestAdvisor) a) + .filter(advisor -> advisor instanceof BeforeAdvisor) + .map(a -> (BeforeAdvisor) a) .toList(); } /** - * Extracts the {@link ResponseAdvisor} instances from the given list of advisors and + * Extracts the {@link AfterAdvisor} instances from the given list of advisors and * returns them in reverse order. * @param advisors list of all registered advisor types. - * @return the list of {@link ResponseAdvisor} instances in reverse order. + * @return the list of {@link AfterAdvisor} instances in reverse order. */ - public static List extractResponseAdvisors(List advisors) { - + public static List responseAdvisors(List advisors) { + if (CollectionUtils.isEmpty(advisors)) { + return Collections.emptyList(); + } var list = advisors.stream() - .filter(advisor -> advisor instanceof ResponseAdvisor) - .map(a -> (ResponseAdvisor) a) + .filter(advisor -> advisor instanceof AfterAdvisor) + .map(a -> (AfterAdvisor) a) .toList(); // reverse the list diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationContext.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationContext.java index 5a4421e98ca..a28906995ea 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationContext.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationContext.java @@ -17,8 +17,8 @@ import java.util.Map; -import org.springframework.ai.chat.client.AdvisedRequest; import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.util.Assert; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java index 6aef10ed777..c7da921726a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java @@ -24,6 +24,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.ChatResponseMetadata; diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientAdvisorTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientAdvisorTests.java index 432029fb813..a1c402eeb4f 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientAdvisorTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientAdvisorTests.java @@ -29,6 +29,7 @@ import reactor.core.publisher.Flux; import org.springframework.ai.chat.client.advisor.PromptChatMemoryAdvisor; +import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; import org.springframework.ai.chat.memory.ChatMemory; import org.springframework.ai.chat.memory.InMemoryChatMemory; import org.springframework.ai.chat.messages.Message; diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/AdvisorsTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/AdvisorsTests.java new file mode 100644 index 00000000000..34cec55a9fa --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/AdvisorsTests.java @@ -0,0 +1,308 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.client.advisor; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.when; + +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; +import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; +import org.springframework.ai.chat.client.advisor.api.AfterAdvisor; +import org.springframework.ai.chat.client.advisor.api.AroundAdvisorChain; +import org.springframework.ai.chat.client.advisor.api.BeforeAdvisor; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; +import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; + +import reactor.core.publisher.Flux; + +/** + * @author Christian Tzolov + */ +@ExtendWith(MockitoExtension.class) +public class AdvisorsTests { + + @Mock + ChatModel chatModel; + + @Captor + ArgumentCaptor promptCaptor; + + public static class MockBeforeAdvisor implements BeforeAdvisor { + + public AdvisedRequest advisedRequest; + + @Override + public String getName() { + return "MockBeforeAdvisor"; + } + + @Override + public AdvisedRequest before(AdvisedRequest advisedRequest) { + this.advisedRequest = advisedRequest.augmentContext(context -> { + context.put("before", "BEFORE"); + return context; + }); + + return this.advisedRequest; + } + + } + + public static class MockAfterAdvisor implements AfterAdvisor { + + public AdvisedResponse advisedResponse; + + public Flux advisedResponseStream; + + @Override + public String getName() { + return "MockAfterAdvisor"; + } + + @Override + public AdvisedResponse afterCall(AdvisedResponse advisedResponse) { + this.advisedResponse = advisedResponse.contextTransform(context -> { + context.put("afterCall", "AFTER_CALL"); + return context; + }); + return this.advisedResponse; + } + + @Override + public Flux afterStream(Flux advisedResponseStream) { + + this.advisedResponseStream = advisedResponseStream.map(advisedResponse -> { + return advisedResponse.contextTransform(context -> { + context.put("afterStream", "AFTER_STREAM"); + return context; + }); + }); + + return this.advisedResponseStream; + } + + } + + public class MockAroundAdvisor implements CallAroundAdvisor, StreamAroundAdvisor { + + public AdvisedRequest advisedRequest; + + public AdvisedResponse advisedResponse; + + public Flux advisedResponseStream; + + @Override + public String getName() { + return "MockAroundAdvisor"; + } + + @Override + public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, AroundAdvisorChain chain) { + + this.advisedRequest = advisedRequest.augmentContext(context -> { + context.put("aroundCallBefore", "AROUND_CALL_BEFORE"); + return context; + }); + + AdvisedResponse advisedResponse = this.advisedResponse = chain.nextAroundCall(this.advisedRequest); + + this.advisedResponse = advisedResponse.contextTransform(context -> { + context.put("aroundCallAfter", "AROUND_CALL_AFTER"); + return context; + }); + + return this.advisedResponse; + } + + @Override + public Flux aroundStream(AdvisedRequest advisedRequest, AroundAdvisorChain chain) { + + this.advisedRequest = advisedRequest.augmentContext(context -> { + context.put("aroundStreamBefore", "AROUND_STREAM_BEFORE"); + return context; + }); + + Flux advisedResponseStream = chain.nextAroundStream(this.advisedRequest); + + this.advisedResponseStream = advisedResponseStream.map(advisedResponse -> { + return advisedResponse.contextTransform(context -> { + context.put("aroundStreamAfter", "AROUND_STREAM_AFTER"); + return context; + }); + }); + + return this.advisedResponseStream; + } + + } + + @Test + public void callAdvisorsContextPropagation() { + + var mockBeforeAdvisor = new MockBeforeAdvisor(); + var mockAfterAdvisor = new MockAfterAdvisor(); + + var mockAroundAdvisor = new MockAroundAdvisor(); + + when(chatModel.call(promptCaptor.capture())) + .thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("Hello John"))))) + .thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("Your name is John"))))); + + when(chatModel.call(promptCaptor.capture())) + .thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("Hello John"))))) + .thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("Your name is John"))))); + + when(chatModel.call(promptCaptor.capture())) + .thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("Hello John"))))) + .thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("Your name is John"))))); + + var chatClient = ChatClient.builder(chatModel) + .defaultSystem("Default system text.") + .defaultAdvisors(mockBeforeAdvisor, mockAroundAdvisor, mockAfterAdvisor) + .build(); + + var content = chatClient.prompt() + .user("my name is John") + .advisors(a -> a.param("key1", "value1").params(Map.of("key2", "value2"))) + .call() + .content(); + + assertThat(content).isEqualTo("Hello John"); + + // BEFORE + assertThat(mockBeforeAdvisor.advisedRequest.adviseContext()).containsEntry("key1", "value1") + .containsEntry("key2", "value2") + .containsEntry("before", "BEFORE") + .doesNotContainKeys("aroundCallBefore", "aroundCallAfter", "afterCall"); + + assertThat(mockBeforeAdvisor.advisedRequest.advisorParams()).containsEntry("key1", "value1") + .containsEntry("key2", "value2") + .doesNotContainKey("before"); + + // AROUND + assertThat(mockAroundAdvisor.advisedResponse.response()).isNotNull(); + assertThat(mockAroundAdvisor.advisedResponse.adviseContext()).containsEntry("key1", "value1") + .containsEntry("key2", "value2") + .containsEntry("before", "BEFORE") + .containsEntry("aroundCallBefore", "AROUND_CALL_BEFORE") + .containsEntry("aroundCallAfter", "AROUND_CALL_AFTER") + .doesNotContainKeys("afterCall"); + + // AFTER + assertThat(mockAfterAdvisor.advisedResponse.adviseContext()).containsEntry("key1", "value1") + .containsEntry("key2", "value2") + .containsEntry("before", "BEFORE") + .containsEntry("aroundCallBefore", "AROUND_CALL_BEFORE") + .containsEntry("aroundCallAfter", "AROUND_CALL_AFTER") + .containsEntry("afterCall", "AFTER_CALL"); + } + + @Test + public void streamAdvisorsContextPropagation() { + + var mockBeforeAdvisor = new MockBeforeAdvisor(); + var mockAfterAdvisor = new MockAfterAdvisor(); + + var mockAroundAdvisor = new MockAroundAdvisor(); + + when(chatModel.stream(promptCaptor.capture())) + .thenReturn(Flux.just(new ChatResponse(List.of(new Generation(new AssistantMessage("Hello")))), + new ChatResponse(List.of(new Generation(new AssistantMessage(" John")))))); + + when(chatModel.stream(promptCaptor.capture())) + .thenReturn(Flux.just(new ChatResponse(List.of(new Generation(new AssistantMessage("Hello")))), + new ChatResponse(List.of(new Generation(new AssistantMessage(" John")))))); + + var chatClient = ChatClient.builder(chatModel) + .defaultSystem("Default system text.") + .defaultAdvisors(mockBeforeAdvisor, mockAroundAdvisor, mockAfterAdvisor) + .build(); + + var content = chatClient.prompt() + .user("my name is John") + .advisors(a -> a.param("key1", "value1").params(Map.of("key2", "value2"))) + .stream() + .content() + .collectList() + .block() + .stream() + .collect(Collectors.joining()); + + assertThat(content).isEqualTo("Hello John"); + + // BEFORE + assertThat(mockBeforeAdvisor.advisedRequest.adviseContext()).containsEntry("key1", "value1") + .containsEntry("key2", "value2") + .containsEntry("before", "BEFORE") + .doesNotContainKeys("aroundCallBefore", "aroundCallAfter", "afterCall"); + + assertThat(mockBeforeAdvisor.advisedRequest.advisorParams()).containsEntry("key1", "value1") + .containsEntry("key2", "value2") + .doesNotContainKey("before"); + + // AROUND + assertThat(mockAroundAdvisor.advisedResponseStream).isNotNull(); + + mockAroundAdvisor.advisedResponseStream.collectList().block().forEach(advisedResponse -> { + assertThat(advisedResponse.adviseContext()).containsEntry("key1", "value1") + .containsEntry("key2", "value2") + .containsEntry("before", "BEFORE") + .containsEntry("aroundCallBefore", "AROUND_CALL_BEFORE") + .containsEntry("aroundCallAfter", "AROUND_CALL_AFTER") + .doesNotContainKeys("afterCall"); + + assertThat(advisedResponse.adviseContext()).containsEntry("key1", "value1") + .containsEntry("key2", "value2") + .containsEntry("before", "BEFORE") + .containsEntry("aroundCallBefore", "AROUND_CALL_BEFORE") + .containsEntry("aroundCallAfter", "AROUND_CALL_AFTER") + .doesNotContainKeys("afterCall"); + }); + + assertThat(mockAroundAdvisor.advisedResponse.adviseContext()).containsEntry("key1", "value1") + .containsEntry("key2", "value2") + .containsEntry("before", "BEFORE") + .containsEntry("aroundCallBefore", "AROUND_CALL_BEFORE") + .containsEntry("aroundCallAfter", "AROUND_CALL_AFTER") + .doesNotContainKeys("afterCall"); + + // AFTER + assertThat(mockAfterAdvisor.advisedResponse.adviseContext()).containsEntry("key1", "value1") + .containsEntry("key2", "value2") + .containsEntry("before", "BEFORE") + .containsEntry("aroundCallBefore", "AROUND_CALL_BEFORE") + .containsEntry("aroundCallAfter", "AROUND_CALL_AFTER") + .containsEntry("afterCall", "AFTER_CALL"); + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java index 7067139d6a3..3bec3dea7c1 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java @@ -25,8 +25,8 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; -import org.springframework.ai.chat.client.AdvisedRequest; import org.springframework.ai.chat.client.DefaultChatClient.DefaultChatClientRequestSpec; +import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; import org.springframework.ai.chat.client.RequestResponseAdvisor; import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation.HighCardinalityKeyNames; import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation.LowCardinalityKeyNames;