diff --git a/advisors/spring-ai-advisors-vector-store/src/main/java/org/springframework/ai/chat/client/advisor/vectorstore/QuestionAnswerAdvisor.java b/advisors/spring-ai-advisors-vector-store/src/main/java/org/springframework/ai/chat/client/advisor/vectorstore/QuestionAnswerAdvisor.java index d6ca554e61e..3d651fe3271 100644 --- a/advisors/spring-ai-advisors-vector-store/src/main/java/org/springframework/ai/chat/client/advisor/vectorstore/QuestionAnswerAdvisor.java +++ b/advisors/spring-ai-advisors-vector-store/src/main/java/org/springframework/ai/chat/client/advisor/vectorstore/QuestionAnswerAdvisor.java @@ -21,17 +21,13 @@ import java.util.Map; import java.util.stream.Collectors; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; +import org.springframework.ai.chat.client.ChatClientRequest; +import org.springframework.ai.chat.client.ChatClientResponse; +import org.springframework.ai.chat.client.advisor.api.AdvisorChain; +import org.springframework.ai.chat.client.advisor.api.BaseAdvisor; +import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Schedulers; -import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; -import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; -import org.springframework.ai.chat.client.advisor.api.AdvisedResponseStreamUtils; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; -import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; -import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.document.Document; @@ -53,7 +49,7 @@ * @author Thomas Vitale * @since 1.0.0 */ -public class QuestionAnswerAdvisor implements CallAroundAdvisor, StreamAroundAdvisor { +public class QuestionAnswerAdvisor implements BaseAdvisor { public static final String RETRIEVED_DOCUMENTS = "qa_retrieved_documents"; @@ -80,100 +76,24 @@ public class QuestionAnswerAdvisor implements CallAroundAdvisor, StreamAroundAdv private final SearchRequest searchRequest; - private final boolean protectFromBlocking; + private final Scheduler scheduler; private final int order; - /** - * The QuestionAnswerAdvisor retrieves context information from a Vector Store and - * combines it with the user's text. - * @param vectorStore The vector store to use - */ public QuestionAnswerAdvisor(VectorStore vectorStore) { - this(vectorStore, SearchRequest.builder().build(), DEFAULT_PROMPT_TEMPLATE, true, DEFAULT_ORDER); - } - - /** - * The QuestionAnswerAdvisor retrieves context information from a Vector Store and - * combines it with the user's text. - * @param vectorStore The vector store to use - * @param searchRequest The search request defined using the portable filter - * expression syntax - * @deprecated in favor of the builder: {@link #builder(VectorStore)} - */ - @Deprecated - public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest) { - this(vectorStore, searchRequest, DEFAULT_PROMPT_TEMPLATE, true, DEFAULT_ORDER); - } - - /** - * The QuestionAnswerAdvisor retrieves context information from a Vector Store and - * combines it with the user's text. - * @param vectorStore The vector store to use - * @param searchRequest The search request defined using the portable filter - * expression syntax - * @param userTextAdvise The user text to append to the existing user prompt. The text - * should contain a placeholder named "question_answer_context". - * @deprecated in favor of the builder: {@link #builder(VectorStore)} - */ - @Deprecated - public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest, String userTextAdvise) { - this(vectorStore, searchRequest, PromptTemplate.builder().template(userTextAdvise).build(), true, + this(vectorStore, SearchRequest.builder().build(), DEFAULT_PROMPT_TEMPLATE, BaseAdvisor.DEFAULT_SCHEDULER, DEFAULT_ORDER); } - /** - * The QuestionAnswerAdvisor retrieves context information from a Vector Store and - * combines it with the user's text. - * @param vectorStore The vector store to use - * @param searchRequest The search request defined using the portable filter - * expression syntax - * @param userTextAdvise The user text to append to the existing user prompt. The text - * should contain a placeholder named "question_answer_context". - * @param protectFromBlocking If true the advisor will protect the execution from - * blocking threads. If false the advisor will not protect the execution from blocking - * threads. This is useful when the advisor is used in a non-blocking environment. It - * is true by default. - * @deprecated in favor of the builder: {@link #builder(VectorStore)} - */ - @Deprecated - public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest, String userTextAdvise, - boolean protectFromBlocking) { - this(vectorStore, searchRequest, PromptTemplate.builder().template(userTextAdvise).build(), protectFromBlocking, - DEFAULT_ORDER); - } - - /** - * The QuestionAnswerAdvisor retrieves context information from a Vector Store and - * combines it with the user's text. - * @param vectorStore The vector store to use - * @param searchRequest The search request defined using the portable filter - * expression syntax - * @param userTextAdvise The user text to append to the existing user prompt. The text - * should contain a placeholder named "question_answer_context". - * @param protectFromBlocking If true the advisor will protect the execution from - * blocking threads. If false the advisor will not protect the execution from blocking - * threads. This is useful when the advisor is used in a non-blocking environment. It - * is true by default. - * @param order The order of the advisor. - * @deprecated in favor of the builder: {@link #builder(VectorStore)} - */ - @Deprecated - public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest, String userTextAdvise, - boolean protectFromBlocking, int order) { - this(vectorStore, searchRequest, PromptTemplate.builder().template(userTextAdvise).build(), protectFromBlocking, - order); - } - QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest, @Nullable PromptTemplate promptTemplate, - boolean protectFromBlocking, int order) { + @Nullable Scheduler scheduler, int order) { Assert.notNull(vectorStore, "vectorStore cannot be null"); Assert.notNull(searchRequest, "searchRequest cannot be null"); this.vectorStore = vectorStore; this.searchRequest = searchRequest; this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULT_PROMPT_TEMPLATE; - this.protectFromBlocking = protectFromBlocking; + this.scheduler = scheduler != null ? scheduler : BaseAdvisor.DEFAULT_SCHEDULER; this.order = order; } @@ -181,97 +101,71 @@ public static Builder builder(VectorStore vectorStore) { return new Builder(vectorStore); } - @Override - public String getName() { - return this.getClass().getSimpleName(); - } - @Override public int getOrder() { return this.order; } @Override - public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { - - AdvisedRequest advisedRequest2 = before(advisedRequest); - - AdvisedResponse advisedResponse = chain.nextAroundCall(advisedRequest2); - - return after(advisedResponse); - } - - @Override - public Flux aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) { - - // This can be executed by both blocking and non-blocking Threads - // E.g. a command line or Tomcat blocking Thread implementation - // or by a WebFlux dispatch in a non-blocking manner. - Flux advisedResponses = (this.protectFromBlocking) ? - // @formatter:off - Mono.just(advisedRequest) - .publishOn(Schedulers.boundedElastic()) - .map(this::before) - .flatMapMany(request -> chain.nextAroundStream(request)) - : chain.nextAroundStream(before(advisedRequest)); - // @formatter:on - - return advisedResponses.map(ar -> { - if (AdvisedResponseStreamUtils.onFinishReason().test(ar)) { - ar = after(ar); - } - return ar; - }); - } - - private AdvisedRequest before(AdvisedRequest request) { - - var context = new HashMap<>(request.adviseContext()); - + public ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChain advisorChain) { // 1. Search for similar documents in the vector store. var searchRequestToUse = SearchRequest.from(this.searchRequest) - .query(request.userText()) - .filterExpression(doGetFilterExpression(context)) + .query(chatClientRequest.prompt().getUserMessage().getText()) + .filterExpression(doGetFilterExpression(chatClientRequest.context())) .build(); List documents = this.vectorStore.similaritySearch(searchRequestToUse); // 2. Create the context from the documents. + Map context = new HashMap<>(chatClientRequest.context()); context.put(RETRIEVED_DOCUMENTS, documents); - String documentContext = documents.stream() - .map(Document::getText) - .collect(Collectors.joining(System.lineSeparator())); + String documentContext = documents == null ? "" + : documents.stream().map(Document::getText).collect(Collectors.joining(System.lineSeparator())); // 3. Augment the user prompt with the document context. String augmentedUserText = this.promptTemplate.mutate() - .template(request.userText() + System.lineSeparator() + this.promptTemplate.getTemplate()) + .template(chatClientRequest.prompt().getUserMessage().getText() + System.lineSeparator() + + this.promptTemplate.getTemplate()) .variables(Map.of("question_answer_context", documentContext)) .build() .render(); - AdvisedRequest advisedRequest = AdvisedRequest.from(request) - .userText(augmentedUserText) - .adviseContext(context) + // 4. Update ChatClientRequest with augmented prompt. + return chatClientRequest.mutate() + .prompt(chatClientRequest.prompt().augmentUserMessage(augmentedUserText)) + .context(context) .build(); - - return advisedRequest; } - private AdvisedResponse after(AdvisedResponse advisedResponse) { - ChatResponse.Builder chatResponseBuilder = ChatResponse.builder().from(advisedResponse.response()); - chatResponseBuilder.metadata(RETRIEVED_DOCUMENTS, advisedResponse.adviseContext().get(RETRIEVED_DOCUMENTS)); - return new AdvisedResponse(chatResponseBuilder.build(), advisedResponse.adviseContext()); + @Override + public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorChain advisorChain) { + ChatResponse.Builder chatResponseBuilder; + if (chatClientResponse.chatResponse() == null) { + chatResponseBuilder = ChatResponse.builder(); + } + else { + chatResponseBuilder = ChatResponse.builder().from(chatClientResponse.chatResponse()); + } + chatResponseBuilder.metadata(RETRIEVED_DOCUMENTS, chatClientResponse.context().get(RETRIEVED_DOCUMENTS)); + return ChatClientResponse.builder() + .chatResponse(chatResponseBuilder.build()) + .context(chatClientResponse.context()) + .build(); } + @Nullable protected Filter.Expression doGetFilterExpression(Map context) { - if (!context.containsKey(FILTER_EXPRESSION) || !StringUtils.hasText(context.get(FILTER_EXPRESSION).toString())) { return this.searchRequest.getFilterExpression(); } return new FilterExpressionTextParser().parse(context.get(FILTER_EXPRESSION).toString()); + } + @Override + public Scheduler getScheduler() { + return this.scheduler; } public static final class Builder { @@ -282,7 +176,7 @@ public static final class Builder { private PromptTemplate promptTemplate; - private boolean protectFromBlocking = true; + private Scheduler scheduler; private int order = DEFAULT_ORDER; @@ -303,18 +197,13 @@ public Builder searchRequest(SearchRequest searchRequest) { return this; } - /** - * @deprecated in favour of {@link #promptTemplate(PromptTemplate)} - */ - @Deprecated - public Builder userTextAdvise(String userTextAdvise) { - Assert.hasText(userTextAdvise, "The userTextAdvise must not be empty!"); - this.promptTemplate = PromptTemplate.builder().template(userTextAdvise).build(); + public Builder protectFromBlocking(boolean protectFromBlocking) { + this.scheduler = protectFromBlocking ? BaseAdvisor.DEFAULT_SCHEDULER : Schedulers.immediate(); return this; } - public Builder protectFromBlocking(boolean protectFromBlocking) { - this.protectFromBlocking = protectFromBlocking; + public Builder scheduler(Scheduler scheduler) { + this.scheduler = scheduler; return this; } @@ -324,8 +213,8 @@ public Builder order(int order) { } public QuestionAnswerAdvisor build() { - return new QuestionAnswerAdvisor(this.vectorStore, this.searchRequest, this.promptTemplate, - this.protectFromBlocking, this.order); + return new QuestionAnswerAdvisor(this.vectorStore, this.searchRequest, this.promptTemplate, this.scheduler, + this.order); } } diff --git a/advisors/spring-ai-advisors-vector-store/src/main/java/org/springframework/ai/chat/client/advisor/vectorstore/VectorStoreChatMemoryAdvisor.java b/advisors/spring-ai-advisors-vector-store/src/main/java/org/springframework/ai/chat/client/advisor/vectorstore/VectorStoreChatMemoryAdvisor.java index 76c7ad36883..0866c5bd7bd 100644 --- a/advisors/spring-ai-advisors-vector-store/src/main/java/org/springframework/ai/chat/client/advisor/vectorstore/VectorStoreChatMemoryAdvisor.java +++ b/advisors/spring-ai-advisors-vector-store/src/main/java/org/springframework/ai/chat/client/advisor/vectorstore/VectorStoreChatMemoryAdvisor.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ package org.springframework.ai.chat.client.advisor.vectorstore; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -24,19 +25,20 @@ import reactor.core.publisher.Flux; import org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor; -import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; -import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; -import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; +import org.springframework.ai.chat.client.ChatClientRequest; +import org.springframework.ai.chat.client.ChatClientResponse; +import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain; +import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.MessageType; +import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.MessageAggregator; +import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.document.Document; import org.springframework.ai.vectorstore.SearchRequest; import org.springframework.ai.vectorstore.VectorStore; -import org.springframework.util.StringUtils; /** * Memory is retrieved from a VectorStore added into the prompt's system text. @@ -87,80 +89,76 @@ public static Builder builder(VectorStore chatMemory) { } @Override - public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { + public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) { + chatClientRequest = this.before(chatClientRequest); - advisedRequest = this.before(advisedRequest); + ChatClientResponse chatClientResponse = callAdvisorChain.nextCall(chatClientRequest); - AdvisedResponse advisedResponse = chain.nextAroundCall(advisedRequest); + this.after(chatClientResponse); - this.observeAfter(advisedResponse); - - return advisedResponse; + return chatClientResponse; } @Override - public Flux aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) { - - Flux advisedResponses = this.doNextWithProtectFromBlockingBefore(advisedRequest, chain, - this::before); + public Flux adviseStream(ChatClientRequest chatClientRequest, + StreamAdvisorChain streamAdvisorChain) { + Flux chatClientResponses = this.doNextWithProtectFromBlockingBefore(chatClientRequest, + streamAdvisorChain, this::before); - // The observeAfter will certainly be executed on non-blocking Threads in case - // of some models - e.g. when the model client is a WebClient - return new MessageAggregator().aggregateAdvisedResponse(advisedResponses, this::observeAfter); + return new MessageAggregator().aggregateChatClientResponse(chatClientResponses, this::after); } - private AdvisedRequest before(AdvisedRequest request) { - - String advisedSystemText; - if (StringUtils.hasText(request.systemText())) { - advisedSystemText = request.systemText() + System.lineSeparator() + this.systemTextAdvise; - } - else { - advisedSystemText = this.systemTextAdvise; - } + private ChatClientRequest before(ChatClientRequest chatClientRequest) { + String conversationId = this.doGetConversationId(chatClientRequest.context()); + int chatMemoryRetrieveSize = this.doGetChatMemoryRetrieveSize(chatClientRequest.context()); + // 1. Retrieve the chat memory for the current conversation. var searchRequest = SearchRequest.builder() - .query(request.userText()) - .topK(this.doGetChatMemoryRetrieveSize(request.adviseContext())) - .filterExpression( - DOCUMENT_METADATA_CONVERSATION_ID + "=='" + this.doGetConversationId(request.adviseContext()) + "'") + .query(chatClientRequest.prompt().getUserMessage().getText()) + .topK(chatMemoryRetrieveSize) + .filterExpression(DOCUMENT_METADATA_CONVERSATION_ID + "=='" + conversationId + "'") .build(); List documents = this.getChatMemoryStore().similaritySearch(searchRequest); - String longTermMemory = documents.stream() - .map(Document::getText) - .collect(Collectors.joining(System.lineSeparator())); - - Map advisedSystemParams = new HashMap<>(request.systemParams()); - advisedSystemParams.put("long_term_memory", longTermMemory); - - AdvisedRequest advisedRequest = AdvisedRequest.from(request) - .systemText(advisedSystemText) - .systemParams(advisedSystemParams) + // 2. Processed memory messages as a string. + String longTermMemory = documents == null ? "" + : documents.stream().map(Document::getText).collect(Collectors.joining(System.lineSeparator())); + + // 2. Augment the system message. + SystemMessage systemMessage = chatClientRequest.prompt().getSystemMessage(); + String augmentedSystemText = PromptTemplate.builder() + .template(systemMessage.getText() + System.lineSeparator() + this.systemTextAdvise) + .variables(Map.of("long_term_memory", longTermMemory)) + .build() + .render(); + + // 3. Create a new request with the augmented system message. + ChatClientRequest processedChatClientRequest = chatClientRequest.mutate() + .prompt(chatClientRequest.prompt().augmentSystemMessage(augmentedSystemText)) .build(); - UserMessage userMessage = UserMessage.builder().text(request.userText()).media(request.media()).build(); - this.getChatMemoryStore() - .write(toDocuments(List.of(userMessage), this.doGetConversationId(request.adviseContext()))); + // 4. Add the new user message to the conversation memory. + UserMessage userMessage = processedChatClientRequest.prompt().getUserMessage(); + this.getChatMemoryStore().write(toDocuments(List.of(userMessage), conversationId)); - return advisedRequest; + return processedChatClientRequest; } - private void observeAfter(AdvisedResponse advisedResponse) { - - List assistantMessages = advisedResponse.response() - .getResults() - .stream() - .map(g -> (Message) g.getOutput()) - .toList(); - + private void after(ChatClientResponse chatClientResponse) { + List assistantMessages = new ArrayList<>(); + if (chatClientResponse.chatResponse() != null) { + assistantMessages = chatClientResponse.chatResponse() + .getResults() + .stream() + .map(g -> (Message) g.getOutput()) + .toList(); + } this.getChatMemoryStore() - .write(toDocuments(assistantMessages, this.doGetConversationId(advisedResponse.adviseContext()))); + .write(toDocuments(assistantMessages, this.doGetConversationId(chatClientResponse.context()))); } private List toDocuments(List messages, String conversationId) { - List docs = messages.stream() .filter(m -> m.getMessageType() == MessageType.USER || m.getMessageType() == MessageType.ASSISTANT) .map(message -> { diff --git a/advisors/spring-ai-advisors-vector-store/src/test/java/org/springframework/ai/chat/client/advisor/vectorstore/QuestionAnswerAdvisorTests.java b/advisors/spring-ai-advisors-vector-store/src/test/java/org/springframework/ai/chat/client/advisor/vectorstore/QuestionAnswerAdvisorTests.java index 79f7a9a1425..8fb33377428 100644 --- a/advisors/spring-ai-advisors-vector-store/src/test/java/org/springframework/ai/chat/client/advisor/vectorstore/QuestionAnswerAdvisorTests.java +++ b/advisors/spring-ai-advisors-vector-store/src/test/java/org/springframework/ai/chat/client/advisor/vectorstore/QuestionAnswerAdvisorTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -52,6 +52,7 @@ * @author Christian Tzolov * @author Timo Salm * @author Alexandros Pappas + * @author Thomas Vitale */ @ExtendWith(MockitoExtension.class) public class QuestionAnswerAdvisorTests { @@ -112,8 +113,9 @@ public Duration getTokensReset() { given(this.vectorStore.similaritySearch(this.vectorSearchCaptor.capture())) .willReturn(List.of(new Document("doc1"), new Document("doc2"))); - var qaAdvisor = new QuestionAnswerAdvisor(this.vectorStore, - SearchRequest.builder().similarityThreshold(0.99d).topK(6).build()); + var qaAdvisor = QuestionAnswerAdvisor.builder(this.vectorStore) + .searchRequest(SearchRequest.builder().similarityThreshold(0.99d).topK(6).build()) + .build(); var chatClient = ChatClient.builder(this.chatModel) .defaultSystem("Default system text.") @@ -187,7 +189,9 @@ public void qaAdvisorTakesUserTextParametersIntoAccountForSimilaritySearch() { .willReturn(List.of(new Document("doc1"), new Document("doc2"))); var chatClient = ChatClient.builder(this.chatModel).build(); - var qaAdvisor = new QuestionAnswerAdvisor(this.vectorStore, SearchRequest.builder().build()); + var qaAdvisor = QuestionAnswerAdvisor.builder(this.vectorStore) + .searchRequest(SearchRequest.builder().build()) + .build(); var userTextTemplate = "Please answer my question {question}"; // @formatter:off @@ -215,10 +219,15 @@ public void qaAdvisorTakesUserParameterizedUserMessagesIntoAccountForSimilarityS .willReturn(List.of(new Document("doc1"), new Document("doc2"))); var chatClient = ChatClient.builder(this.chatModel).build(); - var qaAdvisor = new QuestionAnswerAdvisor(this.vectorStore, SearchRequest.builder().build()); + var qaAdvisor = QuestionAnswerAdvisor.builder(this.vectorStore) + .searchRequest(SearchRequest.builder().build()) + .build(); var userTextTemplate = "Please answer my question {question}"; - var userPromptTemplate = PromptTemplate.builder().template(userTextTemplate).variables(Map.of("question", "XYZ")).build(); + var userPromptTemplate = PromptTemplate.builder() + .template(userTextTemplate) + .variables(Map.of("question", "XYZ")) + .build(); var userMessage = userPromptTemplate.createMessage(); // @formatter:off chatClient.prompt(new Prompt(userMessage)) diff --git a/auto-configurations/models/chat/client/spring-ai-autoconfigure-model-chat-client/src/main/java/org/springframework/ai/model/chat/client/autoconfigure/ChatClientAutoConfiguration.java b/auto-configurations/models/chat/client/spring-ai-autoconfigure-model-chat-client/src/main/java/org/springframework/ai/model/chat/client/autoconfigure/ChatClientAutoConfiguration.java index 6f288088087..2068f27222c 100644 --- a/auto-configurations/models/chat/client/spring-ai-autoconfigure-model-chat-client/src/main/java/org/springframework/ai/model/chat/client/autoconfigure/ChatClientAutoConfiguration.java +++ b/auto-configurations/models/chat/client/spring-ai-autoconfigure-model-chat-client/src/main/java/org/springframework/ai/model/chat/client/autoconfigure/ChatClientAutoConfiguration.java @@ -22,7 +22,6 @@ import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.ChatClientCustomizer; -import org.springframework.ai.chat.client.observation.ChatClientInputContentObservationFilter; import org.springframework.ai.chat.client.observation.ChatClientObservationConvention; import org.springframework.ai.chat.client.observation.ChatClientPromptContentObservationFilter; import org.springframework.ai.chat.model.ChatModel; @@ -80,20 +79,6 @@ ChatClient.Builder chatClientBuilder(ChatClientBuilderConfigurer chatClientBuild return chatClientBuilderConfigurer.configure(builder); } - /** - * @deprecated in favour of {@link #chatClientPromptContentObservationFilter()}. - */ - @Bean - @ConditionalOnMissingBean - @ConditionalOnProperty(prefix = ChatClientBuilderProperties.CONFIG_PREFIX + ".observations", name = "include-input", - havingValue = "true") - @Deprecated - ChatClientInputContentObservationFilter chatClientInputContentObservationFilter() { - logger.warn( - "You have enabled the inclusion of the input content in the observations, with the risk of exposing sensitive or private information. Please, be careful!"); - return new ChatClientInputContentObservationFilter(); - } - @Bean @ConditionalOnMissingBean @ConditionalOnProperty(prefix = ChatClientBuilderProperties.CONFIG_PREFIX + ".observations", diff --git a/auto-configurations/models/chat/client/spring-ai-autoconfigure-model-chat-client/src/main/java/org/springframework/ai/model/chat/client/autoconfigure/ChatClientBuilderProperties.java b/auto-configurations/models/chat/client/spring-ai-autoconfigure-model-chat-client/src/main/java/org/springframework/ai/model/chat/client/autoconfigure/ChatClientBuilderProperties.java index 84c2869ae2f..38f9bb39ed1 100644 --- a/auto-configurations/models/chat/client/spring-ai-autoconfigure-model-chat-client/src/main/java/org/springframework/ai/model/chat/client/autoconfigure/ChatClientBuilderProperties.java +++ b/auto-configurations/models/chat/client/spring-ai-autoconfigure-model-chat-client/src/main/java/org/springframework/ai/model/chat/client/autoconfigure/ChatClientBuilderProperties.java @@ -17,7 +17,6 @@ package org.springframework.ai.model.chat.client.autoconfigure; import org.springframework.boot.context.properties.ConfigurationProperties; -import org.springframework.boot.context.properties.DeprecatedConfigurationProperty; /** * Configuration properties for the chat client builder. @@ -55,27 +54,11 @@ public void setEnabled(boolean enabled) { public static class Observations { - /** - * Whether to include the input content in the observations. - * @deprecated Use {@link #includePrompt} instead. - */ - @Deprecated - private boolean includeInput = false; - /** * Whether to include the prompt content in the observations. */ private boolean includePrompt = false; - @DeprecatedConfigurationProperty(replacement = "spring.ai.chat.observations.include-prompt") - public boolean isIncludeInput() { - return this.includeInput; - } - - public void setIncludeInput(boolean includeCompletion) { - this.includeInput = includeCompletion; - } - public boolean isIncludePrompt() { return this.includePrompt; } diff --git a/auto-configurations/models/chat/client/spring-ai-autoconfigure-model-chat-client/src/test/java/org/springframework/ai/model/chat/client/autoconfigure/ChatClientObservationAutoConfigurationTests.java b/auto-configurations/models/chat/client/spring-ai-autoconfigure-model-chat-client/src/test/java/org/springframework/ai/model/chat/client/autoconfigure/ChatClientObservationAutoConfigurationTests.java index 9e646cdac51..8ab54ff0a76 100644 --- a/auto-configurations/models/chat/client/spring-ai-autoconfigure-model-chat-client/src/test/java/org/springframework/ai/model/chat/client/autoconfigure/ChatClientObservationAutoConfigurationTests.java +++ b/auto-configurations/models/chat/client/spring-ai-autoconfigure-model-chat-client/src/test/java/org/springframework/ai/model/chat/client/autoconfigure/ChatClientObservationAutoConfigurationTests.java @@ -18,7 +18,6 @@ import org.junit.jupiter.api.Test; -import org.springframework.ai.chat.client.observation.ChatClientInputContentObservationFilter; import org.springframework.ai.chat.client.observation.ChatClientPromptContentObservationFilter; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -36,18 +35,6 @@ class ChatClientObservationAutoConfigurationTests { private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withConfiguration(AutoConfigurations.of(ChatClientAutoConfiguration.class)); - @Test - void inputContentFilterDefault() { - this.contextRunner - .run(context -> assertThat(context).doesNotHaveBean(ChatClientInputContentObservationFilter.class)); - } - - @Test - void inputContentFilterEnabled() { - this.contextRunner.withPropertyValues("spring.ai.chat.client.observations.include-input=true") - .run(context -> assertThat(context).hasSingleBean(ChatClientInputContentObservationFilter.class)); - } - @Test void promptContentFilterDefault() { this.contextRunner diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiPaymentTransactionIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiPaymentTransactionIT.java index b80f5186dab..b3b3a876192 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiPaymentTransactionIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiPaymentTransactionIT.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,13 +28,10 @@ import org.junit.jupiter.params.provider.ValueSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor; import reactor.core.publisher.Flux; import org.springframework.ai.chat.client.ChatClient; -import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; -import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.model.tool.ToolCallingManager; import org.springframework.ai.openai.OpenAiChatModel; @@ -64,6 +61,7 @@ /** * @author Christian Tzolov + * @author Thomas Vitale */ @SpringBootTest @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*") @@ -81,7 +79,7 @@ public class OpenAiPaymentTransactionIT { @ValueSource(strings = { "paymentStatus", "paymentStatuses" }) public void transactionPaymentStatuses(String functionName) { List content = this.chatClient.prompt() - .advisors(new LoggingAdvisor()) + .advisors(new SimpleLoggerAdvisor()) .toolNames(functionName) .user(""" What is the status of my payment transactions 001, 002 and 003? @@ -112,7 +110,7 @@ public void streamingPaymentStatuses(String functionName) { }); Flux flux = this.chatClient.prompt() - .advisors(new LoggingAdvisor()) + .advisors(new SimpleLoggerAdvisor()) .toolNames(functionName) .user(u -> u.text(""" What is the status of my payment transactions 001, 002 and 003? @@ -141,49 +139,6 @@ record TransactionStatusResponse(String id, String status) { } - private static class LoggingAdvisor implements CallAroundAdvisor { - - private final Logger logger = LoggerFactory.getLogger(LoggingAdvisor.class); - - public String getName() { - return this.getClass().getSimpleName(); - } - - @Override - public int getOrder() { - return 0; - } - - @Override - public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { - - advisedRequest = this.before(advisedRequest); - - AdvisedResponse advisedResponse = chain.nextAroundCall(advisedRequest); - - this.observeAfter(advisedResponse); - - return advisedResponse; - } - - private AdvisedRequest before(AdvisedRequest request) { - logger.info("System text: \n" + request.systemText()); - logger.info("System params: " + request.systemParams()); - logger.info("User text: \n" + request.userText()); - logger.info("User params:" + request.userParams()); - logger.info("Function names: " + request.toolNames()); - - logger.info("Options: " + request.chatOptions().toString()); - - return request; - } - - private void observeAfter(AdvisedResponse advisedResponse) { - logger.info("Response: " + advisedResponse.response()); - } - - } - record Transaction(String id) { } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/ReReadingAdvisor.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/ReReadingAdvisor.java index 84832be8b68..cdd92d07b04 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/ReReadingAdvisor.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/ReReadingAdvisor.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,17 +16,13 @@ package org.springframework.ai.openai.chat.client; -import java.util.HashMap; import java.util.Map; -import reactor.core.publisher.Flux; - -import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; -import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; -import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; -import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; +import org.springframework.ai.chat.client.ChatClientRequest; +import org.springframework.ai.chat.client.ChatClientResponse; +import org.springframework.ai.chat.client.advisor.api.AdvisorChain; +import org.springframework.ai.chat.client.advisor.api.BaseAdvisor; +import org.springframework.ai.chat.prompt.PromptTemplate; /** * Drawing inspiration from the human strategy of re-reading, this advisor implements a @@ -36,9 +32,10 @@ * Language Models * * @author Christian Tzolov + * @author Thomas Vitale * @since 1.0.0 */ -public class ReReadingAdvisor implements CallAroundAdvisor, StreamAroundAdvisor { +public class ReReadingAdvisor implements BaseAdvisor { private static final String DEFAULT_RE2_ADVISE_TEMPLATE = """ {re2_input_query} @@ -57,29 +54,22 @@ public ReReadingAdvisor(String re2AdviseTemplate) { this.re2AdviseTemplate = re2AdviseTemplate; } - public String getName() { - return this.getClass().getSimpleName(); - } - - private AdvisedRequest before(AdvisedRequest advisedRequest) { - - Map advisedUserParams = new HashMap<>(advisedRequest.userParams()); - advisedUserParams.put("re2_input_query", advisedRequest.userText()); - - return AdvisedRequest.from(advisedRequest) - .userText(this.re2AdviseTemplate) - .userParams(advisedUserParams) - .build(); - } - @Override - public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { - return chain.nextAroundCall(this.before(advisedRequest)); + public ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChain advisorChain) { + String augmentedUserText = PromptTemplate.builder() + .template(re2AdviseTemplate) + .variables(Map.of("re2_input_query", chatClientRequest.prompt().getUserMessage().getText())) + .build() + .render(); + + return chatClientRequest.mutate() + .prompt(chatClientRequest.prompt().augmentUserMessage(augmentedUserText)) + .build(); } @Override - public Flux aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) { - return chain.nextAroundStream(this.before(advisedRequest)); + public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorChain advisorChain) { + return chatClientResponse; } @Override diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiPaymentTransactionIT.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiPaymentTransactionIT.java index 9ac5776bc87..b5901439bb0 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiPaymentTransactionIT.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiPaymentTransactionIT.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,13 +29,10 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor; import reactor.core.publisher.Flux; import org.springframework.ai.chat.client.ChatClient; -import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; -import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; import org.springframework.ai.model.tool.ToolCallingManager; import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor; @@ -57,6 +54,7 @@ /** * @author Christian Tzolov + * @author Thomas Vitale */ @SpringBootTest @EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_PROJECT_ID", matches = ".*") @@ -75,7 +73,7 @@ public class VertexAiGeminiPaymentTransactionIT { public void paymentStatuses() { // @formatter:off String content = this.chatClient.prompt() - .advisors(new LoggingAdvisor()) + .advisors(new SimpleLoggerAdvisor()) .toolNames("paymentStatus") .user(""" What is the status of my payment transactions 001, 002 and 003? @@ -92,7 +90,7 @@ public void paymentStatuses() { public void streamingPaymentStatuses() { Flux streamContent = this.chatClient.prompt() - .advisors(new LoggingAdvisor()) + .advisors(new SimpleLoggerAdvisor()) .toolNames("paymentStatus") .user(""" What is the status of my payment transactions 001, 002 and 003? @@ -120,45 +118,6 @@ record TransactionStatusResponse(String id, String status) { } - private static class LoggingAdvisor implements CallAroundAdvisor { - - private final Logger logger = LoggerFactory.getLogger(LoggingAdvisor.class); - - @Override - public String getName() { - return this.getClass().getSimpleName(); - } - - @Override - public int getOrder() { - return 0; - } - - @Override - public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { - var response = chain.nextAroundCall(before(advisedRequest)); - observeAfter(response); - return response; - } - - private AdvisedRequest before(AdvisedRequest request) { - logger.info("System text: \n" + request.systemText()); - logger.info("System params: " + request.systemParams()); - logger.info("User text: \n" + request.userText()); - logger.info("User params:" + request.userParams()); - logger.info("Function names: " + request.toolNames()); - - logger.info("Options: " + request.chatOptions().toString()); - - return request; - } - - private void observeAfter(AdvisedResponse advisedResponse) { - logger.info("Response: " + advisedResponse.response()); - } - - } - record Transaction(String id) { } diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiPaymentTransactionMethodIT.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiPaymentTransactionMethodIT.java index b5c3d2e07dc..75296acf525 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiPaymentTransactionMethodIT.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiPaymentTransactionMethodIT.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,13 +29,10 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor; import reactor.core.publisher.Flux; import org.springframework.ai.chat.client.ChatClient; -import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; -import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; import org.springframework.ai.model.tool.ToolCallingManager; import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.ToolCallbackProvider; @@ -59,6 +56,7 @@ /** * @author Christian Tzolov + * @author Thomas Vitale */ @SpringBootTest @EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_PROJECT_ID", matches = ".*") @@ -76,10 +74,15 @@ public class VertexAiGeminiPaymentTransactionMethodIT { @Test public void paymentStatuses() { - String content = this.chatClient.prompt().advisors(new LoggingAdvisor()).toolNames("paymentStatus").user(""" - What is the status of my payment transactions 001, 002 and 003? - If requred invoke the function per transaction. - """).call().content(); + String content = this.chatClient.prompt() + .advisors(new SimpleLoggerAdvisor()) + .toolNames("paymentStatus") + .user(""" + What is the status of my payment transactions 001, 002 and 003? + If requred invoke the function per transaction. + """) + .call() + .content(); logger.info("" + content); assertThat(content).contains("001", "002", "003"); @@ -90,7 +93,7 @@ public void paymentStatuses() { public void streamingPaymentStatuses() { Flux streamContent = this.chatClient.prompt() - .advisors(new LoggingAdvisor()) + .advisors(new SimpleLoggerAdvisor()) .toolNames("paymentStatus") .user(""" What is the status of my payment transactions 001, 002 and 003? @@ -118,45 +121,6 @@ record TransactionStatusResponse(String id, String status) { } - private static class LoggingAdvisor implements CallAroundAdvisor { - - private final Logger logger = LoggerFactory.getLogger(LoggingAdvisor.class); - - @Override - public String getName() { - return this.getClass().getSimpleName(); - } - - @Override - public int getOrder() { - return 0; - } - - @Override - public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { - var response = chain.nextAroundCall(before(advisedRequest)); - observeAfter(response); - return response; - } - - private AdvisedRequest before(AdvisedRequest request) { - logger.info("System text: \n" + request.systemText()); - logger.info("System params: " + request.systemParams()); - logger.info("User text: \n" + request.userText()); - logger.info("User params:" + request.userParams()); - logger.info("Function names: " + request.toolNames()); - - logger.info("Options: " + request.chatOptions().toString()); - - return request; - } - - private void observeAfter(AdvisedResponse advisedResponse) { - logger.info("Response: " + advisedResponse.response()); - } - - } - record Transaction(String id) { } diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiPaymentTransactionToolsIT.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiPaymentTransactionToolsIT.java index 9fc92285318..ba905754299 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiPaymentTransactionToolsIT.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiPaymentTransactionToolsIT.java @@ -28,13 +28,10 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor; import reactor.core.publisher.Flux; import org.springframework.ai.chat.client.ChatClient; -import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; -import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; import org.springframework.ai.model.tool.ToolCallingManager; import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.annotation.Tool; @@ -56,6 +53,7 @@ /** * @author Christian Tzolov + * @author Thomas Vitale */ @SpringBootTest @EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_PROJECT_ID", matches = ".*") @@ -74,7 +72,7 @@ public class VertexAiGeminiPaymentTransactionToolsIT { public void paymentStatuses() { // @formatter:off String content = this.chatClient.prompt() - .advisors(new LoggingAdvisor()) + .advisors(new SimpleLoggerAdvisor()) .tools(new MyTools()) .user(""" What is the status of my payment transactions 001, 002 and 003? @@ -91,7 +89,7 @@ public void paymentStatuses() { public void streamingPaymentStatuses() { Flux streamContent = this.chatClient.prompt() - .advisors(new LoggingAdvisor()) + .advisors(new SimpleLoggerAdvisor()) .tools(new MyTools()) .user(""" What is the status of my payment transactions 001, 002 and 003? @@ -119,45 +117,6 @@ record TransactionStatusResponse(String id, String status) { } - private static class LoggingAdvisor implements CallAroundAdvisor { - - private final Logger logger = LoggerFactory.getLogger(LoggingAdvisor.class); - - @Override - public String getName() { - return this.getClass().getSimpleName(); - } - - @Override - public int getOrder() { - return 0; - } - - @Override - public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { - var response = chain.nextAroundCall(before(advisedRequest)); - observeAfter(response); - return response; - } - - private AdvisedRequest before(AdvisedRequest request) { - logger.info("System text: \n" + request.systemText()); - logger.info("System params: " + request.systemParams()); - logger.info("User text: \n" + request.userText()); - logger.info("User params:" + request.userParams()); - logger.info("Function names: " + request.toolNames()); - - logger.info("Options: " + request.chatOptions().toString()); - - return request; - } - - private void observeAfter(AdvisedResponse advisedResponse) { - logger.info("Response: " + advisedResponse.response()); - } - - } - record Transaction(String id) { } diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClientAttributes.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClientAttributes.java index 02b26768317..b7a967c3a3b 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClientAttributes.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClientAttributes.java @@ -26,15 +26,7 @@ public enum ChatClientAttributes { //@formatter:off - @Deprecated // Only for backward compatibility until the next release. - ADVISORS("spring.ai.chat.client.advisors"), - @Deprecated // Only for backward compatibility until the next release. - CHAT_MODEL("spring.ai.chat.client.model"), - OUTPUT_FORMAT("spring.ai.chat.client.output.format"), - @Deprecated // Only for backward compatibility until the next release. - USER_PARAMS("spring.ai.chat.client.user.params"), - @Deprecated // Only for backward compatibility until the next release. - SYSTEM_PARAMS("spring.ai.chat.client.system.params"); + OUTPUT_FORMAT("spring.ai.chat.client.output.format"); //@formatter:on diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java index d5a5f2de6fb..a83898c5779 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java @@ -22,8 +22,8 @@ import java.nio.charset.Charset; import java.util.ArrayList; import java.util.Arrays; -import java.util.Collection; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; @@ -38,7 +38,6 @@ import org.springframework.ai.chat.client.advisor.ChatModelCallAdvisor; import org.springframework.ai.chat.client.advisor.ChatModelStreamAdvisor; import org.springframework.ai.chat.client.advisor.DefaultAroundAdvisorChain; -import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; import org.springframework.ai.chat.client.advisor.api.Advisor; import org.springframework.ai.chat.client.advisor.api.BaseAdvisorChain; import org.springframework.ai.chat.client.observation.ChatClientObservationContext; @@ -47,16 +46,18 @@ import org.springframework.ai.chat.client.observation.DefaultChatClientObservationConvention; import org.springframework.ai.chat.messages.AbstractMessage; import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.MessageType; +import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.content.Media; import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.converter.StructuredOutputConverter; +import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.template.TemplateRenderer; import org.springframework.ai.template.st.StTemplateRenderer; import org.springframework.ai.tool.ToolCallback; @@ -64,7 +65,6 @@ import org.springframework.ai.tool.ToolCallbacks; import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.io.Resource; -import org.springframework.lang.NonNull; import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; @@ -97,52 +97,6 @@ public DefaultChatClient(DefaultChatClientRequestSpec defaultChatClientRequest) this.defaultChatClientRequest = defaultChatClientRequest; } - private static AdvisedRequest toAdvisedRequest(DefaultChatClientRequestSpec inputRequest) { - Assert.notNull(inputRequest, "inputRequest cannot be null"); - - Map advisorContext = new ConcurrentHashMap<>(inputRequest.getAdvisorParams()); - - // Process userText, media and messages before creating the AdvisedRequest. - String userText = inputRequest.userText; - List media = inputRequest.media; - List messages = inputRequest.messages; - - // If the userText is empty, then try extracting the userText from the last - // message - // in the messages list and remove it from the messages list. - if (!StringUtils.hasText(userText) && !CollectionUtils.isEmpty(messages)) { - Message lastMessage = messages.get(messages.size() - 1); - if (lastMessage.getMessageType() == MessageType.USER) { - UserMessage userMessage = (UserMessage) lastMessage; - if (StringUtils.hasText(userMessage.getText())) { - userText = lastMessage.getText(); - } - Collection messageMedia = userMessage.getMedia(); - if (!CollectionUtils.isEmpty(messageMedia)) { - media.addAll(messageMedia); - } - messages = messages.subList(0, messages.size() - 1); - } - } - - return new AdvisedRequest(inputRequest.chatModel, userText, inputRequest.systemText, inputRequest.chatOptions, - media, inputRequest.toolNames, inputRequest.toolCallbacks, messages, inputRequest.userParams, - inputRequest.systemParams, inputRequest.advisors, inputRequest.advisorParams, advisorContext, - inputRequest.toolContext); - } - - @Deprecated - public static DefaultChatClientRequestSpec toDefaultChatClientRequestSpec(AdvisedRequest advisedRequest, - ObservationRegistry observationRegistry, ChatClientObservationConvention customObservationConvention) { - - return new DefaultChatClientRequestSpec(advisedRequest.chatModel(), advisedRequest.userText(), - advisedRequest.userParams(), advisedRequest.systemText(), advisedRequest.systemParams(), - advisedRequest.toolCallbacks(), advisedRequest.messages(), advisedRequest.toolNames(), - advisedRequest.media(), advisedRequest.chatOptions(), advisedRequest.advisors(), - advisedRequest.advisorParams(), observationRegistry, customObservationConvention, - advisedRequest.toolContext(), null); - } - @Override public ChatClientRequestSpec prompt() { return new DefaultChatClientRequestSpec(this.defaultChatClientRequest); @@ -510,7 +464,7 @@ private ChatClientResponse doGetObservableChatClientResponse(ChatClientRequest c .request(chatClientRequest) .advisors(advisorChain.getCallAdvisors()) .stream(false) - .withFormat(outputFormat) + .format(outputFormat) .build(); var observation = ChatClientObservationDocumentation.AI_CHAT_CLIENT.observation(observationConvention, @@ -522,19 +476,6 @@ private ChatClientResponse doGetObservableChatClientResponse(ChatClientRequest c return chatClientResponse != null ? chatClientResponse : ChatClientResponse.builder().build(); } - @NonNull - private static ChatClientRequest augmentPromptWithFormatInstructions(ChatClientRequest chatClientRequest, - String outputFormat) { - Prompt augmentedPrompt = chatClientRequest.prompt() - .augmentUserMessage(userMessage -> userMessage.mutate() - .text(userMessage.getText() + System.lineSeparator() + outputFormat) - .build()); - return ChatClientRequest.builder() - .prompt(augmentedPrompt) - .context(Map.copyOf(chatClientRequest.context())) - .build(); - } - @Nullable private static String getContentFromChatResponse(@Nullable ChatResponse chatResponse) { return Optional.ofNullable(chatResponse) @@ -709,14 +650,6 @@ public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userTe this.templateRenderer = templateRenderer != null ? templateRenderer : DEFAULT_TEMPLATE_RENDERER; } - private ObservationRegistry getObservationRegistry() { - return this.observationRegistry; - } - - private ChatClientObservationConvention getCustomObservationConvention() { - return this.observationConvention; - } - @Nullable public String getUserText() { return this.userText; @@ -768,6 +701,10 @@ public Map getToolContext() { return this.toolContext; } + public TemplateRenderer getTemplateRenderer() { + return this.templateRenderer; + } + /** * Return a {@code ChatClient2Builder} to create a new {@code ChatClient2} whose * settings are replicated from this {@code ChatClientRequest}. @@ -775,6 +712,9 @@ public Map getToolContext() { public Builder mutate() { DefaultChatClientBuilder builder = (DefaultChatClientBuilder) ChatClient .builder(this.chatModel, this.observationRegistry, this.observationConvention) + .defaultTemplateRenderer(this.templateRenderer) + .defaultToolCallbacks(this.toolCallbacks) + .defaultToolContext(this.toolContext) .defaultToolNames(StringUtils.toStringArray(this.toolNames)); if (StringUtils.hasText(this.userText)) { @@ -791,8 +731,6 @@ public Builder mutate() { } builder.addMessages(this.messages); - builder.addToolCallbacks(this.toolCallbacks); - builder.addToolContext(this.toolContext); return builder; } @@ -968,14 +906,14 @@ public ChatClientRequestSpec templateRenderer(TemplateRenderer templateRenderer) public CallResponseSpec call() { BaseAdvisorChain advisorChain = buildAdvisorChain(); - return new DefaultCallResponseSpec(toAdvisedRequest(this).toChatClientRequest(this.templateRenderer), - advisorChain, observationRegistry, observationConvention); + return new DefaultCallResponseSpec(DefaultChatClientUtils.toChatClientRequest(this), advisorChain, + observationRegistry, observationConvention); } public StreamResponseSpec stream() { BaseAdvisorChain advisorChain = buildAdvisorChain(); - return new DefaultStreamResponseSpec(toAdvisedRequest(this).toChatClientRequest(this.templateRenderer), - advisorChain, observationRegistry, observationConvention); + return new DefaultStreamResponseSpec(DefaultChatClientUtils.toChatClientRequest(this), advisorChain, + observationRegistry, observationConvention); } private BaseAdvisorChain buildAdvisorChain() { diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java index 35e011af497..8d314b0ef59 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java @@ -36,7 +36,6 @@ import org.springframework.ai.template.TemplateRenderer; import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.ToolCallbackProvider; -import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.core.io.Resource; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -180,13 +179,6 @@ public Builder defaultToolCallbacks(ToolCallbackProvider... toolCallbackProvider return this; } - @Deprecated // Use defaultTools() - public Builder defaultFunction(String name, String description, java.util.function.Function function) { - this.defaultRequest - .toolCallbacks(FunctionToolCallback.builder(name, function).description(description).build()); - return this; - } - public Builder defaultToolContext(Map toolContext) { this.defaultRequest.toolContext(toolContext); return this; @@ -202,13 +194,4 @@ void addMessages(List messages) { this.defaultRequest.messages(messages); } - void addToolCallbacks(List toolCallbacks) { - Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); - this.defaultRequest.toolCallbacks(toolCallbacks.toArray(ToolCallback[]::new)); - } - - void addToolContext(Map toolContext) { - this.defaultRequest.toolContext(toolContext); - } - } diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java new file mode 100644 index 00000000000..3b793d6a99a --- /dev/null +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientUtils.java @@ -0,0 +1,122 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.client; + +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; + +/** + * Utilities for supporting the {@link DefaultChatClient} implementation. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +class DefaultChatClientUtils { + + static ChatClientRequest toChatClientRequest(DefaultChatClient.DefaultChatClientRequestSpec inputRequest) { + Assert.notNull(inputRequest, "inputRequest cannot be null"); + + /* + * ==========* MESSAGES * ========== + */ + + List processedMessages = new ArrayList<>(); + + // System Text => First in the list + String processedSystemText = inputRequest.getSystemText(); + if (StringUtils.hasText(processedSystemText)) { + if (!CollectionUtils.isEmpty(inputRequest.getSystemParams())) { + processedSystemText = PromptTemplate.builder() + .template(processedSystemText) + .variables(inputRequest.getSystemParams()) + .renderer(inputRequest.getTemplateRenderer()) + .build() + .render(); + } + processedMessages.add(new SystemMessage(processedSystemText)); + } + + // Messages => In the middle of the list + if (!CollectionUtils.isEmpty(inputRequest.getMessages())) { + processedMessages.addAll(inputRequest.getMessages()); + } + + // User Test => Last in the list + String processedUserText = inputRequest.getUserText(); + if (StringUtils.hasText(processedUserText)) { + if (!CollectionUtils.isEmpty(inputRequest.getUserParams())) { + processedUserText = PromptTemplate.builder() + .template(processedUserText) + .variables(inputRequest.getUserParams()) + .renderer(inputRequest.getTemplateRenderer()) + .build() + .render(); + } + processedMessages.add(UserMessage.builder().text(processedUserText).media(inputRequest.getMedia()).build()); + } + + /* + * ==========* OPTIONS * ========== + */ + + ChatOptions processedChatOptions = inputRequest.getChatOptions(); + if (processedChatOptions instanceof ToolCallingChatOptions toolCallingChatOptions) { + if (!inputRequest.getToolNames().isEmpty()) { + Set toolNames = ToolCallingChatOptions + .mergeToolNames(new HashSet<>(inputRequest.getToolNames()), toolCallingChatOptions.getToolNames()); + toolCallingChatOptions.setToolNames(toolNames); + } + if (!inputRequest.getToolCallbacks().isEmpty()) { + List toolCallbacks = ToolCallingChatOptions + .mergeToolCallbacks(inputRequest.getToolCallbacks(), toolCallingChatOptions.getToolCallbacks()); + ToolCallingChatOptions.validateToolCallbacks(toolCallbacks); + toolCallingChatOptions.setToolCallbacks(toolCallbacks); + } + if (!CollectionUtils.isEmpty(inputRequest.getToolContext())) { + Map toolContext = ToolCallingChatOptions.mergeToolContext(inputRequest.getToolContext(), + toolCallingChatOptions.getToolContext()); + toolCallingChatOptions.setToolContext(toolContext); + } + } + + /* + * ==========* REQUEST * ========== + */ + + return ChatClientRequest.builder() + .prompt(Prompt.builder().messages(processedMessages).chatOptions(processedChatOptions).build()) + .context(new ConcurrentHashMap<>(inputRequest.getAdvisorParams())) + .build(); + } + +} diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java index 3d778937647..214ec9671c9 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,17 +19,17 @@ import java.util.Map; import java.util.function.Function; -import org.springframework.ai.chat.memory.ChatMemory; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; -import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; -import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; +import org.springframework.ai.chat.client.ChatClientRequest; +import org.springframework.ai.chat.client.ChatClientResponse; import org.springframework.ai.chat.client.advisor.api.Advisor; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; -import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; -import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; +import org.springframework.ai.chat.client.advisor.api.CallAdvisor; +import org.springframework.ai.chat.client.advisor.api.StreamAdvisor; +import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain; +import org.springframework.ai.chat.memory.ChatMemory; import org.springframework.util.Assert; /** @@ -38,9 +38,10 @@ * @param the type of the chat memory. * @author Christian Tzolov * @author Ilayaperumal Gopinathan + * @author Thomas Vitale * @since 1.0.0 */ -public abstract class AbstractChatMemoryAdvisor implements CallAroundAdvisor, StreamAroundAdvisor { +public abstract class AbstractChatMemoryAdvisor implements CallAdvisor, StreamAdvisor { /** * The key to retrieve the chat memory conversation id from the context. @@ -176,26 +177,18 @@ protected int doGetChatMemoryRetrieveSize(Map context) { : this.defaultChatMemoryRetrieveSize; } - /** - * Execute the next advisor in the chain. - * @param advisedRequest the advised request - * @param chain the advisor chain - * @param beforeAdvise the before advise function - * @return the advised response - */ - protected Flux doNextWithProtectFromBlockingBefore(AdvisedRequest advisedRequest, - StreamAroundAdvisorChain chain, Function beforeAdvise) { - + protected Flux doNextWithProtectFromBlockingBefore(ChatClientRequest chatClientRequest, + StreamAdvisorChain streamAdvisorChain, Function before) { // This can be executed by both blocking and non-blocking Threads // E.g. a command line or Tomcat blocking Thread implementation // or by a WebFlux dispatch in a non-blocking manner. return (this.protectFromBlocking) ? // @formatter:off - Mono.just(advisedRequest) - .publishOn(Schedulers.boundedElastic()) - .map(beforeAdvise) - .flatMapMany(request -> chain.nextAroundStream(request)) - : chain.nextAroundStream(beforeAdvise.apply(advisedRequest)); + Mono.just(chatClientRequest) + .publishOn(Schedulers.boundedElastic()) + .map(before) + .flatMapMany(streamAdvisorChain::nextStream) + : streamAdvisorChain.nextStream(before.apply(chatClientRequest)); } /** @@ -242,7 +235,7 @@ protected AbstractBuilder(T chatMemory) { * @param conversationId the conversation id * @return the builder */ - public AbstractBuilder conversationId(String conversationId) { + public AbstractBuilder conversationId(String conversationId) { this.conversationId = conversationId; return this; } @@ -252,7 +245,7 @@ public AbstractBuilder conversationId(String conversationId) { * @param chatMemoryRetrieveSize the chat memory retrieve size * @return the builder */ - public AbstractBuilder chatMemoryRetrieveSize(int chatMemoryRetrieveSize) { + public AbstractBuilder chatMemoryRetrieveSize(int chatMemoryRetrieveSize) { this.chatMemoryRetrieveSize = chatMemoryRetrieveSize; return this; } @@ -262,7 +255,7 @@ public AbstractBuilder chatMemoryRetrieveSize(int chatMemoryRetrieveSize) { * @param protectFromBlocking whether to protect from blocking * @return the builder */ - public AbstractBuilder protectFromBlocking(boolean protectFromBlocking) { + public AbstractBuilder protectFromBlocking(boolean protectFromBlocking) { this.protectFromBlocking = protectFromBlocking; return this; } @@ -272,7 +265,7 @@ public AbstractBuilder protectFromBlocking(boolean protectFromBlocking) { * @param order the order * @return the builder */ - public AbstractBuilder order(int order) { + public AbstractBuilder order(int order) { this.order = order; return this; } diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/ChatModelCallAdvisor.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/ChatModelCallAdvisor.java index d2dd9cf4d62..390208c675f 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/ChatModelCallAdvisor.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/ChatModelCallAdvisor.java @@ -20,7 +20,7 @@ import org.springframework.ai.chat.client.ChatClientRequest; import org.springframework.ai.chat.client.ChatClientResponse; import org.springframework.ai.chat.client.advisor.api.CallAdvisor; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; +import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; @@ -46,7 +46,7 @@ private ChatModelCallAdvisor(ChatModel chatModel) { } @Override - public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAroundAdvisorChain chain) { + public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) { Assert.notNull(chatClientRequest, "the chatClientRequest cannot be null"); ChatClientRequest formattedChatClientRequest = augmentWithFormatInstructions(chatClientRequest); diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/ChatModelStreamAdvisor.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/ChatModelStreamAdvisor.java index f9b04c23f5f..9bec78fc4d9 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/ChatModelStreamAdvisor.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/ChatModelStreamAdvisor.java @@ -43,7 +43,8 @@ private ChatModelStreamAdvisor(ChatModel chatModel) { } @Override - public Flux adviseStream(ChatClientRequest chatClientRequest, StreamAroundAdvisorChain chain) { + public Flux adviseStream(ChatClientRequest chatClientRequest, + StreamAdvisorChain streamAdvisorChain) { Assert.notNull(chatClientRequest, "the chatClientRequest cannot be null"); return chatModel.stream(chatClientRequest.prompt()) diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChain.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChain.java index dd6d0d3da8f..838b8cbc5b6 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChain.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChain.java @@ -26,13 +26,9 @@ import org.springframework.ai.chat.client.ChatClientRequest; import org.springframework.ai.chat.client.ChatClientResponse; import org.springframework.ai.chat.client.advisor.api.Advisor; -import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; -import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; import org.springframework.ai.chat.client.advisor.api.BaseAdvisorChain; import org.springframework.ai.chat.client.advisor.api.CallAdvisor; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; import org.springframework.ai.chat.client.advisor.api.StreamAdvisor; -import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; import org.springframework.ai.template.TemplateRenderer; import org.springframework.ai.template.st.StTemplateRenderer; import org.springframework.lang.Nullable; @@ -62,51 +58,46 @@ public class DefaultAroundAdvisorChain implements BaseAdvisorChain { private static final TemplateRenderer DEFAULT_TEMPLATE_RENDERER = StTemplateRenderer.builder().build(); - private final List originalCallAdvisors; + private final List originalCallAdvisors; - private final List originalStreamAdvisors; + private final List originalStreamAdvisors; - private final Deque callAroundAdvisors; + private final Deque callAdvisors; - private final Deque streamAroundAdvisors; + private final Deque streamAdvisors; private final ObservationRegistry observationRegistry; private final TemplateRenderer templateRenderer; DefaultAroundAdvisorChain(ObservationRegistry observationRegistry, @Nullable TemplateRenderer templateRenderer, - Deque callAroundAdvisors, Deque streamAroundAdvisors) { + Deque callAdvisors, Deque streamAdvisors) { Assert.notNull(observationRegistry, "the observationRegistry must be non-null"); - Assert.notNull(callAroundAdvisors, "the callAroundAdvisors must be non-null"); - Assert.notNull(streamAroundAdvisors, "the streamAroundAdvisors must be non-null"); + Assert.notNull(callAdvisors, "the callAdvisors must be non-null"); + Assert.notNull(streamAdvisors, "the streamAdvisors must be non-null"); this.observationRegistry = observationRegistry; this.templateRenderer = templateRenderer != null ? templateRenderer : DEFAULT_TEMPLATE_RENDERER; - this.callAroundAdvisors = callAroundAdvisors; - this.streamAroundAdvisors = streamAroundAdvisors; - this.originalCallAdvisors = List.copyOf(callAroundAdvisors); - this.originalStreamAdvisors = List.copyOf(streamAroundAdvisors); + this.callAdvisors = callAdvisors; + this.streamAdvisors = streamAdvisors; + this.originalCallAdvisors = List.copyOf(callAdvisors); + this.originalStreamAdvisors = List.copyOf(streamAdvisors); } public static Builder builder(ObservationRegistry observationRegistry) { return new Builder(observationRegistry); } - @Override - public TemplateRenderer getTemplateRenderer() { - return this.templateRenderer; - } - @Override public ChatClientResponse nextCall(ChatClientRequest chatClientRequest) { Assert.notNull(chatClientRequest, "the chatClientRequest cannot be null"); - if (this.callAroundAdvisors.isEmpty()) { + if (this.callAdvisors.isEmpty()) { throw new IllegalStateException("No CallAdvisors available to execute"); } - var advisor = this.callAroundAdvisors.pop(); + var advisor = this.callAdvisors.pop(); var observationContext = AdvisorObservationContext.builder() .advisorName(advisor.getName()) @@ -116,52 +107,7 @@ public ChatClientResponse nextCall(ChatClientRequest chatClientRequest) { return AdvisorObservationDocumentation.AI_ADVISOR .observation(null, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry) - .observe(() -> { - // Supports both deprecated and new API. - if (advisor instanceof CallAdvisor callAdvisor) { - return callAdvisor.adviseCall(chatClientRequest, this); - } - AdvisedResponse advisedResponse = advisor.aroundCall(AdvisedRequest.from(chatClientRequest), this); - ChatClientResponse chatClientResponse = advisedResponse.toChatClientResponse(); - observationContext.setChatClientResponse(chatClientResponse); - return chatClientResponse; - }); - } - - /** - * @deprecated Use {@link #nextCall(ChatClientRequest)} instead - */ - @Override - @Deprecated - public AdvisedResponse nextAroundCall(AdvisedRequest advisedRequest) { - Assert.notNull(advisedRequest, "the advisedRequest cannot be null"); - - if (this.callAroundAdvisors.isEmpty()) { - throw new IllegalStateException("No AroundAdvisor available to execute"); - } - - var advisor = this.callAroundAdvisors.pop(); - - var observationContext = AdvisorObservationContext.builder() - .advisorName(advisor.getName()) - .chatClientRequest(advisedRequest.toChatClientRequest(templateRenderer)) - .order(advisor.getOrder()) - .build(); - - return AdvisorObservationDocumentation.AI_ADVISOR - .observation(null, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry) - .observe(() -> { - // Supports both deprecated and new API. - if (advisor instanceof CallAdvisor callAdvisor) { - ChatClientResponse chatClientResponse = callAdvisor - .adviseCall(advisedRequest.toChatClientRequest(templateRenderer), this); - return AdvisedResponse.from(chatClientResponse); - } - AdvisedResponse advisedResponse = advisor.aroundCall(advisedRequest, this); - ChatClientResponse chatClientResponse = advisedResponse.toChatClientResponse(); - observationContext.setChatClientResponse(chatClientResponse); - return advisedResponse; - }); + .observe(() -> advisor.adviseCall(chatClientRequest, this)); } @Override @@ -169,11 +115,11 @@ public Flux nextStream(ChatClientRequest chatClientRequest) Assert.notNull(chatClientRequest, "the chatClientRequest cannot be null"); return Flux.deferContextual(contextView -> { - if (this.streamAroundAdvisors.isEmpty()) { + if (this.streamAdvisors.isEmpty()) { return Flux.error(new IllegalStateException("No StreamAdvisors available to execute")); } - var advisor = this.streamAroundAdvisors.pop(); + var advisor = this.streamAdvisors.pop(); AdvisorObservationContext observationContext = AdvisorObservationContext.builder() .advisorName(advisor.getName()) @@ -187,77 +133,21 @@ public Flux nextStream(ChatClientRequest chatClientRequest) observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start(); // @formatter:off - return Flux.defer(() -> { - // Supports both deprecated and new API. - if (advisor instanceof StreamAdvisor streamAdvisor) { - return streamAdvisor.adviseStream(chatClientRequest, this) - .doOnError(observation::error) - .doFinally(s -> observation.stop()) - .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); - } - return advisor.aroundStream(AdvisedRequest.from(chatClientRequest), this) - .doOnError(observation::error) - .doFinally(s -> observation.stop()) - .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)) - .map(AdvisedResponse::toChatClientResponse); - }); - // @formatter:on - }); - } - - /** - * @deprecated Use {@link #nextStream(ChatClientRequest)} instead. - */ - @Override - @Deprecated - public Flux nextAroundStream(AdvisedRequest advisedRequest) { - Assert.notNull(advisedRequest, "the advisedRequest cannot be null"); - - return Flux.deferContextual(contextView -> { - if (this.streamAroundAdvisors.isEmpty()) { - return Flux.error(new IllegalStateException("No AroundAdvisor available to execute")); - } - - var advisor = this.streamAroundAdvisors.pop(); - - AdvisorObservationContext observationContext = AdvisorObservationContext.builder() - .advisorName(advisor.getName()) - .chatClientRequest(advisedRequest.toChatClientRequest(templateRenderer)) - .order(advisor.getOrder()) - .build(); - - var observation = AdvisorObservationDocumentation.AI_ADVISOR.observation(null, - DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry); - - observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start(); - - // @formatter:off - return Flux.defer(() -> { - // Supports both deprecated and new API. - if (advisor instanceof StreamAdvisor streamAdvisor) { - return streamAdvisor.adviseStream(advisedRequest.toChatClientRequest(templateRenderer), this) - .doOnError(observation::error) - .doFinally(s -> observation.stop()) - .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)) - .map(AdvisedResponse::from); - } - - return advisor.aroundStream(advisedRequest, this) + return Flux.defer(() -> advisor.adviseStream(chatClientRequest, this) .doOnError(observation::error) .doFinally(s -> observation.stop()) - .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); - }); + .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation))); // @formatter:on }); } @Override - public List getCallAdvisors() { + public List getCallAdvisors() { return this.originalCallAdvisors; } @Override - public List getStreamAdvisors() { + public List getStreamAdvisors() { return this.originalStreamAdvisors; } @@ -270,16 +160,16 @@ public static class Builder { private final ObservationRegistry observationRegistry; - private final Deque callAroundAdvisors; + private final Deque callAdvisors; - private final Deque streamAroundAdvisors; + private final Deque streamAdvisors; private TemplateRenderer templateRenderer; public Builder(ObservationRegistry observationRegistry) { this.observationRegistry = observationRegistry; - this.callAroundAdvisors = new ConcurrentLinkedDeque<>(); - this.streamAroundAdvisors = new ConcurrentLinkedDeque<>(); + this.callAdvisors = new ConcurrentLinkedDeque<>(); + this.streamAdvisors = new ConcurrentLinkedDeque<>(); } public Builder templateRenderer(TemplateRenderer templateRenderer) { @@ -296,22 +186,22 @@ public Builder pushAll(List advisors) { Assert.notNull(advisors, "the advisors must be non-null"); Assert.noNullElements(advisors, "the advisors must not contain null elements"); if (!CollectionUtils.isEmpty(advisors)) { - List callAroundAdvisorList = advisors.stream() - .filter(a -> a instanceof CallAroundAdvisor) - .map(a -> (CallAroundAdvisor) a) + List callAroundAdvisorList = advisors.stream() + .filter(a -> a instanceof CallAdvisor) + .map(a -> (CallAdvisor) a) .toList(); if (!CollectionUtils.isEmpty(callAroundAdvisorList)) { - callAroundAdvisorList.forEach(this.callAroundAdvisors::push); + callAroundAdvisorList.forEach(this.callAdvisors::push); } - List streamAroundAdvisorList = advisors.stream() - .filter(a -> a instanceof StreamAroundAdvisor) - .map(a -> (StreamAroundAdvisor) a) + List streamAroundAdvisorList = advisors.stream() + .filter(a -> a instanceof StreamAdvisor) + .map(a -> (StreamAdvisor) a) .toList(); if (!CollectionUtils.isEmpty(streamAroundAdvisorList)) { - streamAroundAdvisorList.forEach(this.streamAroundAdvisors::push); + streamAroundAdvisorList.forEach(this.streamAdvisors::push); } this.reOrder(); @@ -323,20 +213,20 @@ public Builder pushAll(List advisors) { * (Re)orders the advisors in priority order based on their Ordered attribute. */ private void reOrder() { - ArrayList callAdvisors = new ArrayList<>(this.callAroundAdvisors); + ArrayList callAdvisors = new ArrayList<>(this.callAdvisors); OrderComparator.sort(callAdvisors); - this.callAroundAdvisors.clear(); - callAdvisors.forEach(this.callAroundAdvisors::addLast); + this.callAdvisors.clear(); + callAdvisors.forEach(this.callAdvisors::addLast); - ArrayList streamAdvisors = new ArrayList<>(this.streamAroundAdvisors); + ArrayList streamAdvisors = new ArrayList<>(this.streamAdvisors); OrderComparator.sort(streamAdvisors); - this.streamAroundAdvisors.clear(); - streamAdvisors.forEach(this.streamAroundAdvisors::addLast); + this.streamAdvisors.clear(); + streamAdvisors.forEach(this.streamAdvisors::addLast); } public DefaultAroundAdvisorChain build() { - return new DefaultAroundAdvisorChain(this.observationRegistry, this.templateRenderer, - this.callAroundAdvisors, this.streamAroundAdvisors); + return new DefaultAroundAdvisorChain(this.observationRegistry, this.templateRenderer, this.callAdvisors, + this.streamAdvisors); } } diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java index 6a563e52025..6a6862a93c7 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,11 +21,11 @@ import reactor.core.publisher.Flux; -import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; -import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; +import org.springframework.ai.chat.client.ChatClientRequest; +import org.springframework.ai.chat.client.ChatClientResponse; import org.springframework.ai.chat.client.advisor.api.Advisor; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; -import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; +import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain; +import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain; import org.springframework.ai.chat.memory.ChatMemory; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; @@ -57,58 +57,59 @@ public static Builder builder(ChatMemory chatMemory) { } @Override - public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { + public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) { + chatClientRequest = this.before(chatClientRequest); - advisedRequest = this.before(advisedRequest); + ChatClientResponse chatClientResponse = callAdvisorChain.nextCall(chatClientRequest); - AdvisedResponse advisedResponse = chain.nextAroundCall(advisedRequest); + this.after(chatClientResponse); - this.observeAfter(advisedResponse); - - return advisedResponse; + return chatClientResponse; } @Override - public Flux aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) { - - Flux advisedResponses = this.doNextWithProtectFromBlockingBefore(advisedRequest, chain, - this::before); + public Flux adviseStream(ChatClientRequest chatClientRequest, + StreamAdvisorChain streamAdvisorChain) { + Flux chatClientResponses = this.doNextWithProtectFromBlockingBefore(chatClientRequest, + streamAdvisorChain, this::before); - return new MessageAggregator().aggregateAdvisedResponse(advisedResponses, this::observeAfter); + return new MessageAggregator().aggregateChatClientResponse(chatClientResponses, this::after); } - private AdvisedRequest before(AdvisedRequest request) { - - String conversationId = this.doGetConversationId(request.adviseContext()); + private ChatClientRequest before(ChatClientRequest chatClientRequest) { + String conversationId = this.doGetConversationId(chatClientRequest.context()); - int chatMemoryRetrieveSize = this.doGetChatMemoryRetrieveSize(request.adviseContext()); + int chatMemoryRetrieveSize = this.doGetChatMemoryRetrieveSize(chatClientRequest.context()); // 1. Retrieve the chat memory for the current conversation. List memoryMessages = this.getChatMemoryStore().get(conversationId, chatMemoryRetrieveSize); // 2. Advise the request messages list. - List advisedMessages = new ArrayList<>(request.messages()); - advisedMessages.addAll(memoryMessages); + List processedMessages = new ArrayList<>(memoryMessages); + processedMessages.addAll(chatClientRequest.prompt().getInstructions()); // 3. Create a new request with the advised messages. - AdvisedRequest advisedRequest = AdvisedRequest.from(request).messages(advisedMessages).build(); + ChatClientRequest processedChatClientRequest = chatClientRequest.mutate() + .prompt(chatClientRequest.prompt().mutate().messages(processedMessages).build()) + .build(); - // 4. Add the new user input to the conversation memory. - UserMessage userMessage = UserMessage.builder().text(request.userText()).media(request.media()).build(); - this.getChatMemoryStore().add(this.doGetConversationId(request.adviseContext()), userMessage); + // 4. Add the new user message to the conversation memory. + UserMessage userMessage = processedChatClientRequest.prompt().getUserMessage(); + this.getChatMemoryStore().add(conversationId, userMessage); - return advisedRequest; + return processedChatClientRequest; } - private void observeAfter(AdvisedResponse advisedResponse) { - - List assistantMessages = advisedResponse.response() - .getResults() - .stream() - .map(g -> (Message) g.getOutput()) - .toList(); - - this.getChatMemoryStore().add(this.doGetConversationId(advisedResponse.adviseContext()), assistantMessages); + private void after(ChatClientResponse chatClientResponse) { + List assistantMessages = new ArrayList<>(); + if (chatClientResponse.chatResponse() != null) { + assistantMessages = chatClientResponse.chatResponse() + .getResults() + .stream() + .map(g -> (Message) g.getOutput()) + .toList(); + } + this.getChatMemoryStore().add(this.doGetConversationId(chatClientResponse.context()), assistantMessages); } public static class Builder extends AbstractChatMemoryAdvisor.AbstractBuilder { diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java index e8d919dd794..aa2927523c3 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java @@ -16,19 +16,20 @@ package org.springframework.ai.chat.client.advisor; -import java.util.HashMap; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.stream.Collectors; -import org.springframework.util.StringUtils; import reactor.core.publisher.Flux; -import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; -import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; +import org.springframework.ai.chat.client.ChatClientRequest; +import org.springframework.ai.chat.client.ChatClientResponse; import org.springframework.ai.chat.client.advisor.api.Advisor; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; -import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; +import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain; +import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.chat.memory.ChatMemory; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.MessageType; @@ -40,6 +41,7 @@ * * @author Christian Tzolov * @author Miloš Havránek + * @author Thomas Vitale * @since 1.0.0 */ public class PromptChatMemoryAdvisor extends AbstractChatMemoryAdvisor { @@ -83,68 +85,68 @@ public static Builder builder(ChatMemory chatMemory) { } @Override - public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { + public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) { + chatClientRequest = this.before(chatClientRequest); - advisedRequest = this.before(advisedRequest); + ChatClientResponse chatClientResponse = callAdvisorChain.nextCall(chatClientRequest); - AdvisedResponse advisedResponse = chain.nextAroundCall(advisedRequest); + this.after(chatClientResponse); - this.observeAfter(advisedResponse); - - return advisedResponse; + return chatClientResponse; } @Override - public Flux aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) { - - Flux advisedResponses = this.doNextWithProtectFromBlockingBefore(advisedRequest, chain, - this::before); + public Flux adviseStream(ChatClientRequest chatClientRequest, + StreamAdvisorChain streamAdvisorChain) { + Flux chatClientResponses = this.doNextWithProtectFromBlockingBefore(chatClientRequest, + streamAdvisorChain, this::before); - return new MessageAggregator().aggregateAdvisedResponse(advisedResponses, this::observeAfter); + return new MessageAggregator().aggregateChatClientResponse(chatClientResponses, this::after); } - private AdvisedRequest before(AdvisedRequest request) { + private ChatClientRequest before(ChatClientRequest chatClientRequest) { + String conversationId = this.doGetConversationId(chatClientRequest.context()); + int chatMemoryRetrieveSize = this.doGetChatMemoryRetrieveSize(chatClientRequest.context()); - // 1. Advise system parameters. - List memoryMessages = this.getChatMemoryStore() - .get(this.doGetConversationId(request.adviseContext()), - this.doGetChatMemoryRetrieveSize(request.adviseContext())); + // 1. Retrieve the chat memory for the current conversation. + List memoryMessages = this.getChatMemoryStore().get(conversationId, chatMemoryRetrieveSize); - String memory = (memoryMessages != null) ? memoryMessages.stream() + // 2. Processed memory messages as a string. + String memory = memoryMessages.stream() .filter(m -> m.getMessageType() == MessageType.USER || m.getMessageType() == MessageType.ASSISTANT) .map(m -> m.getMessageType() + ":" + m.getText()) - .collect(Collectors.joining(System.lineSeparator())) : ""; - - Map advisedSystemParams = new HashMap<>(request.systemParams()); - advisedSystemParams.put("memory", memory); - - // 2. Advise the system text. - String systemText = request.systemText(); - String advisedSystemText = (StringUtils.hasText(systemText) ? systemText + System.lineSeparator() : "") - + this.systemTextAdvise; - - // 3. Create a new request with the advised system text and parameters. - AdvisedRequest advisedRequest = AdvisedRequest.from(request) - .systemText(advisedSystemText) - .systemParams(advisedSystemParams) + .collect(Collectors.joining(System.lineSeparator())); + + // 2. Augment the system message. + SystemMessage systemMessage = chatClientRequest.prompt().getSystemMessage(); + String augmentedSystemText = PromptTemplate.builder() + .template(systemMessage.getText() + System.lineSeparator() + this.systemTextAdvise) + .variables(Map.of("memory", memory)) + .build() + .render(); + + // 3. Create a new request with the augmented system message. + ChatClientRequest processedChatClientRequest = chatClientRequest.mutate() + .prompt(chatClientRequest.prompt().augmentSystemMessage(augmentedSystemText)) .build(); - // 4. Add the new user input to the conversation memory. - UserMessage userMessage = UserMessage.builder().text(request.userText()).media(request.media()).build(); - this.getChatMemoryStore().add(this.doGetConversationId(request.adviseContext()), userMessage); + // 4. Add the new user message to the conversation memory. + UserMessage userMessage = processedChatClientRequest.prompt().getUserMessage(); + this.getChatMemoryStore().add(conversationId, userMessage); - return advisedRequest; + return processedChatClientRequest; } - private void observeAfter(AdvisedResponse advisedResponse) { - - List assistantMessages = advisedResponse.response() - .getResults() - .stream() - .map(g -> (Message) g.getOutput()) - .toList(); - - this.getChatMemoryStore().add(this.doGetConversationId(advisedResponse.adviseContext()), assistantMessages); + private void after(ChatClientResponse chatClientResponse) { + List assistantMessages = new ArrayList<>(); + if (chatClientResponse.chatResponse() != null) { + assistantMessages = chatClientResponse.chatResponse() + .getResults() + .stream() + .map(g -> (Message) g.getOutput()) + .toList(); + } + this.getChatMemoryStore().add(this.doGetConversationId(chatClientResponse.context()), assistantMessages); } public static class Builder extends AbstractChatMemoryAdvisor.AbstractBuilder { diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/SafeGuardAdvisor.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/SafeGuardAdvisor.java index aa36979e377..a319f9dd4f3 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/SafeGuardAdvisor.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/SafeGuardAdvisor.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,15 +17,16 @@ package org.springframework.ai.chat.client.advisor; import java.util.List; +import java.util.Map; import reactor.core.publisher.Flux; -import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; -import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; -import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; -import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; +import org.springframework.ai.chat.client.ChatClientRequest; +import org.springframework.ai.chat.client.ChatClientResponse; +import org.springframework.ai.chat.client.advisor.api.CallAdvisor; +import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain; +import org.springframework.ai.chat.client.advisor.api.StreamAdvisor; +import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; @@ -33,14 +34,15 @@ import org.springframework.util.CollectionUtils; /** - * A {@link CallAroundAdvisor} and {@link StreamAroundAdvisor} that filters out the - * response if the user input contains any of the sensitive words. + * An advisor that blocks the call to the model provider if the user input contains any of + * the sensitive words. * * @author Christian Tzolov * @author Ilayaperumal Gopinathan + * @author Thomas Vitale * @since 1.0.0 */ -public class SafeGuardAdvisor implements CallAroundAdvisor, StreamAroundAdvisor { +public class SafeGuardAdvisor implements CallAdvisor, StreamAdvisor { private static final String DEFAULT_FAILURE_RESPONSE = "I'm unable to respond to that due to sensitive content. Could we rephrase or discuss something else?"; @@ -73,32 +75,33 @@ public String getName() { } @Override - public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { - + public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) { if (!CollectionUtils.isEmpty(this.sensitiveWords) - && this.sensitiveWords.stream().anyMatch(w -> advisedRequest.userText().contains(w))) { - - return createFailureResponse(advisedRequest); + && this.sensitiveWords.stream().anyMatch(w -> chatClientRequest.prompt().getContents().contains(w))) { + return createFailureResponse(chatClientRequest); } - return chain.nextAroundCall(advisedRequest); + return callAdvisorChain.nextCall(chatClientRequest); } @Override - public Flux aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) { - + public Flux adviseStream(ChatClientRequest chatClientRequest, + StreamAdvisorChain streamAdvisorChain) { if (!CollectionUtils.isEmpty(this.sensitiveWords) - && this.sensitiveWords.stream().anyMatch(w -> advisedRequest.userText().contains(w))) { - return Flux.just(createFailureResponse(advisedRequest)); + && this.sensitiveWords.stream().anyMatch(w -> chatClientRequest.prompt().getContents().contains(w))) { + return Flux.just(createFailureResponse(chatClientRequest)); } - return chain.nextAroundStream(advisedRequest); + return streamAdvisorChain.nextStream(chatClientRequest); } - private AdvisedResponse createFailureResponse(AdvisedRequest advisedRequest) { - return new AdvisedResponse(ChatResponse.builder() - .generations(List.of(new Generation(new AssistantMessage(this.failureResponse)))) - .build(), advisedRequest.adviseContext()); + private ChatClientResponse createFailureResponse(ChatClientRequest chatClientRequest) { + return ChatClientResponse.builder() + .chatResponse(ChatResponse.builder() + .generations(List.of(new Generation(new AssistantMessage(this.failureResponse)))) + .build()) + .context(Map.copyOf(chatClientRequest.context())) + .build(); } @Override diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisor.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisor.java index c6fc0d794b8..e5000a73a0f 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisor.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisor.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,39 +18,39 @@ import java.util.function.Function; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; -import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; -import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; -import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; -import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.client.ChatClientRequest; +import org.springframework.ai.chat.client.ChatClientResponse; +import org.springframework.ai.chat.client.advisor.api.CallAdvisor; +import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain; +import org.springframework.ai.chat.client.advisor.api.StreamAdvisor; +import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.MessageAggregator; import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.lang.Nullable; /** * A simple logger advisor that logs the request and response messages. * * @author Christian Tzolov */ -public class SimpleLoggerAdvisor implements CallAroundAdvisor, StreamAroundAdvisor { +public class SimpleLoggerAdvisor implements CallAdvisor, StreamAdvisor { - public static final Function DEFAULT_REQUEST_TO_STRING = request -> request.toString(); + public static final Function DEFAULT_REQUEST_TO_STRING = ChatClientRequest::toString; - public static final Function DEFAULT_RESPONSE_TO_STRING = response -> ModelOptionsUtils - .toJsonStringPrettyPrinter(response); + public static final Function DEFAULT_RESPONSE_TO_STRING = ModelOptionsUtils::toJsonStringPrettyPrinter; private static final Logger logger = LoggerFactory.getLogger(SimpleLoggerAdvisor.class); - private final Function requestToString; + private final Function requestToString; private final Function responseToString; - private int order; + private final int order; public SimpleLoggerAdvisor() { this(DEFAULT_REQUEST_TO_STRING, DEFAULT_RESPONSE_TO_STRING, 0); @@ -60,30 +60,50 @@ public SimpleLoggerAdvisor(int order) { this(DEFAULT_REQUEST_TO_STRING, DEFAULT_RESPONSE_TO_STRING, order); } - public SimpleLoggerAdvisor(Function requestToString, - Function responseToString, int order) { - this.requestToString = requestToString; - this.responseToString = responseToString; + public SimpleLoggerAdvisor(@Nullable Function requestToString, + @Nullable Function responseToString, int order) { + this.requestToString = requestToString != null ? requestToString : DEFAULT_REQUEST_TO_STRING; + this.responseToString = responseToString != null ? responseToString : DEFAULT_RESPONSE_TO_STRING; this.order = order; } @Override - public String getName() { - return this.getClass().getSimpleName(); + public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) { + logRequest(chatClientRequest); + + ChatClientResponse chatClientResponse = callAdvisorChain.nextCall(chatClientRequest); + + logResponse(chatClientResponse); + + return chatClientResponse; } @Override - public int getOrder() { - return this.order; + public Flux adviseStream(ChatClientRequest chatClientRequest, + StreamAdvisorChain streamAdvisorChain) { + logRequest(chatClientRequest); + + Flux chatClientResponses = streamAdvisorChain.nextStream(chatClientRequest); + + return new MessageAggregator().aggregateChatClientResponse(chatClientResponses, this::logResponse); } - private AdvisedRequest before(AdvisedRequest request) { + private void logRequest(ChatClientRequest request) { logger.debug("request: {}", this.requestToString.apply(request)); - return request; } - private void observeAfter(AdvisedResponse advisedResponse) { - logger.debug("response: {}", this.responseToString.apply(advisedResponse.response())); + private void logResponse(ChatClientResponse chatClientResponse) { + logger.debug("response: {}", this.responseToString.apply(chatClientResponse.chatResponse())); + } + + @Override + public String getName() { + return this.getClass().getSimpleName(); + } + + @Override + public int getOrder() { + return this.order; } @Override @@ -91,26 +111,40 @@ public String toString() { return SimpleLoggerAdvisor.class.getSimpleName(); } - @Override - public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { + public static Builder builder() { + return new Builder(); + } - advisedRequest = before(advisedRequest); + public static class Builder { - AdvisedResponse advisedResponse = chain.nextAroundCall(advisedRequest); + private Function requestToString; - observeAfter(advisedResponse); + private Function responseToString; - return advisedResponse; - } + private int order = 0; - @Override - public Flux aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) { + private Builder() { + } + + public Builder requestToString(Function requestToString) { + this.requestToString = requestToString; + return this; + } + + public Builder responseToString(Function responseToString) { + this.responseToString = responseToString; + return this; + } - advisedRequest = before(advisedRequest); + public Builder order(int order) { + this.order = order; + return this; + } - Flux advisedResponses = chain.nextAroundStream(advisedRequest); + public SimpleLoggerAdvisor build() { + return new SimpleLoggerAdvisor(requestToString, responseToString, order); + } - return new MessageAggregator().aggregateAdvisedResponse(advisedResponses, this::observeAfter); } } diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java deleted file mode 100644 index 4d8352eda96..00000000000 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java +++ /dev/null @@ -1,457 +0,0 @@ -/* - * Copyright 2023-2025 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.chat.client.advisor.api; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.LinkedList; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.function.Function; - -import org.springframework.ai.chat.client.ChatClientAttributes; -import org.springframework.ai.chat.client.ChatClientRequest; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.SystemMessage; -import org.springframework.ai.chat.messages.UserMessage; -import org.springframework.ai.chat.model.ChatModel; -import org.springframework.ai.chat.prompt.ChatOptions; -import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.chat.prompt.PromptTemplate; -import org.springframework.ai.content.Media; -import org.springframework.ai.model.tool.ToolCallingChatOptions; -import org.springframework.ai.template.TemplateRenderer; -import org.springframework.ai.template.st.StTemplateRenderer; -import org.springframework.ai.tool.ToolCallback; -import org.springframework.lang.Nullable; -import org.springframework.util.Assert; -import org.springframework.util.CollectionUtils; -import org.springframework.util.StringUtils; - -/** - * The data of the chat client request that can be modified before the execution of the - * ChatClient's call method - * - * @param chatModel the chat model used - * @param userText the text provided by the user - * @param systemText the text provided by the system - * @param chatOptions the options for the chat - * @param media the list of media items - * @param toolNames the list of function names - * @param toolCallbacks the list of function callbacks - * @param messages the list of messages - * @param userParams the map of user parameters - * @param systemParams the map of system parameters - * @param advisors the list of request response advisors - * @param advisorParams the map of advisor parameters - * @param adviseContext the map of advise context - * @param toolContext the tool context - * @author Christian Tzolov - * @author Thomas Vitale - * @author Ilayaperumal Gopinathan - * @deprecated Use {@link ChatClientRequest} instead. - * @since 1.0.0 - */ -public record AdvisedRequest( -// @formatter:off - ChatModel chatModel, - String userText, - @Nullable - String systemText, - @Nullable - ChatOptions chatOptions, - List media, - List toolNames, - List toolCallbacks, - List messages, - Map userParams, - Map systemParams, - List advisors, - @Deprecated // Not really used. Use "adviseContext" instead. - Map advisorParams, - Map adviseContext, - Map toolContext -// @formatter:on -) { - - public AdvisedRequest { - Assert.notNull(chatModel, "chatModel cannot be null"); - Assert.isTrue(StringUtils.hasText(userText) || !CollectionUtils.isEmpty(messages), - "userText cannot be null or empty unless messages are provided and contain Tool Response message."); - Assert.notNull(media, "media cannot be null"); - Assert.noNullElements(media, "media cannot contain null elements"); - Assert.notNull(toolNames, "toolNames cannot be null"); - Assert.noNullElements(toolNames, "toolNames cannot contain null elements"); - Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); - Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements"); - Assert.notNull(messages, "messages cannot be null"); - Assert.noNullElements(messages, "messages cannot contain null elements"); - Assert.notNull(userParams, "userParams cannot be null"); - Assert.noNullElements(userParams.keySet(), "userParams keys cannot contain null elements"); - Assert.noNullElements(userParams.values(), "userParams values cannot contain null elements"); - Assert.notNull(systemParams, "systemParams cannot be null"); - Assert.noNullElements(systemParams.keySet(), "systemParams keys cannot contain null elements"); - Assert.noNullElements(systemParams.values(), "systemParams values cannot contain null elements"); - Assert.notNull(advisors, "advisors cannot be null"); - Assert.noNullElements(advisors, "advisors cannot contain null elements"); - Assert.notNull(advisorParams, "advisorParams cannot be null"); - Assert.noNullElements(advisorParams.keySet(), "advisorParams keys cannot contain null elements"); - Assert.noNullElements(advisorParams.values(), "advisorParams values cannot contain null elements"); - Assert.notNull(adviseContext, "adviseContext cannot be null"); - Assert.noNullElements(adviseContext.keySet(), "adviseContext keys cannot contain null elements"); - Assert.noNullElements(adviseContext.values(), "adviseContext values cannot contain null elements"); - Assert.notNull(toolContext, "toolContext cannot be null"); - Assert.noNullElements(toolContext.keySet(), "toolContext keys cannot contain null elements"); - Assert.noNullElements(toolContext.values(), "toolContext values cannot contain null elements"); - } - - public static Builder builder() { - return new Builder(); - } - - public static Builder from(AdvisedRequest from) { - Assert.notNull(from, "AdvisedRequest cannot be null"); - - Builder builder = new Builder(); - builder.chatModel = from.chatModel; - builder.userText = from.userText; - builder.systemText = from.systemText; - builder.chatOptions = from.chatOptions; - builder.media = from.media; - builder.toolNames = from.toolNames; - builder.toolCallbacks = from.toolCallbacks; - builder.messages = from.messages; - builder.userParams = from.userParams; - builder.systemParams = from.systemParams; - builder.advisors = from.advisors; - builder.advisorParams = from.advisorParams; - builder.adviseContext = from.adviseContext; - builder.toolContext = from.toolContext; - return builder; - } - - @SuppressWarnings("unchecked") - public static AdvisedRequest from(ChatClientRequest from) { - Assert.notNull(from, "ChatClientRequest cannot be null"); - - List messages = new LinkedList<>(from.prompt().getInstructions()); - - Builder builder = new Builder(); - if (from.context().get(ChatClientAttributes.CHAT_MODEL.getKey()) instanceof ChatModel chatModel) { - builder.chatModel = chatModel; - } - - if (!messages.isEmpty() && messages.get(messages.size() - 1) instanceof UserMessage userMessage) { - builder.userText = userMessage.getText(); - builder.media = userMessage.getMedia(); - messages.remove(messages.size() - 1); - } - if (from.context().get(ChatClientAttributes.USER_PARAMS.getKey()) instanceof Map contextUserParams) { - builder.userParams = (Map) contextUserParams; - } - - if (!messages.isEmpty() && messages.get(messages.size() - 1) instanceof SystemMessage systemMessage) { - builder.systemText = systemMessage.getText(); - messages.remove(messages.size() - 1); - } - if (from.context().get(ChatClientAttributes.SYSTEM_PARAMS.getKey()) instanceof Map contextSystemParams) { - builder.systemParams = (Map) contextSystemParams; - } - - builder.messages = messages; - - builder.chatOptions = Objects.requireNonNullElse(from.prompt().getOptions(), ChatOptions.builder().build()); - if (from.prompt().getOptions() instanceof ToolCallingChatOptions options) { - builder.toolNames = options.getToolNames().stream().toList(); - builder.toolCallbacks = options.getToolCallbacks(); - builder.toolContext = options.getToolContext(); - } - - if (from.context().get(ChatClientAttributes.ADVISORS.getKey()) instanceof List advisors) { - builder.advisors = (List) advisors; - } - builder.advisorParams = Map.of(); - builder.adviseContext = from.context(); - - return builder.build(); - } - - public AdvisedRequest updateContext(Function, Map> contextTransform) { - Assert.notNull(contextTransform, "contextTransform cannot be null"); - return from(this) - .adviseContext(Collections.unmodifiableMap(contextTransform.apply(new HashMap<>(this.adviseContext)))) - .build(); - } - - public ChatClientRequest toChatClientRequest() { - return toChatClientRequest(StTemplateRenderer.builder().build()); - } - - public ChatClientRequest toChatClientRequest(TemplateRenderer templateRenderer) { - return ChatClientRequest.builder() - .prompt(toPrompt(templateRenderer)) - .context(this.adviseContext) - .context(ChatClientAttributes.ADVISORS.getKey(), this.advisors) - .context(ChatClientAttributes.CHAT_MODEL.getKey(), this.chatModel) - .context(ChatClientAttributes.USER_PARAMS.getKey(), this.userParams) - .context(ChatClientAttributes.SYSTEM_PARAMS.getKey(), this.systemParams) - .build(); - } - - public Prompt toPrompt() { - return toPrompt(StTemplateRenderer.builder().build()); - } - - public Prompt toPrompt(TemplateRenderer templateRenderer) { - var messages = new ArrayList<>(this.messages()); - - String processedSystemText = this.systemText(); - if (StringUtils.hasText(processedSystemText)) { - if (!CollectionUtils.isEmpty(this.systemParams())) { - processedSystemText = PromptTemplate.builder() - .template(processedSystemText) - .variables(this.systemParams()) - .renderer(templateRenderer) - .build() - .render(); - } - messages.add(new SystemMessage(processedSystemText)); - } - - if (StringUtils.hasText(this.userText())) { - Map userParams = new HashMap<>(this.userParams()); - String processedUserText = this.userText(); - if (!CollectionUtils.isEmpty(userParams)) { - processedUserText = PromptTemplate.builder() - .template(processedUserText) - .variables(userParams) - .renderer(templateRenderer) - .build() - .render(); - } - messages.add(UserMessage.builder().text(processedUserText).media(this.media()).build()); - } - - if (this.chatOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) { - if (!this.toolNames().isEmpty()) { - toolCallingChatOptions.setToolNames(new HashSet<>(this.toolNames())); - } - if (!this.toolCallbacks().isEmpty()) { - toolCallingChatOptions.setToolCallbacks(this.toolCallbacks()); - } - if (!CollectionUtils.isEmpty(this.toolContext())) { - toolCallingChatOptions.setToolContext(this.toolContext()); - } - } - - return new Prompt(messages, this.chatOptions()); - } - - /** - * Builder for {@link AdvisedRequest}. - */ - public static final class Builder { - - private ChatModel chatModel; - - private String userText; - - private String systemText; - - private ChatOptions chatOptions; - - private List media = List.of(); - - private List toolNames = List.of(); - - private List toolCallbacks = List.of(); - - private List messages = List.of(); - - private Map userParams = Map.of(); - - private Map systemParams = Map.of(); - - private List advisors = List.of(); - - private Map advisorParams = Map.of(); - - private Map adviseContext = Map.of(); - - public Map toolContext = Map.of(); - - private Builder() { - } - - /** - * Set the chat model. - * @param chatModel the chat model - * @return this {@link Builder} instance - */ - public Builder chatModel(ChatModel chatModel) { - this.chatModel = chatModel; - return this; - } - - /** - * Set the user text. - * @param userText the user text - * @return this {@link Builder} instance - */ - public Builder userText(String userText) { - this.userText = userText; - return this; - } - - /** - * Set the system text. - * @param systemText the system text - * @return this {@link Builder} instance - */ - public Builder systemText(String systemText) { - this.systemText = systemText; - return this; - } - - /** - * Set the chat options. - * @param chatOptions the chat options - * @return this {@link Builder} instance - */ - public Builder chatOptions(ChatOptions chatOptions) { - this.chatOptions = chatOptions; - return this; - } - - /** - * Set the media. - * @param media the media - * @return this {@link Builder} instance - */ - public Builder media(List media) { - this.media = media; - return this; - } - - /** - * Set the tool names. - * @param toolNames the function names - * @return this {@link Builder} instance - */ - public Builder toolNames(List toolNames) { - this.toolNames = toolNames; - return this; - } - - /** - * Set the tool callbacks. - * @param toolCallbacks the tool callbacks - * @return this {@link Builder} instance - */ - public Builder functionCallbacks(List toolCallbacks) { - this.toolCallbacks = toolCallbacks; - return this; - } - - /** - * Set the messages. - * @param messages the messages - * @return this {@link Builder} instance - */ - public Builder messages(List messages) { - this.messages = messages; - return this; - } - - /** - * Set the user params. - * @param userParams the user params - * @return this {@link Builder} instance - */ - public Builder userParams(Map userParams) { - this.userParams = userParams; - return this; - } - - /** - * Set the system params. - * @param systemParams the system params - * @return this {@link Builder} instance - */ - public Builder systemParams(Map systemParams) { - this.systemParams = systemParams; - return this; - } - - /** - * Set the advisors. - * @param advisors the advisors - * @return this {@link Builder} instance - */ - public Builder advisors(List advisors) { - this.advisors = advisors; - return this; - } - - /** - * Set the advisor params. - * @param advisorParams the advisor params - * @return this {@link Builder} instance - * @deprecated in favor of {@link #adviseContext(Map)} - */ - @Deprecated - public Builder advisorParams(Map advisorParams) { - this.advisorParams = advisorParams; - return this; - } - - /** - * Set the advise context. - * @param adviseContext the advise context - * @return this {@link Builder} instance - */ - public Builder adviseContext(Map adviseContext) { - this.adviseContext = adviseContext; - return this; - } - - /** - * Set the tool context. - * @param toolContext the tool context - * @return this {@link Builder} instance - */ - public Builder toolContext(Map toolContext) { - this.toolContext = toolContext; - return this; - } - - /** - * Build the {@link AdvisedRequest} instance. - * @return a new {@link AdvisedRequest} instance - */ - public AdvisedRequest build() { - return new AdvisedRequest(this.chatModel, this.userText, this.systemText, this.chatOptions, this.media, - this.toolNames, this.toolCallbacks, this.messages, this.userParams, this.systemParams, - this.advisors, this.advisorParams, this.adviseContext, this.toolContext); - } - - } - -} diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponse.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponse.java deleted file mode 100644 index 04644c7db8a..00000000000 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponse.java +++ /dev/null @@ -1,135 +0,0 @@ -/* - * Copyright 2023-2025 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.chat.client.advisor.api; - -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; -import java.util.function.Function; - -import org.springframework.ai.chat.client.ChatClientResponse; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.lang.Nullable; -import org.springframework.util.Assert; - -/** - * The data of the chat client response that can be modified before the call returns. - * - * @param response the chat response - * @param adviseContext the context to advise the response - * @author Christian Tzolov - * @author Thomas Vitale - * @author Ilayaperumal Gopinathan - * @deprecated Use {@link ChatClientResponse} instead. - * @since 1.0.0 - */ -@Deprecated -public record AdvisedResponse(@Nullable ChatResponse response, Map adviseContext) { - - /** - * Create a new {@link AdvisedResponse} instance. - * @param response the chat response - * @param adviseContext the context to advise the response - */ - public AdvisedResponse { - Assert.notNull(adviseContext, "adviseContext cannot be null"); - Assert.noNullElements(adviseContext.keySet(), "adviseContext keys cannot be null"); - Assert.noNullElements(adviseContext.values(), "adviseContext values cannot be null"); - } - - /** - * Create a new {@link Builder} instance. - * @return a new {@link Builder} instance - */ - public static Builder builder() { - return new Builder(); - } - - /** - * Create a new {@link Builder} instance from the provided {@link AdvisedResponse}. - * @param advisedResponse the advised response to copy - * @return a new {@link Builder} instance - */ - public static Builder from(AdvisedResponse advisedResponse) { - Assert.notNull(advisedResponse, "advisedResponse cannot be null"); - return new Builder().response(advisedResponse.response).adviseContext(advisedResponse.adviseContext); - } - - public static AdvisedResponse from(ChatClientResponse chatClientResponse) { - Assert.notNull(chatClientResponse, "chatClientResponse cannot be null"); - return new AdvisedResponse(chatClientResponse.chatResponse(), chatClientResponse.context()); - } - - public ChatClientResponse toChatClientResponse() { - return new ChatClientResponse(this.response, this.adviseContext); - } - - /** - * Update the context of the advised response. - * @param contextTransform the function to transform the context - * @return the updated advised response - */ - public AdvisedResponse updateContext(Function, Map> contextTransform) { - Assert.notNull(contextTransform, "contextTransform cannot be null"); - return new AdvisedResponse(this.response, - Collections.unmodifiableMap(contextTransform.apply(new HashMap<>(this.adviseContext)))); - } - - /** - * Builder for {@link AdvisedResponse}. - */ - public static final class Builder { - - @Nullable - private ChatResponse response; - - private Map adviseContext; - - private Builder() { - } - - /** - * Set the chat response. - * @param response the chat response - * @return the builder - */ - public Builder response(@Nullable ChatResponse response) { - this.response = response; - return this; - } - - /** - * Set the context to advise the response. - * @param adviseContext the context to advise the response - * @return the builder - */ - public Builder adviseContext(Map adviseContext) { - this.adviseContext = adviseContext; - return this; - } - - /** - * Build the {@link AdvisedResponse}. - * @return the advised response - */ - public AdvisedResponse build() { - return new AdvisedResponse(this.response, this.adviseContext); - } - - } - -} diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponseStreamUtils.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponseStreamUtils.java deleted file mode 100644 index 169c8c4e562..00000000000 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponseStreamUtils.java +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Copyright 2025-2025 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.chat.client.advisor.api; - -import java.util.function.Predicate; - -import org.springframework.ai.chat.client.advisor.AdvisorUtils; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.util.StringUtils; - -/** - * A stream utility class to provide support methods handling {@link AdvisedResponse}. - * - * @deprecated in favour of {@link AdvisorUtils}. - */ -@Deprecated -public final class AdvisedResponseStreamUtils { - - private AdvisedResponseStreamUtils() { - // Avoids instantiation - } - - /** - * Returns a predicate that checks whether the provided {@link AdvisedResponse} - * contains a {@link ChatResponse} with at least one result having a non-empty finish - * reason in its metadata. - * @return a {@link Predicate} that evaluates whether the finish reason exists within - * the response metadata. - */ - public static Predicate onFinishReason() { - return advisedResponse -> { - ChatResponse chatResponse = advisedResponse.response(); - return chatResponse != null && chatResponse.getResults() != null - && chatResponse.getResults() - .stream() - .anyMatch(result -> result != null && result.getMetadata() != null - && StringUtils.hasText(result.getMetadata().getFinishReason())); - }; - } - -} diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/BaseAdvisor.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/BaseAdvisor.java index 5ce25bb4aac..007aac02bfc 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/BaseAdvisor.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/BaseAdvisor.java @@ -31,10 +31,10 @@ * {@link StreamAdvisor}, reducing the boilerplate code needed to implement an advisor. *

* It provides default implementations for the - * {@link #adviseCall(ChatClientRequest, CallAroundAdvisorChain)} and - * {@link #adviseStream(ChatClientRequest, StreamAroundAdvisorChain)} methods, delegating - * the actual logic to the {@link #before(ChatClientRequest, AdvisorChain advisorChain)} - * and {@link #after(ChatClientResponse, AdvisorChain advisorChain)} methods. + * {@link #adviseCall(ChatClientRequest, CallAdvisorChain)} and + * {@link #adviseStream(ChatClientRequest, StreamAdvisorChain)} methods, delegating the + * actual logic to the {@link #before(ChatClientRequest, AdvisorChain advisorChain)} and + * {@link #after(ChatClientResponse, AdvisorChain advisorChain)} methods. * * @author Thomas Vitale * @since 1.0.0 @@ -44,82 +44,35 @@ public interface BaseAdvisor extends CallAdvisor, StreamAdvisor { Scheduler DEFAULT_SCHEDULER = Schedulers.boundedElastic(); @Override - default ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAroundAdvisorChain chain) { + default ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) { Assert.notNull(chatClientRequest, "chatClientRequest cannot be null"); - Assert.notNull(chain, "chain cannot be null"); + Assert.notNull(callAdvisorChain, "callAdvisorChain cannot be null"); - ChatClientRequest processedChatClientRequest = before(chatClientRequest, chain); - ChatClientResponse chatClientResponse; - if (chain instanceof CallAdvisorChain callAdvisorChain) { - chatClientResponse = callAdvisorChain.nextCall(processedChatClientRequest); - } - else { - chatClientResponse = chain.nextAroundCall(AdvisedRequest.from(processedChatClientRequest)) - .toChatClientResponse(); - } - return after(chatClientResponse, chain); + ChatClientRequest processedChatClientRequest = before(chatClientRequest, callAdvisorChain); + ChatClientResponse chatClientResponse = callAdvisorChain.nextCall(processedChatClientRequest); + return after(chatClientResponse, callAdvisorChain); } @Override - @Deprecated - default AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { - Assert.notNull(advisedRequest, "advisedRequest cannot be null"); - Assert.notNull(chain, "chain cannot be null"); - - AdvisedRequest processedAdvisedRequest = before(advisedRequest); - AdvisedResponse advisedResponse = chain.nextAroundCall(processedAdvisedRequest); - return after(advisedResponse); - } - - @Override - default Flux adviseStream(ChatClientRequest chatClientRequest, StreamAroundAdvisorChain chain) { + default Flux adviseStream(ChatClientRequest chatClientRequest, + StreamAdvisorChain streamAdvisorChain) { Assert.notNull(chatClientRequest, "chatClientRequest cannot be null"); - Assert.notNull(chain, "chain cannot be null"); + Assert.notNull(streamAdvisorChain, "streamAdvisorChain cannot be null"); Assert.notNull(getScheduler(), "scheduler cannot be null"); - Flux chatClientResponseFlux; - if (chain instanceof StreamAdvisorChain streamAdvisorChain) { - chatClientResponseFlux = Mono.just(chatClientRequest) - .publishOn(getScheduler()) - .map(request -> this.before(request, streamAdvisorChain)) - .flatMapMany(streamAdvisorChain::nextStream); - } - else { - chatClientResponseFlux = Mono.just(AdvisedRequest.from(chatClientRequest)) - .publishOn(getScheduler()) - .map(this::before) - .flatMapMany(chain::nextAroundStream) - .map(AdvisedResponse::toChatClientResponse); - } + Flux chatClientResponseFlux = Mono.just(chatClientRequest) + .publishOn(getScheduler()) + .map(request -> this.before(request, streamAdvisorChain)) + .flatMapMany(streamAdvisorChain::nextStream); return chatClientResponseFlux.map(response -> { if (AdvisorUtils.onFinishReason().test(response)) { - response = after(response, chain); + response = after(response, streamAdvisorChain); } return response; }).onErrorResume(error -> Flux.error(new IllegalStateException("Stream processing failed", error))); } - @Override - @Deprecated - default Flux aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) { - Assert.notNull(advisedRequest, "advisedRequest cannot be null"); - Assert.notNull(chain, "chain cannot be null"); - Assert.notNull(getScheduler(), "scheduler cannot be null"); - - Flux advisedResponses = Mono.just(advisedRequest) - .publishOn(getScheduler()) - .map(this::before) - .flatMapMany(chain::nextAroundStream); - - return advisedResponses.map(ar -> { - if (AdvisedResponseStreamUtils.onFinishReason().test(ar)) { - ar = after(ar); - } - return ar; - }).onErrorResume(error -> Flux.error(new IllegalStateException("Stream processing failed", error))); - } - @Override default String getName() { return this.getClass().getSimpleName(); @@ -128,32 +81,12 @@ default String getName() { /** * Logic to be executed before the rest of the advisor chain is called. */ - default ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChain advisorChain) { - Assert.notNull(chatClientRequest, "chatClientRequest cannot be null"); - return before(AdvisedRequest.from(chatClientRequest)).toChatClientRequest(); - } - - /** - * Logic to be executed after the rest of the advisor chain is called. - */ - default ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorChain advisorChain) { - Assert.notNull(chatClientResponse, "chatClientResponse cannot be null"); - return after(AdvisedResponse.from(chatClientResponse)).toChatClientResponse(); - } - - /** - * Logic to be executed before the rest of the advisor chain is called. - * @deprecated in favor of {@link #before(ChatClientRequest,AdvisorChain)} - */ - @Deprecated - AdvisedRequest before(AdvisedRequest request); + ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChain advisorChain); /** * Logic to be executed after the rest of the advisor chain is called. - * @deprecated in favor of {@link #after(ChatClientResponse,AdvisorChain)} */ - @Deprecated - AdvisedResponse after(AdvisedResponse advisedResponse); + ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorChain advisorChain); /** * Scheduler used for processing the advisor logic when streaming. diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/BaseAdvisorChain.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/BaseAdvisorChain.java index 4ea605cb452..7957d48e6e2 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/BaseAdvisorChain.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/BaseAdvisorChain.java @@ -16,9 +16,6 @@ package org.springframework.ai.chat.client.advisor.api; -import org.springframework.ai.template.TemplateRenderer; -import org.springframework.ai.template.st.StTemplateRenderer; - /** * A base interface for advisor chains that can be used to chain multiple advisors * together, both for call and stream advisors. @@ -28,8 +25,4 @@ */ public interface BaseAdvisorChain extends CallAdvisorChain, StreamAdvisorChain { - default TemplateRenderer getTemplateRenderer() { - return StTemplateRenderer.builder().build(); - } - } diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAdvisor.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAdvisor.java index 6478b92f51a..182dae19321 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAdvisor.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAdvisor.java @@ -22,20 +22,13 @@ /** * Advisor for execution flows ultimately resulting in a call to an AI model * + * @author Christian Tzolov + * @author Dariusz Jedrzejczyk * @author Thomas Vitale * @since 1.0.0 */ -public interface CallAdvisor extends CallAroundAdvisor { +public interface CallAdvisor extends Advisor { - /** - * @deprecated use {@link #adviseCall(ChatClientRequest, CallAroundAdvisorChain)} - */ - @Deprecated - default AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { - ChatClientResponse chatClientResponse = adviseCall(advisedRequest.toChatClientRequest(), chain); - return AdvisedResponse.from(chatClientResponse); - } - - ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAroundAdvisorChain chain); + ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain); } diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAdvisorChain.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAdvisorChain.java index e6d17d2fd41..13624abce93 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAdvisorChain.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAdvisorChain.java @@ -25,22 +25,23 @@ * A chain of {@link CallAdvisor} instances orchestrating the execution of a * {@link ChatClientRequest} on the next {@link CallAdvisor} in the chain. * + * @author Christian Tzolov + * @author Dariusz Jedrzejczyk * @author Thomas Vitale * @since 1.0.0 */ -public interface CallAdvisorChain extends CallAroundAdvisorChain { +public interface CallAdvisorChain extends AdvisorChain { /** - * @deprecated use {@link #nextCall(ChatClientRequest)} + * Invokes the next {@link CallAdvisor} in the {@link CallAdvisorChain} with the given + * request. */ - @Deprecated - default AdvisedResponse nextAroundCall(AdvisedRequest advisedRequest) { - ChatClientResponse chatClientResponse = nextCall(advisedRequest.toChatClientRequest()); - return AdvisedResponse.from(chatClientResponse); - } - ChatClientResponse nextCall(ChatClientRequest chatClientRequest); - List getCallAdvisors(); + /** + * Returns the list of all the {@link CallAdvisor} instances included in this chain at + * the time of its creation. + */ + List getCallAdvisors(); } diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisor.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisor.java deleted file mode 100644 index 3faaf36a599..00000000000 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisor.java +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Copyright 2023-2025 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.chat.client.advisor.api; - -import org.springframework.ai.chat.client.ChatClientRequest; - -/** - * Around advisor that wraps the ChatModel#call(Prompt) method. - * - * @author Christian Tzolov - * @author Dariusz Jedrzejczyk - * @since 1.0.0 - * @deprecated in favor of {@link CallAdvisor} - */ -@Deprecated -public interface CallAroundAdvisor extends Advisor { - - /** - * Around advice that wraps the ChatModel#call(Prompt) method. - * @param advisedRequest the advised request - * @param chain the advisor chain - * @return the response - * @deprecated in favor of - * {@link CallAdvisor#adviseCall(ChatClientRequest, CallAroundAdvisorChain)} - */ - @Deprecated - AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain); - -} diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisorChain.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisorChain.java deleted file mode 100644 index e768f45e1cc..00000000000 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisorChain.java +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Copyright 2023-2025 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.chat.client.advisor.api; - -import org.springframework.ai.chat.client.ChatClientRequest; - -/** - * The Call Around Advisor Chain is used to invoke the next Around Advisor in the chain. - * Used for non-streaming responses. - * - * @author Christian Tzolov - * @author Dariusz Jedrzejczyk - * @since 1.0.0 - * @deprecated in favor of {@link CallAdvisorChain} - */ -@Deprecated -public interface CallAroundAdvisorChain extends AdvisorChain { - - /** - * Invokes the next Around Advisor in the CallAroundAdvisorChain with the given - * request. - * @param advisedRequest the request containing the data to be processed by the next - * advisor in the chain. - * @return the response generated by the next advisor in the chain. - * @deprecated in favor of {@link CallAdvisorChain#nextCall(ChatClientRequest)} - */ - @Deprecated - AdvisedResponse nextAroundCall(AdvisedRequest advisedRequest); - -} diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAdvisor.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAdvisor.java index 9ea441d4486..a5e3becfff9 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAdvisor.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAdvisor.java @@ -23,20 +23,13 @@ /** * Advisor for execution flows ultimately resulting in a streaming call to an AI model. * + * @author Christian Tzolov + * @author Dariusz Jedrzejczyk * @author Thomas Vitale * @since 1.0.0 */ -public interface StreamAdvisor extends StreamAroundAdvisor { +public interface StreamAdvisor extends Advisor { - /** - * @deprecated use {@link #adviseStream(ChatClientRequest, StreamAroundAdvisorChain)} - */ - @Deprecated - default Flux aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) { - Flux chatClientResponse = adviseStream(advisedRequest.toChatClientRequest(), chain); - return chatClientResponse.map(AdvisedResponse::from); - } - - Flux adviseStream(ChatClientRequest chatClientRequest, StreamAroundAdvisorChain chain); + Flux adviseStream(ChatClientRequest chatClientRequest, StreamAdvisorChain streamAdvisorChain); } diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAdvisorChain.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAdvisorChain.java index 66e60f4c769..2b99dc47a24 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAdvisorChain.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAdvisorChain.java @@ -26,22 +26,23 @@ * A chain of {@link StreamAdvisor} instances orchestrating the execution of a * {@link ChatClientRequest} on the next {@link StreamAdvisor} in the chain. * + * @author Christian Tzolov + * @author Dariusz Jedrzejczyk * @author Thomas Vitale * @since 1.0.0 */ -public interface StreamAdvisorChain extends StreamAroundAdvisorChain { +public interface StreamAdvisorChain extends AdvisorChain { /** - * @deprecated use {@link #nextStream(ChatClientRequest)} + * Invokes the next {@link StreamAdvisor} in the {@link StreamAdvisorChain} with the + * given request. */ - @Deprecated - default Flux nextAroundStream(AdvisedRequest advisedRequest) { - Flux chatClientResponse = nextStream(advisedRequest.toChatClientRequest()); - return chatClientResponse.map(AdvisedResponse::from); - } - Flux nextStream(ChatClientRequest chatClientRequest); - List getStreamAdvisors(); + /** + * Returns the list of all the {@link StreamAdvisor} instances included in this chain + * at the time of its creation. + */ + List getStreamAdvisors(); } diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisor.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisor.java deleted file mode 100644 index d7145e14246..00000000000 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisor.java +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Copyright 2023-2025 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.chat.client.advisor.api; - -import org.springframework.ai.chat.client.ChatClientRequest; -import reactor.core.publisher.Flux; - -/** - * Around advisor that runs around stream based requests. - * - * @author Christian Tzolov - * @author Dariusz Jedrzejczyk - * @since 1.0.0 - * @deprecated in favor of {@link StreamAdvisor} - */ -@Deprecated -public interface StreamAroundAdvisor extends Advisor { - - /** - * Around advice that wraps the invocation of the advised request. - * @param advisedRequest the advised request - * @param chain the chain of advisors to execute - * @return the result of the advised request - * @deprecated in favor of - * {@link StreamAdvisor#adviseStream(ChatClientRequest, StreamAroundAdvisorChain)} - */ - @Deprecated - Flux aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain); - -} diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisorChain.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisorChain.java deleted file mode 100644 index abb4a62e60f..00000000000 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisorChain.java +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright 2023-2025 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.chat.client.advisor.api; - -import org.springframework.ai.chat.client.ChatClientRequest; -import reactor.core.publisher.Flux; - -/** - * The StreamAroundAdvisorChain is used to delegate the call to the next - * StreamAroundAdvisor in the chain. Used for streaming responses. - * - * @author Christian Tzolov - * @author Dariusz Jedrzejczyk - * @since 1.0.0 - * @deprecated in favor of {@link StreamAdvisorChain} - */ -@Deprecated -public interface StreamAroundAdvisorChain extends AdvisorChain { - - /** - * This method delegates the call to the next StreamAroundAdvisor in the chain and is - * used for streaming responses. - * @param advisedRequest the request containing data of the chat client that can be - * modified before execution - * @return a Flux stream of AdvisedResponse objects - * @deprecated in favor of {@link StreamAdvisorChain#nextStream(ChatClientRequest)} - */ - @Deprecated - Flux nextAroundStream(AdvisedRequest advisedRequest); - -} diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationContext.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationContext.java index c5005c92471..dc9a3735e19 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationContext.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationContext.java @@ -16,18 +16,12 @@ package org.springframework.ai.chat.client.advisor.observation; -import java.util.Map; - import io.micrometer.observation.Observation; -import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.ChatClientRequest; import org.springframework.ai.chat.client.ChatClientResponse; -import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; -import org.springframework.ai.chat.prompt.Prompt; import org.springframework.lang.Nullable; import org.springframework.util.Assert; -import org.springframework.util.CollectionUtils; /** * Context used to store metadata for chat client advisors. @@ -47,41 +41,6 @@ public class AdvisorObservationContext extends Observation.Context { @Nullable private ChatClientResponse chatClientResponse; - /** - * the shared data between the advisors in the chain. It is shared between all request - * and response advising points of all advisors in the chain. - */ - @Nullable - private Map advisorResponseContext; - - /** - * Create a new {@link AdvisorObservationContext}. - * @param advisorName the advisor name - * @param advisorType the advisor type - * @param advisorRequest the advised request - * @param advisorRequestContext the shared data between the advisors in the chain - * @param advisorResponseContext the shared data between the advisors in the chain - * @param order the order of the advisor in the advisor chain - * @deprecated use the builder instead - */ - @Deprecated - public AdvisorObservationContext(String advisorName, Type advisorType, @Nullable AdvisedRequest advisorRequest, - @Nullable Map advisorRequestContext, @Nullable Map advisorResponseContext, - int order) { - Assert.hasText(advisorName, "advisorName cannot be null or empty"); - - this.advisorName = advisorName; - this.chatClientRequest = advisorRequest != null ? advisorRequest.toChatClientRequest() - : ChatClientRequest.builder().prompt(new Prompt()).build(); - if (!CollectionUtils.isEmpty(advisorRequestContext)) { - this.chatClientRequest.context().putAll(advisorRequestContext); - } - if (!CollectionUtils.isEmpty(advisorResponseContext)) { - this.chatClientResponse = ChatClientResponse.builder().context(advisorResponseContext).build(); - } - this.order = order; - } - AdvisorObservationContext(String advisorName, ChatClientRequest chatClientRequest, int order) { Assert.hasText(advisorName, "advisorName cannot be null or empty"); Assert.notNull(chatClientRequest, "chatClientRequest cannot be null"); @@ -120,111 +79,6 @@ public void setChatClientResponse(@Nullable ChatClientResponse chatClientRespons this.chatClientResponse = chatClientResponse; } - /** - * The type of the advisor. - * @return the type of the advisor - * @deprecated advisors don't have types anymore, they're all "around" - */ - @Deprecated - public Type getAdvisorType() { - return Type.AROUND; - } - - /** - * The order of the advisor in the advisor chain. - * @return the order of the advisor in the advisor chain - * @deprecated not used anymore - */ - @Deprecated - public AdvisedRequest getAdvisedRequest() { - return AdvisedRequest.from(this.chatClientRequest); - } - - /** - * Set the {@link AdvisedRequest} data to be advised. Represents the row - * {@link ChatClient.ChatClientRequestSpec} data before sealed into a {@link Prompt}. - * @param advisedRequest the advised request - * @deprecated immutable object, use the builder instead to create a new instance - */ - @Deprecated - public void setAdvisedRequest(@Nullable AdvisedRequest advisedRequest) { - throw new IllegalStateException( - "The AdvisedRequest is immutable. Build a new AdvisorObservationContext instead."); - } - - /** - * Get the shared data between the advisors in the chain. It is shared between all - * request and response advising points of all advisors in the chain. - * @return the shared data between the advisors in the chain - * @deprecated use {@link #getChatClientRequest()} instead - */ - @Deprecated - public Map getAdvisorRequestContext() { - return this.chatClientRequest.context(); - } - - /** - * Set the shared data between the advisors in the chain. It is shared between all - * request and response advising points of all advisors in the chain. - * @param advisorRequestContext the shared data between the advisors in the chain - * @deprecated not supported anymore, use {@link #getChatClientRequest()} instead - */ - @Deprecated - public void setAdvisorRequestContext(@Nullable Map advisorRequestContext) { - if (!CollectionUtils.isEmpty(advisorRequestContext)) { - this.chatClientRequest.context().putAll(advisorRequestContext); - } - } - - /** - * Get the shared data between the advisors in the chain. It is shared between all - * request and response advising points of all advisors in the chain. - * @return the shared data between the advisors in the chain - * @deprecated use {@link #getChatClientResponse()} instead - */ - @Nullable - @Deprecated - public Map getAdvisorResponseContext() { - if (this.chatClientResponse != null) { - return this.chatClientResponse.context(); - } - return null; - } - - /** - * Set the shared data between the advisors in the chain. It is shared between all - * request and response advising points of all advisors in the chain. - * @param advisorResponseContext the shared data between the advisors in the chain - * @deprecated use {@link #setChatClientResponse(ChatClientResponse)} instead - */ - @Deprecated - public void setAdvisorResponseContext(@Nullable Map advisorResponseContext) { - this.advisorResponseContext = advisorResponseContext; - } - - /** - * The type of the advisor. - * - * @deprecated advisors don't have types anymore, they're all "around" - */ - @Deprecated - public enum Type { - - /** - * The advisor is called before the advised request is executed. - */ - BEFORE, - /** - * The advisor is called after the advised request is executed. - */ - AFTER, - /** - * The advisor is called around the advised request. - */ - AROUND - - } - /** * Builder for {@link AdvisorObservationContext}. */ @@ -236,12 +90,6 @@ public static final class Builder { private int order = 0; - private AdvisedRequest advisorRequest; - - private Map advisorRequestContext; - - private Map advisorResponseContext; - private Builder() { } @@ -260,65 +108,8 @@ public Builder order(int order) { return this; } - /** - * Set the advisor type. - * @param advisorType the advisor type - * @return the builder - * @deprecated advisors don't have types anymore, they're all "around" - */ - @Deprecated - public Builder advisorType(Type advisorType) { - return this; - } - - /** - * Set the advised request. - * @param advisedRequest the advised request - * @return the builder - * @deprecated use {@link #chatClientRequest(ChatClientRequest)} instead - */ - @Deprecated - public Builder advisedRequest(AdvisedRequest advisedRequest) { - this.advisorRequest = advisedRequest; - return this; - } - - /** - * Set the advisor request context. - * @param advisorRequestContext the advisor request context - * @return the builder - * @deprecated use {@link #chatClientRequest(ChatClientRequest)} instead - */ - @Deprecated - public Builder advisorRequestContext(Map advisorRequestContext) { - this.advisorRequestContext = advisorRequestContext; - return this; - } - - /** - * Set the advisor response context. - * @param advisorResponseContext the advisor response context - * @return the builder - * @deprecated use {@link #setChatClientResponse(ChatClientResponse)} instead - */ - @Deprecated - public Builder advisorResponseContext(Map advisorResponseContext) { - this.advisorResponseContext = advisorResponseContext; - return this; - } - public AdvisorObservationContext build() { - if (chatClientRequest != null && advisorRequest != null) { - throw new IllegalArgumentException( - "ChatClientRequest and AdvisedRequest cannot be set at the same time"); - } - else if (chatClientRequest != null) { - return new AdvisorObservationContext(this.advisorName, this.chatClientRequest, this.order); - } - else { - return new AdvisorObservationContext(this.advisorName, Type.AROUND, this.advisorRequest, - this.advisorRequestContext, this.advisorResponseContext, this.order); - } + return new AdvisorObservationContext(this.advisorName, this.chatClientRequest, this.order); } } diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationDocumentation.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationDocumentation.java index f332db04cca..4d80bc7b736 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationDocumentation.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationDocumentation.java @@ -87,18 +87,6 @@ public String asString() { } }, - /** - * Advisor type: Before, After or Around. - * @deprecated advisors don't have types anymore, they're all "around" - */ - @Deprecated - ADVISOR_TYPE { - @Override - public String asString() { - return "spring.ai.advisor.type"; - } - }, - /** * Advisor name. */ diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/observation/DefaultAdvisorObservationConvention.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/observation/DefaultAdvisorObservationConvention.java index 4f1e513628a..ddfb110e4a1 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/observation/DefaultAdvisorObservationConvention.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/observation/DefaultAdvisorObservationConvention.java @@ -69,8 +69,7 @@ public String getContextualName(AdvisorObservationContext context) { @Override public KeyValues getLowCardinalityKeyValues(AdvisorObservationContext context) { Assert.notNull(context, "context cannot be null"); - return KeyValues.of(aiOperationType(context), aiProvider(context), springAiKind(), advisorType(context), - advisorName(context)); + return KeyValues.of(aiOperationType(context), aiProvider(context), springAiKind(), advisorName(context)); } protected KeyValue aiOperationType(AdvisorObservationContext context) { @@ -81,14 +80,6 @@ protected KeyValue aiProvider(AdvisorObservationContext context) { return KeyValue.of(LowCardinalityKeyNames.AI_PROVIDER, AiProvider.SPRING_AI.value()); } - /** - * @deprecated advisors don't have types anymore, they're all "around" - */ - @Deprecated - protected KeyValue advisorType(AdvisorObservationContext context) { - return KeyValue.of(LowCardinalityKeyNames.ADVISOR_TYPE, context.getAdvisorType().name()); - } - protected KeyValue springAiKind() { return KeyValue.of(LowCardinalityKeyNames.SPRING_AI_KIND, SpringAiKind.ADVISOR.value()); } diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/observation/ChatClientInputContentObservationFilter.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/observation/ChatClientInputContentObservationFilter.java deleted file mode 100644 index fdd5042602d..00000000000 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/observation/ChatClientInputContentObservationFilter.java +++ /dev/null @@ -1,118 +0,0 @@ -/* - * Copyright 2023-2025 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.chat.client.observation; - -import io.micrometer.common.KeyValue; -import io.micrometer.observation.Observation; -import io.micrometer.observation.ObservationFilter; - -import org.springframework.ai.chat.client.ChatClientAttributes; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.SystemMessage; -import org.springframework.ai.chat.messages.UserMessage; -import org.springframework.ai.observation.tracing.TracingHelper; -import org.springframework.util.CollectionUtils; - -import java.util.List; -import java.util.Map; - -/** - * An {@link ObservationFilter} to include the chat prompt content in the observation. - * - * @author Christian Tzolov - * @since 1.0.0 - * @deprecated in favor of {@link ChatClientPromptContentObservationFilter}. - */ -@Deprecated -public class ChatClientInputContentObservationFilter implements ObservationFilter { - - @Override - public Observation.Context map(Observation.Context context) { - if (!(context instanceof ChatClientObservationContext chatClientObservationContext)) { - return context; - } - chatClientSystemText(chatClientObservationContext); - chatClientSystemParams(chatClientObservationContext); - chatClientUserText(chatClientObservationContext); - chatClientUserParams(chatClientObservationContext); - - return chatClientObservationContext; - } - - protected void chatClientSystemText(ChatClientObservationContext context) { - List messages = context.getRequest().prompt().getInstructions(); - if (CollectionUtils.isEmpty(messages)) { - return; - } - - var systemMessage = messages.stream() - .filter(message -> message instanceof SystemMessage) - .reduce((first, second) -> second); - if (systemMessage.isEmpty()) { - return; - } - context.addHighCardinalityKeyValue( - KeyValue.of(ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_SYSTEM_TEXT, - systemMessage.get().getText())); - } - - @SuppressWarnings("unchecked") - protected void chatClientSystemParams(ChatClientObservationContext context) { - if (!(context.getRequest() - .context() - .get(ChatClientAttributes.SYSTEM_PARAMS.getKey()) instanceof Map systemParams)) { - return; - } - if (CollectionUtils.isEmpty(systemParams)) { - return; - } - - context.addHighCardinalityKeyValue( - KeyValue.of(ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_SYSTEM_PARAM, - TracingHelper.concatenateMaps((Map) systemParams))); - } - - protected void chatClientUserText(ChatClientObservationContext context) { - List messages = context.getRequest().prompt().getInstructions(); - if (CollectionUtils.isEmpty(messages)) { - return; - } - - if (!(messages.get(messages.size() - 1) instanceof UserMessage userMessage)) { - return; - } - context.addHighCardinalityKeyValue( - KeyValue.of(ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_USER_TEXT, - userMessage.getText())); - } - - @SuppressWarnings("unchecked") - protected void chatClientUserParams(ChatClientObservationContext context) { - if (!(context.getRequest() - .context() - .get(ChatClientAttributes.USER_PARAMS.getKey()) instanceof Map userParams)) { - return; - } - if (CollectionUtils.isEmpty(userParams)) { - return; - } - context.addHighCardinalityKeyValue( - KeyValue.of(ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_USER_PARAMS, - TracingHelper.concatenateMaps((Map) userParams))); - } - -} diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/observation/ChatClientObservationContext.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/observation/ChatClientObservationContext.java index 41d8afa71a2..f12de8feac0 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/observation/ChatClientObservationContext.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/observation/ChatClientObservationContext.java @@ -78,12 +78,7 @@ public boolean isStream() { return this.stream; } - /** - * @deprecated not used anymore. The format instructions are already included in the - * ChatModelObservationContext. - */ @Nullable - @Deprecated public String getFormat() { if (this.request.context().get(ChatClientAttributes.OUTPUT_FORMAT.getKey()) instanceof String format) { return format; @@ -91,21 +86,13 @@ public String getFormat() { return null; } - /** - * @deprecated not used anymore. The format instructions are already included in the - * ChatModelObservationContext. - */ - @Deprecated - public void setFormat(@Nullable String format) { - this.request.context().put(ChatClientAttributes.OUTPUT_FORMAT.getKey(), format); - } - public static final class Builder { private ChatClientRequest chatClientRequest; private List advisors = List.of(); + @Nullable private String format; private boolean isStream = false; @@ -118,17 +105,7 @@ public Builder request(ChatClientRequest chatClientRequest) { return this; } - @Deprecated // use request(ChatClientRequest chatClientRequest) - public Builder withRequest(ChatClientRequest chatClientRequest) { - return request(chatClientRequest); - } - - /** - * @deprecated not used anymore. The format instructions are already included in - * the ChatModelObservationContext. - */ - @Deprecated - public Builder withFormat(String format) { + public Builder format(@Nullable String format) { this.format = format; return this; } @@ -143,11 +120,6 @@ public Builder stream(boolean isStream) { return this; } - @Deprecated // use stream(boolean isStream) - public Builder withStream(boolean isStream) { - return stream(isStream); - } - public ChatClientObservationContext build() { if (StringUtils.hasText(format)) { this.chatClientRequest.context().put(ChatClientAttributes.OUTPUT_FORMAT.getKey(), format); diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/observation/ChatClientObservationDocumentation.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/observation/ChatClientObservationDocumentation.java index 7eb536edd83..b1c2ee2e543 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/observation/ChatClientObservationDocumentation.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/observation/ChatClientObservationDocumentation.java @@ -109,94 +109,6 @@ public String asString() { } }, - /** - * Enabled tool function names. - * @deprecated replaced by {@link #CHAT_CLIENT_TOOL_NAMES} - */ - @Deprecated - CHAT_CLIENT_TOOL_FUNCTION_NAMES { - @Override - public String asString() { - return "spring.ai.chat.client.tool.function.names"; - } - }, - - /** - * List of configured chat client function callbacks. - * @deprecated replaced by {@link #CHAT_CLIENT_TOOL_NAMES} - */ - @Deprecated - CHAT_CLIENT_TOOL_FUNCTION_CALLBACKS { - @Override - public String asString() { - return "spring.ai.chat.client.tool.function.callbacks"; - } - }, - - /** - * Map of advisor parameters. - * @deprecated risk to expose sensitive information or break the instrumentation - * since the advisor context map is used to pass arbitrary Java objects between - * advisors and not necessarily serializable. The conversation ID, previously part - * of this, is already included in the {@link #CHAT_CLIENT_CONVERSATION_ID} - * method. - */ - @Deprecated - CHAT_CLIENT_ADVISOR_PARAMS { - @Override - public String asString() { - return "spring.ai.chat.client.advisor.params"; - } - }, - - /** - * Chat client user text. - * @deprecated replaced by {@link #PROMPT} - */ - @Deprecated - CHAT_CLIENT_USER_TEXT { - @Override - public String asString() { - return "spring.ai.chat.client.user.text"; - } - }, - - /** - * Chat client user parameters. - * @deprecated replaced by {@link #PROMPT} - */ - @Deprecated - CHAT_CLIENT_USER_PARAMS { - @Override - public String asString() { - return "spring.ai.chat.client.user.params"; - } - }, - - /** - * Chat client system text. - * @deprecated replaced by {@link #PROMPT} - */ - @Deprecated - CHAT_CLIENT_SYSTEM_TEXT { - @Override - public String asString() { - return "spring.ai.chat.client.system.text"; - } - }, - - /** - * Chat client system parameters. - * @deprecated replaced by {@link #PROMPT} - */ - @Deprecated - CHAT_CLIENT_SYSTEM_PARAM { - @Override - public String asString() { - return "spring.ai.chat.client.system.params"; - } - }, - // Content /** diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConvention.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConvention.java index 0bfdfc5ae8f..431ce9c843d 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConvention.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConvention.java @@ -16,14 +16,11 @@ package org.springframework.ai.chat.client.observation; -import java.util.Arrays; import java.util.ArrayList; -import java.util.HashMap; import io.micrometer.common.KeyValue; import io.micrometer.common.KeyValues; -import org.springframework.ai.chat.client.ChatClientAttributes; import org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor; import org.springframework.ai.chat.client.advisor.api.Advisor; import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation.LowCardinalityKeyNames; @@ -97,12 +94,6 @@ public KeyValues getHighCardinalityKeyValues(ChatClientObservationContext contex keyValues = advisors(keyValues, context); keyValues = conversationId(keyValues, context); keyValues = tools(keyValues, context); - // @deprecated remove before 1.0.0-RC1. - keyValues = chatClientAdvisorParams(keyValues, context); - // @deprecated remove before 1.0.0-RC1. - keyValues = toolNames(keyValues, context); - // @deprecated remove before 1.0.0-RC1. - keyValues = toolCallbacks(keyValues, context); return keyValues; } @@ -123,7 +114,8 @@ protected KeyValues conversationId(KeyValues keyValues, ChatClientObservationCon var conversationIdValue = context.getRequest() .context() .get(AbstractChatMemoryAdvisor.CHAT_MEMORY_CONVERSATION_ID_KEY); - if (!(conversationIdValue instanceof String conversationId) || StringUtils.isEmpty(conversationId)) { + + if (!(conversationIdValue instanceof String conversationId) || !StringUtils.hasText(conversationId)) { return keyValues; } @@ -154,71 +146,4 @@ protected KeyValues tools(KeyValues keyValues, ChatClientObservationContext cont TracingHelper.concatenateStrings(toolNames.stream().sorted().toList())); } - /** - * @deprecated risk to expose sensitive information or break the instrumentation since - * the advisor context map is used to pass arbitrary Java objects between advisors and - * not necessarily serializable. The conversation ID, previously part of this, is - * already included in the - * {@link #conversationId(KeyValues, ChatClientObservationContext)} method. - */ - @Deprecated - protected KeyValues chatClientAdvisorParams(KeyValues keyValues, ChatClientObservationContext context) { - if (CollectionUtils.isEmpty(context.getRequest().context())) { - return keyValues; - } - var chatClientContext = new HashMap<>(context.getRequest().context()); - Arrays.stream(ChatClientAttributes.values()).forEach(attribute -> chatClientContext.remove(attribute.getKey())); - return keyValues.and( - ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_ADVISOR_PARAMS.asString(), - TracingHelper.concatenateMaps(chatClientContext)); - } - - /** - * @deprecated in favor of {@link #tools(KeyValues, ChatClientObservationContext)} - */ - @Deprecated - protected KeyValues toolNames(KeyValues keyValues, ChatClientObservationContext context) { - if (context.getRequest().prompt().getOptions() == null) { - return keyValues; - } - if (!(context.getRequest().prompt().getOptions() instanceof ToolCallingChatOptions options)) { - return keyValues; - } - - var toolNames = options.getToolNames(); - if (CollectionUtils.isEmpty(toolNames)) { - return keyValues; - } - - return keyValues.and( - ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_TOOL_FUNCTION_NAMES.asString(), - TracingHelper.concatenateStrings(toolNames.stream().sorted().toList())); - } - - /** - * @deprecated in favor of {@link #tools(KeyValues, ChatClientObservationContext)} - */ - @Deprecated - protected KeyValues toolCallbacks(KeyValues keyValues, ChatClientObservationContext context) { - if (context.getRequest().prompt().getOptions() == null) { - return keyValues; - } - if (!(context.getRequest().prompt().getOptions() instanceof ToolCallingChatOptions options)) { - return keyValues; - } - - var toolCallbacks = options.getToolCallbacks(); - if (CollectionUtils.isEmpty(toolCallbacks)) { - return keyValues; - } - - var toolCallbackNames = toolCallbacks.stream() - .map(toolCallback -> toolCallback.getToolDefinition().name()) - .sorted() - .toList(); - return keyValues - .and(ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_TOOL_FUNCTION_CALLBACKS - .asString(), TracingHelper.concatenateStrings(toolCallbackNames)); - } - } diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java index b94a91f6dbd..433050cb38a 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,9 +24,9 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.client.ChatClientResponse; import reactor.core.publisher.Flux; -import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.ChatResponseMetadata; @@ -42,31 +42,28 @@ * * @author Christian Tzolov * @author Alexandros Pappas + * @author Thomas Vitale * @since 1.0.0 */ public class MessageAggregator { private static final Logger logger = LoggerFactory.getLogger(MessageAggregator.class); - public Flux aggregateAdvisedResponse(Flux advisedResponses, - Consumer aggregationHandler) { + public Flux aggregateChatClientResponse(Flux chatClientResponses, + Consumer aggregationHandler) { - AtomicReference> adviseContext = new AtomicReference<>(new HashMap<>()); - - return new MessageAggregator().aggregate(advisedResponses.map(ar -> { - adviseContext.get().putAll(ar.adviseContext()); - return ar.response(); + AtomicReference> context = new AtomicReference<>(new HashMap<>()); + return new MessageAggregator().aggregate(chatClientResponses.mapNotNull(chatClientResponse -> { + context.get().putAll(chatClientResponse.context()); + return chatClientResponse.chatResponse(); }), aggregatedChatResponse -> { - - AdvisedResponse aggregatedAdvisedResponse = AdvisedResponse.builder() - .response(aggregatedChatResponse) - .adviseContext(adviseContext.get()) + ChatClientResponse aggregatedChatClientResponse = ChatClientResponse.builder() + .chatResponse(aggregatedChatResponse) + .context(context.get()) .build(); - - aggregationHandler.accept(aggregatedAdvisedResponse); - - }).map(cr -> new AdvisedResponse(cr, adviseContext.get())); + aggregationHandler.accept(aggregatedChatClientResponse); + }).map(chatResponse -> ChatClientResponse.builder().chatResponse(chatResponse).context(context.get()).build()); } public Flux aggregate(Flux fluxChatResponse, diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java index 87285182f74..783a7356c0a 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/ChatClientTest.java @@ -713,7 +713,7 @@ void whenPromptWithMessagesAndSystemText() { assertThat(content).isEqualTo("response"); assertThat(this.promptCaptor.getValue().getInstructions()).hasSize(4); - var systemMessage = this.promptCaptor.getValue().getInstructions().get(2); + var systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getText()).isEqualTo("instructions"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); } @@ -747,7 +747,7 @@ void whenPromptWithSystemMessageAndSystemText() { assertThat(content).isEqualTo("response"); assertThat(this.promptCaptor.getValue().getInstructions()).hasSize(4); - var systemMessage = this.promptCaptor.getValue().getInstructions().get(2); + var systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getText()).isEqualTo("other instructions"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); } @@ -769,7 +769,7 @@ void whenMessagesAndSystemText() { assertThat(content).isEqualTo("response"); assertThat(this.promptCaptor.getValue().getInstructions()).hasSize(4); - var systemMessage = this.promptCaptor.getValue().getInstructions().get(2); + var systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getText()).isEqualTo("instructions"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); } @@ -808,7 +808,7 @@ void whenMessagesWithSystemMessageAndSystemText() { assertThat(content).isEqualTo("response"); assertThat(this.promptCaptor.getValue().getInstructions()).hasSize(4); - var systemMessage = this.promptCaptor.getValue().getInstructions().get(2); + var systemMessage = this.promptCaptor.getValue().getInstructions().get(0); assertThat(systemMessage.getText()).isEqualTo("other instructions"); assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM); } diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientUtilsTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientUtilsTests.java new file mode 100644 index 00000000000..3f0ed5148e5 --- /dev/null +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientUtilsTests.java @@ -0,0 +1,448 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.client; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.content.Media; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.template.TemplateRenderer; +import org.springframework.ai.template.st.StTemplateRenderer; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.metadata.ToolMetadata; + +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + +/** + * Unit tests for {@link DefaultChatClientUtils}. + * + * @author Thomas Vitale + */ +class DefaultChatClientUtilsTests { + + @Test + void whenInputRequestIsNullThenThrows() { + assertThatThrownBy(() -> DefaultChatClientUtils.toChatClientRequest(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("inputRequest cannot be null"); + } + + @Test + void whenSystemTextIsProvidedThenSystemMessageIsAddedToPrompt() { + String systemText = "System instructions"; + ChatModel chatModel = mock(ChatModel.class); + DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient + .create(chatModel) + .prompt() + .system(systemText); + + ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest); + + assertThat(result).isNotNull(); + assertThat(result.prompt().getInstructions()).isNotEmpty(); + assertThat(result.prompt().getInstructions().get(0)).isInstanceOf(SystemMessage.class); + assertThat(result.prompt().getInstructions().get(0).getText()).isEqualTo(systemText); + } + + @Test + void whenSystemTextWithParamsIsProvidedThenSystemMessageIsRenderedAndAddedToPrompt() { + String systemText = "System instructions for {name}"; + Map systemParams = Map.of("name", "Spring AI"); + ChatModel chatModel = mock(ChatModel.class); + DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient + .create(chatModel) + .prompt() + .system(s -> s.text(systemText).params(systemParams)); + + ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest); + + assertThat(result).isNotNull(); + assertThat(result.prompt().getInstructions()).isNotEmpty(); + assertThat(result.prompt().getInstructions().get(0)).isInstanceOf(SystemMessage.class); + assertThat(result.prompt().getInstructions().get(0).getText()).isEqualTo("System instructions for Spring AI"); + } + + @Test + void whenMessagesAreProvidedThenTheyAreAddedToPrompt() { + List messages = List.of(new SystemMessage("System message"), new UserMessage("User message")); + ChatModel chatModel = mock(ChatModel.class); + DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient + .create(chatModel) + .prompt() + .messages(messages); + + ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest); + + assertThat(result).isNotNull(); + assertThat(result.prompt().getInstructions()).hasSize(2); + assertThat(result.prompt().getInstructions().get(0).getText()).isEqualTo("System message"); + assertThat(result.prompt().getInstructions().get(1).getText()).isEqualTo("User message"); + } + + @Test + void whenUserTextIsProvidedThenUserMessageIsAddedToPrompt() { + String userText = "User question"; + ChatModel chatModel = mock(ChatModel.class); + DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient + .create(chatModel) + .prompt() + .user(userText); + + ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest); + + assertThat(result).isNotNull(); + assertThat(result.prompt().getInstructions()).isNotEmpty(); + assertThat(result.prompt().getInstructions().get(0)).isInstanceOf(UserMessage.class); + assertThat(result.prompt().getInstructions().get(0).getText()).isEqualTo(userText); + } + + @Test + void whenUserTextWithParamsIsProvidedThenUserMessageIsRenderedAndAddedToPrompt() { + String userText = "Question about {topic}"; + Map userParams = Map.of("topic", "Spring AI"); + ChatModel chatModel = mock(ChatModel.class); + DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient + .create(chatModel) + .prompt() + .user(s -> s.text(userText).params(userParams)); + + ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest); + + assertThat(result).isNotNull(); + assertThat(result.prompt().getInstructions()).isNotEmpty(); + assertThat(result.prompt().getInstructions().get(0)).isInstanceOf(UserMessage.class); + assertThat(result.prompt().getInstructions().get(0).getText()).isEqualTo("Question about Spring AI"); + } + + @Test + void whenUserTextWithMediaIsProvidedThenUserMessageWithMediaIsAddedToPrompt() { + String userText = "What's in this image?"; + Media media = mock(Media.class); + ChatModel chatModel = mock(ChatModel.class); + DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient + .create(chatModel) + .prompt() + .user(s -> s.text(userText).media(media)); + + ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest); + + assertThat(result).isNotNull(); + assertThat(result.prompt().getInstructions()).isNotEmpty(); + assertThat(result.prompt().getInstructions().get(0)).isInstanceOf(UserMessage.class); + UserMessage userMessage = (UserMessage) result.prompt().getInstructions().get(0); + assertThat(userMessage.getText()).isEqualTo(userText); + assertThat(userMessage.getMedia()).contains(media); + } + + @Test + void whenSystemTextAndSystemMessageAreProvidedThenSystemTextIsFirst() { + String systemText = "System instructions"; + List messages = List.of(new SystemMessage("System message")); + ChatModel chatModel = mock(ChatModel.class); + DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient + .create(chatModel) + .prompt() + .system(systemText) + .messages(messages); + + ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest); + + assertThat(result).isNotNull(); + assertThat(result.prompt().getInstructions()).hasSize(2); + assertThat(result.prompt().getInstructions().get(0)).isInstanceOf(SystemMessage.class); + assertThat(result.prompt().getInstructions().get(0).getText()).isEqualTo(systemText); + } + + @Test + void whenUserTextAndUserMessageAreProvidedThenUserTextIsLast() { + String userText = "User question"; + List messages = List.of(new UserMessage("User message")); + ChatModel chatModel = mock(ChatModel.class); + DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient + .create(chatModel) + .prompt() + .user(userText) + .messages(messages); + + ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest); + + assertThat(result).isNotNull(); + assertThat(result.prompt().getInstructions()).hasSize(2); + assertThat(result.prompt().getInstructions()).last().isInstanceOf(UserMessage.class); + assertThat(result.prompt().getInstructions()).last().extracting(Message::getText).isEqualTo(userText); + } + + @Test + void whenToolCallingChatOptionsIsProvidedThenToolNamesAreSet() { + ToolCallingChatOptions chatOptions = ToolCallingChatOptions.builder().build(); + List toolNames = List.of("tool1", "tool2"); + ChatModel chatModel = mock(ChatModel.class); + DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient + .create(chatModel) + .prompt() + .options(chatOptions) + .toolNames(toolNames.toArray(new String[0])); + + ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest); + + assertThat(result).isNotNull(); + assertThat(result.prompt().getOptions()).isInstanceOf(ToolCallingChatOptions.class); + ToolCallingChatOptions resultOptions = (ToolCallingChatOptions) result.prompt().getOptions(); + assertThat(resultOptions).isNotNull(); + assertThat(resultOptions.getToolNames()).containsExactlyInAnyOrderElementsOf(toolNames); + } + + @Test + void whenToolCallingChatOptionsIsProvidedThenToolCallbacksAreSet() { + ToolCallingChatOptions chatOptions = ToolCallingChatOptions.builder().build(); + ToolCallback toolCallback = new TestToolCallback("tool1"); + ChatModel chatModel = mock(ChatModel.class); + DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient + .create(chatModel) + .prompt() + .options(chatOptions) + .toolCallbacks(toolCallback); + + ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest); + + assertThat(result).isNotNull(); + assertThat(result.prompt().getOptions()).isInstanceOf(ToolCallingChatOptions.class); + ToolCallingChatOptions resultOptions = (ToolCallingChatOptions) result.prompt().getOptions(); + assertThat(resultOptions).isNotNull(); + assertThat(resultOptions.getToolCallbacks()).contains(toolCallback); + } + + @Test + void whenToolCallingChatOptionsIsProvidedThenToolContextIsSet() { + ToolCallingChatOptions chatOptions = ToolCallingChatOptions.builder().build(); + Map toolContext = Map.of("key", "value"); + ChatModel chatModel = mock(ChatModel.class); + DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient + .create(chatModel) + .prompt() + .options(chatOptions) + .toolContext(toolContext); + + ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest); + + assertThat(result).isNotNull(); + assertThat(result.prompt().getOptions()).isInstanceOf(ToolCallingChatOptions.class); + ToolCallingChatOptions resultOptions = (ToolCallingChatOptions) result.prompt().getOptions(); + assertThat(resultOptions).isNotNull(); + assertThat(resultOptions.getToolContext()).containsAllEntriesOf(toolContext); + } + + @Test + void whenToolNamesAndChatOptionsAreProvidedThenTheToolNamesOverride() { + Set toolNames1 = Set.of("toolA", "toolB"); + ToolCallingChatOptions chatOptions = ToolCallingChatOptions.builder().toolNames(toolNames1).build(); + List toolNames2 = List.of("tool1", "tool2"); + ChatModel chatModel = mock(ChatModel.class); + DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient + .create(chatModel) + .prompt() + .options(chatOptions) + .toolNames(toolNames2.toArray(new String[0])); + + ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest); + + assertThat(result).isNotNull(); + assertThat(result.prompt().getOptions()).isInstanceOf(ToolCallingChatOptions.class); + ToolCallingChatOptions resultOptions = (ToolCallingChatOptions) result.prompt().getOptions(); + assertThat(resultOptions).isNotNull(); + assertThat(resultOptions.getToolNames()).containsExactlyInAnyOrderElementsOf(toolNames2); + } + + @Test + void whenToolCallbacksAndChatOptionsAreProvidedThenTheToolCallbacksOverride() { + ToolCallback toolCallback1 = new TestToolCallback("tool1"); + ToolCallingChatOptions chatOptions = ToolCallingChatOptions.builder().toolCallbacks(toolCallback1).build(); + ToolCallback toolCallback2 = new TestToolCallback("tool2"); + ChatModel chatModel = mock(ChatModel.class); + DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient + .create(chatModel) + .prompt() + .options(chatOptions) + .toolCallbacks(toolCallback2); + + ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest); + + assertThat(result).isNotNull(); + assertThat(result.prompt().getOptions()).isInstanceOf(ToolCallingChatOptions.class); + ToolCallingChatOptions resultOptions = (ToolCallingChatOptions) result.prompt().getOptions(); + assertThat(resultOptions).isNotNull(); + assertThat(resultOptions.getToolCallbacks()).containsExactlyInAnyOrder(toolCallback2); + } + + @Test + void whenToolContextAndChatOptionsAreProvidedThenTheValuesAreMerged() { + Map toolContext1 = Map.of("key1", "value1"); + ToolCallingChatOptions chatOptions = ToolCallingChatOptions.builder().toolContext(toolContext1).build(); + Map toolContext2 = Map.of("key2", "value2"); + ChatModel chatModel = mock(ChatModel.class); + DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient + .create(chatModel) + .prompt() + .options(chatOptions) + .toolContext(toolContext2); + + ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest); + + assertThat(result).isNotNull(); + assertThat(result.prompt().getOptions()).isInstanceOf(ToolCallingChatOptions.class); + ToolCallingChatOptions resultOptions = (ToolCallingChatOptions) result.prompt().getOptions(); + assertThat(resultOptions).isNotNull(); + assertThat(resultOptions.getToolContext()).containsAllEntriesOf(toolContext1) + .containsAllEntriesOf(toolContext2); + } + + @Test + void whenAdvisorParamsAreProvidedThenTheyAreAddedToContext() { + Map advisorParams = Map.of("key1", "value1", "key2", "value2"); + ChatModel chatModel = mock(ChatModel.class); + DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient + .create(chatModel) + .prompt() + .advisors(a -> a.params(advisorParams)); + + ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest); + + assertThat(result).isNotNull(); + assertThat(result.context()).containsAllEntriesOf(advisorParams); + } + + @Test + void whenCustomTemplateRendererIsProvidedThenItIsUsedForRendering() { + String systemText = "Instructions "; + Map systemParams = Map.of("name", "Spring AI"); + TemplateRenderer customRenderer = StTemplateRenderer.builder() + .startDelimiterToken('<') + .endDelimiterToken('>') + .build(); + ChatModel chatModel = mock(ChatModel.class); + DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient + .create(chatModel) + .prompt() + .system(s -> s.text(systemText).params(systemParams)) + .templateRenderer(customRenderer); + + ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest); + + assertThat(result).isNotNull(); + assertThat(result.prompt().getInstructions()).isNotEmpty(); + assertThat(result.prompt().getInstructions().get(0)).isInstanceOf(SystemMessage.class); + assertThat(result.prompt().getInstructions().get(0).getText()).isEqualTo("Instructions Spring AI"); + } + + @Test + void whenAllComponentsAreProvidedThenCompleteRequestIsCreated() { + String systemText = "System instructions for {name}"; + Map systemParams = Map.of("name", "Spring AI"); + + String userText = "Question about {topic}"; + Map userParams = Map.of("topic", "Spring AI"); + Media media = mock(Media.class); + + List messages = List.of(new UserMessage("Intermediate message")); + + ToolCallingChatOptions chatOptions = ToolCallingChatOptions.builder().build(); + List toolNames = List.of("tool1", "tool2"); + ToolCallback toolCallback = new TestToolCallback("tool3"); + Map toolContext = Map.of("toolKey", "toolValue"); + + Map advisorParams = Map.of("advisorKey", "advisorValue"); + + ChatModel chatModel = mock(ChatModel.class); + DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient + .create(chatModel) + .prompt() + .system(s -> s.text(systemText).params(systemParams)) + .user(u -> u.text(userText).params(userParams).media(media)) + .messages(messages) + .toolNames(toolNames.toArray(new String[0])) + .toolCallbacks(toolCallback) + .toolContext(toolContext) + .options(chatOptions) + .advisors(a -> a.params(advisorParams)); + + ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest); + + assertThat(result).isNotNull(); + + assertThat(result.prompt().getInstructions()).hasSize(3); + assertThat(result.prompt().getInstructions().get(0)).isInstanceOf(SystemMessage.class); + assertThat(result.prompt().getInstructions().get(0).getText()).isEqualTo("System instructions for Spring AI"); + assertThat(result.prompt().getInstructions().get(1).getText()).isEqualTo("Intermediate message"); + assertThat(result.prompt().getInstructions().get(2)).isInstanceOf(UserMessage.class); + assertThat(result.prompt().getInstructions().get(2).getText()).isEqualTo("Question about Spring AI"); + UserMessage userMessage = (UserMessage) result.prompt().getInstructions().get(2); + assertThat(userMessage.getMedia()).contains(media); + + assertThat(result.prompt().getOptions()).isInstanceOf(ToolCallingChatOptions.class); + ToolCallingChatOptions resultOptions = (ToolCallingChatOptions) result.prompt().getOptions(); + assertThat(resultOptions).isNotNull(); + assertThat(resultOptions.getToolNames()).containsExactlyInAnyOrderElementsOf(toolNames); + assertThat(resultOptions.getToolCallbacks()).contains(toolCallback); + assertThat(resultOptions.getToolContext()).containsAllEntriesOf(toolContext); + + assertThat(result.context()).containsAllEntriesOf(advisorParams); + } + + static class TestToolCallback implements ToolCallback { + + private final ToolDefinition toolDefinition; + + private final ToolMetadata toolMetadata; + + TestToolCallback(String name) { + this.toolDefinition = ToolDefinition.builder().name(name).inputSchema("{}").build(); + this.toolMetadata = ToolMetadata.builder().build(); + } + + TestToolCallback(String name, boolean returnDirect) { + this.toolDefinition = ToolDefinition.builder().name(name).inputSchema("{}").build(); + this.toolMetadata = ToolMetadata.builder().returnDirect(returnDirect).build(); + } + + @Override + public ToolDefinition getToolDefinition() { + return this.toolDefinition; + } + + @Override + public ToolMetadata getToolMetadata() { + return this.toolMetadata; + } + + @Override + public String call(String toolInput) { + return "Mission accomplished!"; + } + + } + +} diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponseStreamUtilsTest.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/AdvisorUtilsTests.java similarity index 65% rename from spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponseStreamUtilsTest.java rename to spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/AdvisorUtilsTests.java index 0d96cb89ccb..a9878c1b1d0 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponseStreamUtilsTest.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/AdvisorUtilsTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2025-2025 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,83 +14,84 @@ * limitations under the License. */ -package org.springframework.ai.chat.client.advisor.api; - -import java.util.List; +package org.springframework.ai.chat.client.advisor; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; - +import org.springframework.ai.chat.client.ChatClientResponse; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; +import java.util.List; + import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; /** - * Unit tests for {@link AdvisedResponseStreamUtils}. + * Unit tests for {@link AdvisorUtils}. * * @author ghdcksgml1 + * @author Thomas Vitale */ -class AdvisedResponseStreamUtilsTest { +class AdvisorUtilsTests { @Nested class OnFinishReason { @Test void whenChatResponseIsNullThenReturnFalse() { - AdvisedResponse response = mock(AdvisedResponse.class); - given(response.response()).willReturn(null); + ChatClientResponse chatClientResponse = mock(ChatClientResponse.class); + given(chatClientResponse.chatResponse()).willReturn(null); - boolean result = AdvisedResponseStreamUtils.onFinishReason().test(response); + boolean result = AdvisorUtils.onFinishReason().test(chatClientResponse); assertFalse(result); } @Test void whenChatResponseResultsIsNullThenReturnFalse() { - AdvisedResponse response = mock(AdvisedResponse.class); + ChatClientResponse chatClientResponse = mock(ChatClientResponse.class); ChatResponse chatResponse = mock(ChatResponse.class); given(chatResponse.getResults()).willReturn(null); - given(response.response()).willReturn(chatResponse); + given(chatClientResponse.chatResponse()).willReturn(chatResponse); - boolean result = AdvisedResponseStreamUtils.onFinishReason().test(response); + boolean result = AdvisorUtils.onFinishReason().test(chatClientResponse); assertFalse(result); } @Test void whenChatIsRunningThenReturnFalse() { - AdvisedResponse response = mock(AdvisedResponse.class); + ChatClientResponse chatClientResponse = mock(ChatClientResponse.class); ChatResponse chatResponse = mock(ChatResponse.class); Generation generation = new Generation(new AssistantMessage("running.."), ChatGenerationMetadata.NULL); given(chatResponse.getResults()).willReturn(List.of(generation)); - given(response.response()).willReturn(chatResponse); + given(chatClientResponse.chatResponse()).willReturn(chatResponse); - boolean result = AdvisedResponseStreamUtils.onFinishReason().test(response); + boolean result = AdvisorUtils.onFinishReason().test(chatClientResponse); assertFalse(result); } @Test void whenChatIsStopThenReturnTrue() { - AdvisedResponse response = mock(AdvisedResponse.class); + ChatClientResponse chatClientResponse = mock(ChatClientResponse.class); ChatResponse chatResponse = mock(ChatResponse.class); Generation generation = new Generation(new AssistantMessage("finish."), ChatGenerationMetadata.builder().finishReason("STOP").build()); given(chatResponse.getResults()).willReturn(List.of(generation)); - given(response.response()).willReturn(chatResponse); + given(chatClientResponse.chatResponse()).willReturn(chatResponse); - boolean result = AdvisedResponseStreamUtils.onFinishReason().test(response); + boolean result = AdvisorUtils.onFinishReason().test(chatClientResponse); assertTrue(result); } diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/AdvisorsTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/AdvisorsTests.java index fd92eb230cd..48b3da6873e 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/AdvisorsTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/AdvisorsTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,12 +30,12 @@ import reactor.core.publisher.Flux; import org.springframework.ai.chat.client.ChatClient; -import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; -import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; -import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; -import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; +import org.springframework.ai.chat.client.ChatClientRequest; +import org.springframework.ai.chat.client.ChatClientResponse; +import org.springframework.ai.chat.client.advisor.api.CallAdvisor; +import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain; +import org.springframework.ai.chat.client.advisor.api.StreamAdvisor; +import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; @@ -84,8 +84,8 @@ public void callAdvisorsContextPropagation() { assertThat(content).isEqualTo("Hello John"); // AROUND - assertThat(mockAroundAdvisor1.advisedResponse.response()).isNotNull(); - assertThat(mockAroundAdvisor1.advisedResponse.adviseContext()).containsEntry("key1", "value1") + assertThat(mockAroundAdvisor1.chatClientResponse.chatResponse()).isNotNull(); + assertThat(mockAroundAdvisor1.chatClientResponse.context()).containsEntry("key1", "value1") .containsEntry("key2", "value2") .containsEntry("aroundCallBeforeAdvisor1", "AROUND_CALL_BEFORE Advisor1") .containsEntry("aroundCallAfterAdvisor1", "AROUND_CALL_AFTER Advisor1") @@ -126,10 +126,10 @@ public void streamAdvisorsContextPropagation() { assertThat(content).isEqualTo("Hello John"); // AROUND - assertThat(mockAroundAdvisor1.aroundAdvisedResponses).isNotEmpty(); + assertThat(mockAroundAdvisor1.advisedChatClientResponses).isNotEmpty(); - mockAroundAdvisor1.aroundAdvisedResponses.stream() - .forEach(advisedResponse -> assertThat(advisedResponse.adviseContext()).containsEntry("key1", "value1") + mockAroundAdvisor1.advisedChatClientResponses.stream() + .forEach(chatClientResponse -> assertThat(chatClientResponse.context()).containsEntry("key1", "value1") .containsEntry("key2", "value2") .containsEntry("aroundStreamBeforeAdvisor1", "AROUND_STREAM_BEFORE Advisor1") .containsEntry("aroundStreamAfterAdvisor1", "AROUND_STREAM_AFTER Advisor1") @@ -142,17 +142,17 @@ public void streamAdvisorsContextPropagation() { verify(this.chatModel).stream(this.promptCaptor.capture()); } - public class MockAroundAdvisor implements CallAroundAdvisor, StreamAroundAdvisor { + public class MockAroundAdvisor implements CallAdvisor, StreamAdvisor { private final String name; private final int order; - public AdvisedRequest advisedRequest; + public ChatClientRequest chatClientRequest; - public AdvisedResponse advisedResponse; + public ChatClientResponse chatClientResponse; - public List aroundAdvisedResponses = new ArrayList<>(); + public List advisedChatClientResponses = new ArrayList<>(); public MockAroundAdvisor(String name, int order) { this.name = name; @@ -170,45 +170,38 @@ public int getOrder() { } @Override - public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { + public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) { + this.chatClientRequest = chatClientRequest.mutate() + .context(Map.of("aroundCallBefore" + getName(), "AROUND_CALL_BEFORE " + getName(), "lastBefore", + getName())) + .build(); - this.advisedRequest = advisedRequest.updateContext(context -> { - context.put("aroundCallBefore" + getName(), "AROUND_CALL_BEFORE " + getName()); - context.put("lastBefore", getName()); - return context; - }); + var chatClientResponse = callAdvisorChain.nextCall(this.chatClientRequest); - this.advisedResponse = chain.nextAroundCall(this.advisedRequest); - AdvisedResponse advisedResponse = this.advisedResponse; + this.chatClientResponse = chatClientResponse.mutate() + .context( + Map.of("aroundCallAfter" + getName(), "AROUND_CALL_AFTER " + getName(), "lastAfter", getName())) + .build(); - this.advisedResponse = advisedResponse.updateContext(context -> { - context.put("aroundCallAfter" + this.name, "AROUND_CALL_AFTER " + this.name); - context.put("lastAfter", this.name); - return context; - }); - - return this.advisedResponse; + return this.chatClientResponse; } @Override - public Flux aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) { - - this.advisedRequest = advisedRequest.updateContext(context -> { - context.put("aroundStreamBefore" + this.name, "AROUND_STREAM_BEFORE " + this.name); - context.put("lastBefore", this.name); - return context; - }); - - Flux advisedResponseStream = chain.nextAroundStream(this.advisedRequest); - - return advisedResponseStream.map(advisedResponse -> { - return advisedResponse.updateContext(context -> { - context.put("aroundStreamAfter" + this.name, "AROUND_STREAM_AFTER " + this.name); - context.put("lastAfter", this.name); - return context; - }); - }).doOnNext(ar -> this.aroundAdvisedResponses.add(ar)); - + public Flux adviseStream(ChatClientRequest chatClientRequest, + StreamAdvisorChain streamAdvisorChain) { + this.chatClientRequest = chatClientRequest.mutate() + .context(Map.of("aroundStreamBefore" + getName(), "AROUND_STREAM_BEFORE " + getName(), "lastBefore", + getName())) + .build(); + + Flux chatClientResponseFlux = streamAdvisorChain.nextStream(this.chatClientRequest); + + return chatClientResponseFlux + .map(chatClientResponse -> chatClientResponse.mutate() + .context(Map.of("aroundStreamAfter" + getName(), "AROUND_STREAM_AFTER " + getName(), "lastAfter", + getName())) + .build()) + .doOnNext(ar -> this.advisedChatClientResponses.add(ar)); } } diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisorTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisorTests.java index 3a62b7a1fdd..b9f78de6f48 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisorTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisorTests.java @@ -103,7 +103,7 @@ private void validate(String content, CapturedOutput output) { UserMessage userMessage = (UserMessage) this.promptCaptor.getValue().getInstructions().get(0); assertThat(userMessage.getText()).isEqualToIgnoringWhitespace("Please answer my question XYZ"); - assertThat(output.getOut()).contains("request: AdvisedRequest", "userText=Please answer my question XYZ"); + assertThat(output.getOut()).contains("request: ChatClientRequest", "Please answer my question XYZ"); assertThat(output.getOut()).contains("response:", "finishReason"); } diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequestTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequestTests.java deleted file mode 100644 index 806498d6039..00000000000 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequestTests.java +++ /dev/null @@ -1,279 +0,0 @@ -/* - * Copyright 2023-2025 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.chat.client.advisor.api; - -import java.util.List; -import java.util.Map; - -import org.junit.jupiter.api.Test; - -import org.springframework.ai.chat.client.ChatClientAttributes; -import org.springframework.ai.chat.client.ChatClientRequest; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.SystemMessage; -import org.springframework.ai.chat.messages.UserMessage; -import org.springframework.ai.chat.model.ChatModel; -import org.springframework.ai.chat.prompt.ChatOptions; -import org.springframework.ai.content.Media; -import org.springframework.ai.model.tool.ToolCallingChatOptions; -import org.springframework.ai.template.TemplateRenderer; -import org.springframework.ai.template.st.StTemplateRenderer; -import org.springframework.ai.tool.ToolCallback; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.Mockito.mock; - -/** - * Unit tests for {@link AdvisedRequest}. - * - * @author Thomas Vitale - */ -class AdvisedRequestTests { - - @Test - void buildAdvisedRequest() { - AdvisedRequest request = new AdvisedRequest(mock(ChatModel.class), "user", null, null, List.of(), List.of(), - List.of(), List.of(), Map.of(), Map.of(), List.of(), Map.of(), Map.of(), Map.of()); - assertThat(request).isNotNull(); - } - - @Test - void whenChatModelIsNullThenThrows() { - assertThatThrownBy(() -> new AdvisedRequest(null, "user", null, null, List.of(), List.of(), List.of(), - List.of(), Map.of(), Map.of(), List.of(), Map.of(), Map.of(), Map.of())) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("chatModel cannot be null"); - } - - @Test - void whenUserTextIsNullThenThrows() { - assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), null, null, null, List.of(), List.of(), - List.of(), List.of(), Map.of(), Map.of(), List.of(), Map.of(), Map.of(), Map.of())) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage( - "userText cannot be null or empty unless messages are provided and contain Tool Response message."); - } - - @Test - void whenUserTextIsEmptyThenThrows() { - assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "", null, null, List.of(), List.of(), - List.of(), List.of(), Map.of(), Map.of(), List.of(), Map.of(), Map.of(), Map.of())) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage( - "userText cannot be null or empty unless messages are provided and contain Tool Response message."); - } - - @Test - void whenMediaIsNullThenThrows() { - assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, null, List.of(), - List.of(), List.of(), Map.of(), Map.of(), List.of(), Map.of(), Map.of(), Map.of())) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("media cannot be null"); - } - - @Test - void whenFunctionNamesIsNullThenThrows() { - assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, List.of(), null, - List.of(), List.of(), Map.of(), Map.of(), List.of(), Map.of(), Map.of(), Map.of())) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("toolNames cannot be null"); - } - - @Test - void whenToolCallbacksIsNullThenThrows() { - assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, List.of(), List.of(), - null, List.of(), Map.of(), Map.of(), List.of(), Map.of(), Map.of(), Map.of())) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("toolCallbacks cannot be null"); - } - - @Test - void whenMessagesIsNullThenThrows() { - assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, List.of(), List.of(), - List.of(), null, Map.of(), Map.of(), List.of(), Map.of(), Map.of(), Map.of())) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("messages cannot be null"); - } - - @Test - void whenUserParamsIsNullThenThrows() { - assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, List.of(), List.of(), - List.of(), List.of(), null, Map.of(), List.of(), Map.of(), Map.of(), Map.of())) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("userParams cannot be null"); - } - - @Test - void whenSystemParamsIsNullThenThrows() { - assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, List.of(), List.of(), - List.of(), List.of(), Map.of(), null, List.of(), Map.of(), Map.of(), Map.of())) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("systemParams cannot be null"); - } - - @Test - void whenAdvisorsIsNullThenThrows() { - assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, List.of(), List.of(), - List.of(), List.of(), Map.of(), Map.of(), null, Map.of(), Map.of(), Map.of())) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("advisors cannot be null"); - } - - @Test - void whenAdvisorParamsIsNullThenThrows() { - assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, List.of(), List.of(), - List.of(), List.of(), Map.of(), Map.of(), List.of(), null, Map.of(), Map.of())) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("advisorParams cannot be null"); - } - - @Test - void whenAdviseContextIsNullThenThrows() { - assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, List.of(), List.of(), - List.of(), List.of(), Map.of(), Map.of(), List.of(), Map.of(), null, Map.of())) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("adviseContext cannot be null"); - } - - @Test - void whenToolContextIsNullThenThrows() { - assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, List.of(), List.of(), - List.of(), List.of(), Map.of(), Map.of(), List.of(), Map.of(), Map.of(), null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("toolContext cannot be null"); - } - - @Test - void whenConvertToAndFromChatClientRequestWithDefaultTemplateRenderer() { - ChatModel chatModel = mock(ChatModel.class); - ChatOptions chatOptions = ToolCallingChatOptions.builder().build(); - List messages = List.of(mock(UserMessage.class)); - SystemMessage systemMessage = new SystemMessage("Instructions {key}"); - UserMessage userMessage = UserMessage.builder().text("Question {key}").media(mock(Media.class)).build(); - Map systemParams = Map.of("key", "value"); - Map userParams = Map.of("key", "value"); - List toolNames = List.of("tool1", "tool2"); - ToolCallback toolCallback = mock(ToolCallback.class); - Map toolContext = Map.of("key", "value"); - List advisors = List.of(mock(Advisor.class)); - Map advisorContext = Map.of("key", "value"); - - AdvisedRequest advisedRequest = AdvisedRequest.builder() - .chatModel(chatModel) - .chatOptions(chatOptions) - .messages(messages) - .systemText(systemMessage.getText()) - .systemParams(systemParams) - .userText(userMessage.getText()) - .userParams(userParams) - .media(userMessage.getMedia()) - .toolNames(toolNames) - .functionCallbacks(List.of(toolCallback)) - .toolContext(toolContext) - .advisors(advisors) - .adviseContext(advisorContext) - .build(); - - ChatClientRequest chatClientRequest = advisedRequest.toChatClientRequest(); - - assertThat(chatClientRequest.context().get(ChatClientAttributes.CHAT_MODEL.getKey())).isEqualTo(chatModel); - assertThat(chatClientRequest.prompt().getOptions()).isEqualTo(chatOptions); - assertThat(chatClientRequest.prompt().getInstructions()).hasSize(3); - assertThat(chatClientRequest.prompt().getInstructions().get(0)).isEqualTo(messages.get(0)); - assertThat(chatClientRequest.prompt().getInstructions().get(1).getText()).isEqualTo("Instructions value"); - assertThat(chatClientRequest.prompt().getInstructions().get(2).getText()).isEqualTo("Question value"); - assertThat(((ToolCallingChatOptions) chatClientRequest.prompt().getOptions()).getToolNames()) - .containsAll(toolNames); - assertThat(((ToolCallingChatOptions) chatClientRequest.prompt().getOptions()).getToolCallbacks()) - .contains(toolCallback); - assertThat(((ToolCallingChatOptions) chatClientRequest.prompt().getOptions()).getToolContext()) - .containsAllEntriesOf(toolContext); - assertThat((List) chatClientRequest.context().get(ChatClientAttributes.ADVISORS.getKey())) - .containsAll(advisors); - assertThat(chatClientRequest.context()).containsAllEntriesOf(advisorContext); - - AdvisedRequest convertedAdvisedRequest = AdvisedRequest.from(chatClientRequest); - assertThat(convertedAdvisedRequest.toPrompt()).isEqualTo(chatClientRequest.prompt()); - assertThat(convertedAdvisedRequest.adviseContext()).containsAllEntriesOf(chatClientRequest.context()); - assertThat(chatClientRequest.context().get(ChatClientAttributes.USER_PARAMS.getKey())).isEqualTo(userParams); - assertThat(chatClientRequest.context().get(ChatClientAttributes.SYSTEM_PARAMS.getKey())) - .isEqualTo(systemParams); - } - - @Test - void whenConvertToAndFromChatClientRequestWithCustomTemplateRenderer() { - ChatModel chatModel = mock(ChatModel.class); - ChatOptions chatOptions = ToolCallingChatOptions.builder().build(); - SystemMessage systemMessage = new SystemMessage("Instructions "); - UserMessage userMessage = UserMessage.builder().text("Question ").media(mock(Media.class)).build(); - Map systemParams = Map.of("name", "Spring AI"); - Map userParams = Map.of("name", "Spring AI"); - - AdvisedRequest advisedRequest = AdvisedRequest.builder() - .chatModel(chatModel) - .chatOptions(chatOptions) - .systemText(systemMessage.getText()) - .systemParams(systemParams) - .userText(userMessage.getText()) - .userParams(userParams) - .media(userMessage.getMedia()) - .build(); - - TemplateRenderer customRenderer = StTemplateRenderer.builder() - .startDelimiterToken('<') - .endDelimiterToken('>') - .build(); - ChatClientRequest chatClientRequest = advisedRequest.toChatClientRequest(customRenderer); - - assertThat(chatClientRequest.prompt().getInstructions()).hasSize(2); - assertThat(chatClientRequest.prompt().getInstructions().get(0)).isInstanceOf(SystemMessage.class); - assertThat(chatClientRequest.prompt().getInstructions().get(1)).isInstanceOf(UserMessage.class); - assertThat(chatClientRequest.context().get(ChatClientAttributes.USER_PARAMS.getKey())).isEqualTo(userParams); - assertThat(chatClientRequest.context().get(ChatClientAttributes.SYSTEM_PARAMS.getKey())) - .isEqualTo(systemParams); - } - - @Test - void whenUsingToPromptWithCustomTemplateRenderer() { - ChatModel chatModel = mock(ChatModel.class); - SystemMessage systemMessage = new SystemMessage("Instructions "); - UserMessage userMessage = UserMessage.builder().text("Question ").media(mock(Media.class)).build(); - Map systemParams = Map.of("name", "Spring AI"); - Map userParams = Map.of("name", "Spring AI"); - - AdvisedRequest advisedRequest = AdvisedRequest.builder() - .chatModel(chatModel) - .systemText(systemMessage.getText()) - .systemParams(systemParams) - .userText(userMessage.getText()) - .userParams(userParams) - .media(userMessage.getMedia()) - .build(); - - TemplateRenderer customRenderer = StTemplateRenderer.builder() - .startDelimiterToken('<') - .endDelimiterToken('>') - .build(); - var prompt = advisedRequest.toPrompt(customRenderer); - - assertThat(prompt.getInstructions()).hasSize(2); - assertThat(prompt.getInstructions().get(0).getText()).isEqualTo("Instructions Spring AI"); - assertThat(prompt.getInstructions().get(1).getText()).isEqualTo("Question Spring AI"); - } - -} diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponseTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponseTests.java deleted file mode 100644 index dd52eea8361..00000000000 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/api/AdvisedResponseTests.java +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Copyright 2023-2025 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.chat.client.advisor.api; - -import java.util.HashMap; -import java.util.Map; - -import org.junit.jupiter.api.Test; - -import org.springframework.ai.chat.client.ChatClientResponse; -import org.springframework.ai.chat.model.ChatResponse; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.Mockito.mock; - -/** - * Unit tests for {@link AdvisedResponse}. - * - * @author Thomas Vitale - */ -class AdvisedResponseTests { - - @Test - void buildAdvisedResponse() { - AdvisedResponse advisedResponse = new AdvisedResponse(mock(ChatResponse.class), Map.of()); - assertThat(advisedResponse).isNotNull(); - } - - @Test - void whenAdviseContextIsNullThenThrows() { - assertThatThrownBy(() -> new AdvisedResponse(mock(ChatResponse.class), null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("adviseContext cannot be null"); - } - - @Test - void whenAdviseContextKeysIsNullThenThrows() { - Map adviseContext = new HashMap<>(); - adviseContext.put(null, "value"); - assertThatThrownBy(() -> new AdvisedResponse(mock(ChatResponse.class), adviseContext)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("adviseContext keys cannot be null"); - } - - @Test - void whenAdviseContextValuesIsNullThenThrows() { - Map adviseContext = new HashMap<>(); - adviseContext.put("key", null); - assertThatThrownBy(() -> new AdvisedResponse(mock(ChatResponse.class), adviseContext)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("adviseContext values cannot be null"); - } - - @Test - void whenBuildFromNullAdvisedResponseThenThrows() { - assertThatThrownBy(() -> AdvisedResponse.from((AdvisedResponse) null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("advisedResponse cannot be null"); - } - - @Test - void buildFromAdvisedResponse() { - AdvisedResponse advisedResponse = new AdvisedResponse(mock(ChatResponse.class), Map.of()); - AdvisedResponse.Builder builder = AdvisedResponse.from(advisedResponse); - assertThat(builder).isNotNull(); - } - - @Test - void whenUpdateFromNullContextThenThrows() { - AdvisedResponse advisedResponse = new AdvisedResponse(mock(ChatResponse.class), Map.of()); - assertThatThrownBy(() -> advisedResponse.updateContext(null)).isInstanceOf(IllegalArgumentException.class) - .hasMessage("contextTransform cannot be null"); - } - - @Test - void whenConvertToAndFromChatClientResponse() { - ChatResponse chatResponse = mock(ChatResponse.class); - Map context = Map.of("key", "value"); - AdvisedResponse advisedResponse = new AdvisedResponse(chatResponse, context); - - ChatClientResponse chatClientResponse = advisedResponse.toChatClientResponse(); - - AdvisedResponse newAdvisedResponse = AdvisedResponse.from(chatClientResponse); - assertThat(newAdvisedResponse).isEqualTo(advisedResponse); - } - -} diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationContextTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationContextTests.java index f548346ba64..06c4f66368a 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationContextTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/observation/AdvisorObservationContextTests.java @@ -18,12 +18,10 @@ import org.junit.jupiter.api.Test; import org.springframework.ai.chat.client.ChatClientRequest; -import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; import org.springframework.ai.chat.prompt.Prompt; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.Mockito.mock; /** * Unit tests for {@link AdvisorObservationContext}. @@ -36,6 +34,7 @@ class AdvisorObservationContextTests { @Test void whenMandatoryOptionsThenReturn() { AdvisorObservationContext observationContext = AdvisorObservationContext.builder() + .chatClientRequest(ChatClientRequest.builder().prompt(new Prompt("Hello")).build()) .advisorName("AdvisorName") .build(); @@ -44,19 +43,17 @@ void whenMandatoryOptionsThenReturn() { @Test void missingAdvisorName() { - assertThatThrownBy(() -> AdvisorObservationContext.builder().build()) - .isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy(() -> AdvisorObservationContext.builder() + .chatClientRequest(ChatClientRequest.builder().prompt(new Prompt("Hello")).build()) + .build()).isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("advisorName cannot be null or empty"); } @Test - void whenBuilderWithAdvisedRequestThenReturn() { - AdvisorObservationContext observationContext = AdvisorObservationContext.builder() - .advisorName("AdvisorName") - .advisedRequest(mock(AdvisedRequest.class)) - .build(); - - assertThat(observationContext).isNotNull(); + void missingChatClientRequest() { + assertThatThrownBy(() -> AdvisorObservationContext.builder().advisorName("AdvisorName").build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("chatClientRequest cannot be null"); } @Test @@ -69,13 +66,4 @@ void whenBuilderWithChatClientRequestThenReturn() { assertThat(observationContext).isNotNull(); } - @Test - void missingBuilderWithBothRequestsThenThrow() { - assertThatThrownBy(() -> AdvisorObservationContext.builder() - .advisedRequest(mock(AdvisedRequest.class)) - .chatClientRequest(ChatClientRequest.builder().prompt(new Prompt()).build()) - .build()).isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("ChatClientRequest and AdvisedRequest cannot be set at the same time"); - } - } diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/observation/DefaultAdvisorObservationConventionTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/observation/DefaultAdvisorObservationConventionTests.java index c24098ad25a..98cc1d66b17 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/observation/DefaultAdvisorObservationConventionTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/observation/DefaultAdvisorObservationConventionTests.java @@ -20,8 +20,10 @@ import io.micrometer.observation.Observation; import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.client.ChatClientRequest; import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationDocumentation.HighCardinalityKeyNames; import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationDocumentation.LowCardinalityKeyNames; +import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.observation.conventions.AiOperationType; import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.ai.observation.conventions.SpringAiKind; @@ -46,6 +48,7 @@ void shouldHaveName() { @Test void contextualName() { AdvisorObservationContext observationContext = AdvisorObservationContext.builder() + .chatClientRequest(ChatClientRequest.builder().prompt(new Prompt("Hello")).build()) .advisorName("MyName") .build(); assertThat(this.observationConvention.getContextualName(observationContext)).isEqualTo("my_name"); @@ -54,6 +57,7 @@ void contextualName() { @Test void supportsAdvisorObservationContext() { AdvisorObservationContext observationContext = AdvisorObservationContext.builder() + .chatClientRequest(ChatClientRequest.builder().prompt(new Prompt("Hello")).build()) .advisorName("MyName") .build(); assertThat(this.observationConvention.supportsContext(observationContext)).isTrue(); @@ -63,6 +67,7 @@ void supportsAdvisorObservationContext() { @Test void shouldHaveLowCardinalityKeyValuesWhenDefined() { AdvisorObservationContext observationContext = AdvisorObservationContext.builder() + .chatClientRequest(ChatClientRequest.builder().prompt(new Prompt("Hello")).build()) .advisorName("MyName") .build(); assertThat(this.observationConvention.getLowCardinalityKeyValues(observationContext)).contains( @@ -75,6 +80,7 @@ void shouldHaveLowCardinalityKeyValuesWhenDefined() { @Test void shouldHaveKeyValuesWhenDefinedAndResponse() { AdvisorObservationContext observationContext = AdvisorObservationContext.builder() + .chatClientRequest(ChatClientRequest.builder().prompt(new Prompt("Hello")).build()) .advisorName("MyName") .order(678) .build(); diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/observation/ChatClientInputContentObservationFilterTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/observation/ChatClientInputContentObservationFilterTests.java deleted file mode 100644 index 70be4c4e620..00000000000 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/observation/ChatClientInputContentObservationFilterTests.java +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Copyright 2023-2025 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.chat.client.observation; - -import java.util.Map; - -import io.micrometer.common.KeyValue; -import io.micrometer.observation.Observation; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.Mock; -import org.mockito.junit.jupiter.MockitoExtension; - -import org.springframework.ai.chat.client.ChatClientAttributes; -import org.springframework.ai.chat.client.ChatClientRequest; -import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation.HighCardinalityKeyNames; -import org.springframework.ai.chat.messages.SystemMessage; -import org.springframework.ai.chat.messages.UserMessage; -import org.springframework.ai.chat.model.ChatModel; -import org.springframework.ai.chat.prompt.Prompt; - -import static org.assertj.core.api.Assertions.assertThat; - -/** - * Unit tests for {@link ChatClientInputContentObservationFilter}. - * - * @author Christian Tzolov - * @author Thomas Vitale - */ -@ExtendWith(MockitoExtension.class) -class ChatClientInputContentObservationFilterTests { - - private final ChatClientInputContentObservationFilter observationFilter = new ChatClientInputContentObservationFilter(); - - @Mock - ChatModel chatModel; - - @Test - void whenNotSupportedObservationContextThenReturnOriginalContext() { - var expectedContext = new Observation.Context(); - var actualContext = this.observationFilter.map(expectedContext); - - assertThat(actualContext).isEqualTo(expectedContext); - } - - @Test - void whenEmptyInputContentThenReturnOriginalContext() { - var request = ChatClientRequest.builder().prompt(new Prompt()).build(); - - var expectedContext = ChatClientObservationContext.builder().request(request).build(); - - var actualContext = this.observationFilter.map(expectedContext); - - assertThat(actualContext).isEqualTo(expectedContext); - } - - @Test - void whenWithTextThenAugmentContext() { - var request = ChatClientRequest.builder() - .prompt(new Prompt(new SystemMessage("sample system text"), new UserMessage("sample user text"))) - .context(ChatClientAttributes.USER_PARAMS.getKey(), Map.of("up1", "upv1")) - .context(ChatClientAttributes.SYSTEM_PARAMS.getKey(), Map.of("sp1", "sp1v")) - .build(); - - var originalContext = ChatClientObservationContext.builder().request(request).build(); - - var augmentedContext = this.observationFilter.map(originalContext); - - assertThat(augmentedContext.getHighCardinalityKeyValues()) - .contains(KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_USER_TEXT.asString(), "sample user text")); - assertThat(augmentedContext.getHighCardinalityKeyValues()) - .contains(KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_USER_PARAMS.asString(), "[\"up1\":\"upv1\"]")); - assertThat(augmentedContext.getHighCardinalityKeyValues()) - .contains(KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_SYSTEM_TEXT.asString(), "sample system text")); - assertThat(augmentedContext.getHighCardinalityKeyValues()) - .contains(KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_SYSTEM_PARAM.asString(), "[\"sp1\":\"sp1v\"]")); - } - -} diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java index a9d52e82b53..313ec8295a2 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/observation/DefaultChatClientObservationConventionTests.java @@ -19,7 +19,6 @@ import java.util.List; import io.micrometer.common.KeyValue; -import io.micrometer.common.KeyValues; import io.micrometer.observation.Observation; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -27,13 +26,11 @@ import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; -import org.springframework.ai.chat.client.ChatClientAttributes; import org.springframework.ai.chat.client.ChatClientRequest; +import org.springframework.ai.chat.client.ChatClientResponse; import org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor; -import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; -import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; +import org.springframework.ai.chat.client.advisor.api.CallAdvisor; +import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain; import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation.HighCardinalityKeyNames; import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation.LowCardinalityKeyNames; import org.springframework.ai.chat.model.ChatModel; @@ -62,8 +59,8 @@ class DefaultChatClientObservationConventionTests { ChatClientRequest request; - static CallAroundAdvisor dummyAdvisor(String name) { - return new CallAroundAdvisor() { + static CallAdvisor dummyAdvisor(String name) { + return new CallAdvisor() { @Override public String getName() { @@ -76,7 +73,8 @@ public int getOrder() { } @Override - public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { + public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, + CallAdvisorChain callAdvisorChain) { return null; } @@ -156,7 +154,7 @@ void shouldHaveOptionalKeyValues() { ChatClientObservationContext observationContext = ChatClientObservationContext.builder() .request(request) - .withFormat("json") + .format("json") .advisors(List.of(dummyAdvisor("advisor1"), dummyAdvisor("advisor2"))) .stream(true) .build(); @@ -166,33 +164,7 @@ void shouldHaveOptionalKeyValues() { ["advisor1", "advisor2"]"""), KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_CONVERSATION_ID.asString(), "007"), KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_TOOL_NAMES.asString(), """ - ["tool1", "tool2", "toolCallback1", "toolCallback2"]"""), - KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_ADVISOR_PARAMS.asString(), """ - ["chat_memory_conversation_id":"007"]"""), - KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_TOOL_FUNCTION_NAMES.asString(), """ - ["tool1", "tool2"]"""), - KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_TOOL_FUNCTION_CALLBACKS.asString(), """ - ["toolCallback1", "toolCallback2"]""")); - } - - @Test - void entriesInAdvisorContextAreNotRemoved() { - var request = ChatClientRequest.builder() - .prompt(new Prompt("")) - .context("advParam1", "advisorParam1Value") - .context(ChatClientAttributes.ADVISORS.getKey(), - List.of(dummyAdvisor("advisor1"), dummyAdvisor("advisor2"))) - .build(); - - ChatClientObservationContext observationContext = ChatClientObservationContext.builder() - .request(request) - .build(); - - assertThat(observationContext.getRequest().context()).hasSize(2); - - this.observationConvention.chatClientAdvisorParams(KeyValues.empty(), observationContext); - - assertThat(observationContext.getRequest().context()).hasSize(2); + ["tool1", "tool2", "toolCallback1", "toolCallback2"]""")); } } diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/prompt/ChatTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/prompt/ChatTests.java deleted file mode 100644 index 711889d2ce1..00000000000 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/prompt/ChatTests.java +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright 2023-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.prompt; - -public class ChatTests { - - // @Test - // void testChat() { - // - // String customerStyle = "American English in a calm and respectful tone"; - // String customerEmail = "Arrr, I be fuming that me blender lid " - // + "flew off and splattered me kitchen walls " - // + "with smoothie! And to make matters worse, " - // + "the warranty don't cover the cost of " - // + "cleaning up me kitchen. I need yer help " - // + "right now, matey!"; - // ChatOpenAi chatOpenAi = new ChatOpenAi(); - // chatOpenAi - // - // } - -} diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/upgrade-notes.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/upgrade-notes.adoc index 78649cc23ba..6db86f730ac 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/upgrade-notes.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/upgrade-notes.adoc @@ -29,6 +29,13 @@ For details, refer to: [[upgrading-to-1-0-0-RC1]] == Upgrading to 1.0.0-RC1 +=== Chat Client And Advisors + +* When building a `Prompt` from the ChatClient input, the `SystemMessage` built from `systemText()` is now placed first in the message list. Before, it was put last, resulting in errors with several model providers. +* In `AbstractChatMemoryAdvisor`, the `doNextWithProtectFromBlockingBefore()` protected method has been changed from accepting the old `AdvisedRequest` to the new `ChatClientRequest`. It’s a breaking change since the alternative was not part of M8. +* `MessageAggregator` has a new method to aggregate messages from `ChatClientRequest`. The previous method aggregating messages from the old `AdvisedRequest` has been removed, since it was already marked as deprecated in M8. +* In `SimpleLoggerAdvisor`, the `requestToString` input argument needs to be updated to use `ChatClientRequest`. It’s a breaking change since the alternative was not part of M8 yet. Same thing about the constructor. + === Breaking Changes The Watson AI model was removed as it was based on the older text generation that is considered outdated as there is a new chat generation model available. Hopefully Watson will reappear in a future version of Spring AI diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/Prompt.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/Prompt.java index 43a043d467a..0cf5134354a 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/Prompt.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/Prompt.java @@ -100,6 +100,20 @@ public List getInstructions() { return this.messages; } + /** + * Get the first system message in the prompt. If no system message is found, an empty + * SystemMessage is returned. + */ + public SystemMessage getSystemMessage() { + for (int i = 0; i <= this.messages.size() - 1; i++) { + Message message = this.messages.get(i); + if (message instanceof SystemMessage systemMessage) { + return systemMessage; + } + } + return new SystemMessage(""); + } + /** * Get the last user message in the prompt. If no user message is found, an empty * UserMessage is returned. @@ -165,11 +179,44 @@ else if (message instanceof ToolResponseMessage toolResponseMessage) { } /** - * @param userMessageAugmenter the function to augment the last user message. - * @return a new prompt instance with the augmented user message. + * Augments the first system message in the prompt with the provided function. If no + * system message is found, a new one is created with the provided text. + * @return a new {@link Prompt} instance with the augmented system message. */ - public Prompt augmentUserMessage(Function userMessageAugmenter) { + public Prompt augmentSystemMessage(Function systemMessageAugmenter) { + + var messagesCopy = new ArrayList<>(this.messages); + for (int i = 0; i <= this.messages.size() - 1; i++) { + Message message = messagesCopy.get(i); + if (message instanceof SystemMessage systemMessage) { + messagesCopy.set(i, systemMessageAugmenter.apply(systemMessage)); + break; + } + if (i == 0) { + // If no system message is found, create a new one with the provided text + // and add it as the first item in the list. + messagesCopy.add(0, systemMessageAugmenter.apply(new SystemMessage(""))); + } + } + return new Prompt(messagesCopy, null == this.chatOptions ? null : this.chatOptions.copy()); + } + + /** + * Augments the last system message in the prompt with the provided text. If no system + * message is found, a new one is created with the provided text. + * @return a new {@link Prompt} instance with the augmented system message. + */ + public Prompt augmentSystemMessage(String newSystemText) { + return augmentSystemMessage(systemMessage -> systemMessage.mutate().text(newSystemText).build()); + } + + /** + * Augments the last user message in the prompt with the provided function. If no user + * message is found, a new one is created with the provided text. + * @return a new {@link Prompt} instance with the augmented user message. + */ + public Prompt augmentUserMessage(Function userMessageAugmenter) { var messagesCopy = new ArrayList<>(this.messages); for (int i = messagesCopy.size() - 1; i >= 0; i--) { Message message = messagesCopy.get(i); @@ -186,11 +233,9 @@ public Prompt augmentUserMessage(Function userMessageA } /** - * Creates a copy of the prompt, replacing the text content of the last UserMessage - * with the provided text. If no UserMessage exists, a new one with the given text is - * added. - * @param newUserText The new text content for the last user message. - * @return A new Prompt instance with the augmented user message text. + * Augments the last user message in the prompt with the provided text. If no user + * message is found, a new one is created with the provided text. + * @return a new {@link Prompt} instance with the augmented user message. */ public Prompt augmentUserMessage(String newUserText) { return augmentUserMessage(userMessage -> userMessage.mutate().text(newUserText).build()); diff --git a/spring-ai-model/src/test/java/org/springframework/ai/chat/prompt/PromptTests.java b/spring-ai-model/src/test/java/org/springframework/ai/chat/prompt/PromptTests.java index 5668fc16193..a9e8f84ecef 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/chat/prompt/PromptTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/chat/prompt/PromptTests.java @@ -152,10 +152,90 @@ void augmentUserMessageWhenNone() { Prompt copy = prompt.augmentUserMessage(message -> message.mutate().text("How are you?").build()); + assertThat(copy.getInstructions().get(copy.getInstructions().size() - 1)).isInstanceOf(UserMessage.class); assertThat(copy.getUserMessage()).isNotNull(); assertThat(copy.getUserMessage().getText()).isEqualTo("How are you?"); assertThat(prompt.getUserMessage()).isNotNull(); assertThat(prompt.getUserMessage().getText()).isEqualTo(""); } + @Test + void getSystemMessageWhenSingle() { + Prompt prompt = Prompt.builder().messages(new SystemMessage("Hello")).build(); + + assertThat(prompt.getSystemMessage()).isNotNull(); + assertThat(prompt.getSystemMessage().getText()).isEqualTo("Hello"); + } + + @Test + void getSystemMessageWhenMultiple() { + Prompt prompt = Prompt.builder() + .messages(new SystemMessage("Hello"), new SystemMessage("How are you?")) + .build(); + + assertThat(prompt.getSystemMessage()).isNotNull(); + assertThat(prompt.getSystemMessage().getText()).isEqualTo("Hello"); + } + + @Test + void getSystemMessageWhenNone() { + Prompt prompt = Prompt.builder().messages(new UserMessage("You'll be back!")).build(); + + assertThat(prompt.getSystemMessage()).isNotNull(); + assertThat(prompt.getSystemMessage().getText()).isEqualTo(""); + + prompt = Prompt.builder().messages(List.of()).build(); + + assertThat(prompt.getSystemMessage()).isNotNull(); + assertThat(prompt.getSystemMessage().getText()).isEqualTo(""); + } + + @Test + void augmentSystemMessageWhenSingle() { + Prompt prompt = Prompt.builder().messages(new SystemMessage("Hello")).build(); + + assertThat(prompt.getSystemMessage()).isNotNull(); + assertThat(prompt.getSystemMessage().getText()).isEqualTo("Hello"); + + Prompt copy = prompt.augmentSystemMessage(message -> message.mutate().text("How are you?").build()); + + assertThat(copy.getSystemMessage()).isNotNull(); + assertThat(copy.getSystemMessage().getText()).isEqualTo("How are you?"); + assertThat(prompt.getSystemMessage()).isNotNull(); + assertThat(prompt.getSystemMessage().getText()).isEqualTo("Hello"); + } + + @Test + void augmentSystemMessageWhenMultiple() { + Prompt prompt = Prompt.builder() + .messages(new SystemMessage("Hello"), new SystemMessage("How are you?")) + .build(); + + assertThat(prompt.getSystemMessage()).isNotNull(); + assertThat(prompt.getSystemMessage().getText()).isEqualTo("Hello"); + + Prompt copy = prompt.augmentSystemMessage(message -> message.mutate().text("What about you?").build()); + + assertThat(copy.getSystemMessage()).isNotNull(); + assertThat(copy.getSystemMessage().getText()).isEqualTo("What about you?"); + assertThat(prompt.getSystemMessage()).isNotNull(); + assertThat(prompt.getSystemMessage().getText()).isEqualTo("Hello"); + } + + @Test + void augmentSystemMessageWhenNone() { + Prompt prompt = Prompt.builder().messages(new UserMessage("You'll be back!")).build(); + + assertThat(prompt.getSystemMessage()).isNotNull(); + assertThat(prompt.getSystemMessage().getText()).isEqualTo(""); + + Prompt copy = prompt.augmentSystemMessage(message -> message.mutate().text("How are you?").build()); + + assertThat(copy.getInstructions().get(0)).isInstanceOf(SystemMessage.class); + assertThat(copy.getSystemMessage()).isNotNull(); + assertThat(copy.getSystemMessage().getText()).isEqualTo("How are you?"); + assertThat(prompt.getSystemMessage()).isNotNull(); + assertThat(prompt.getSystemMessage().getText()).isEqualTo(""); + } + } diff --git a/spring-ai-rag/src/main/java/org/springframework/ai/rag/advisor/RetrievalAugmentationAdvisor.java b/spring-ai-rag/src/main/java/org/springframework/ai/rag/advisor/RetrievalAugmentationAdvisor.java index ac61071831e..69521638d88 100644 --- a/spring-ai-rag/src/main/java/org/springframework/ai/rag/advisor/RetrievalAugmentationAdvisor.java +++ b/spring-ai-rag/src/main/java/org/springframework/ai/rag/advisor/RetrievalAugmentationAdvisor.java @@ -25,8 +25,6 @@ import reactor.core.scheduler.Scheduler; -import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; -import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; import org.springframework.ai.chat.client.advisor.api.BaseAdvisor; import org.springframework.ai.chat.client.ChatClientRequest; import org.springframework.ai.chat.client.ChatClientResponse; @@ -100,16 +98,6 @@ public static Builder builder() { return new Builder(); } - /** - * @deprecated in favour of {@link #before(ChatClientRequest, AdvisorChain)} - */ - @Override - @Deprecated - public AdvisedRequest before(AdvisedRequest advisedRequest) { - ChatClientRequest chatClientRequest = advisedRequest.toChatClientRequest(); - return AdvisedRequest.from(before(chatClientRequest, null)); - } - @Override public ChatClientRequest before(ChatClientRequest chatClientRequest, @Nullable AdvisorChain advisorChain) { Map context = new HashMap<>(chatClientRequest.context()); @@ -163,16 +151,6 @@ private Map.Entry> getDocumentsForQuery(Query query) { return Map.entry(query, documents); } - /** - * @deprecated in favour of {@link #after(ChatClientResponse, AdvisorChain)} - */ - @Override - @Deprecated - public AdvisedResponse after(AdvisedResponse advisedResponse) { - ChatClientResponse chatClientResponse = advisedResponse.toChatClientResponse(); - return AdvisedResponse.from(after(chatClientResponse, null)); - } - @Override public ChatClientResponse after(ChatClientResponse chatClientResponse, @Nullable AdvisorChain advisorChain) { ChatResponse.Builder chatResponseBuilder; diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/pgvector/PgVectorStoreWithChatMemoryAdvisorIT.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/pgvector/PgVectorStoreWithChatMemoryAdvisorIT.java index a41c6a6ef1c..fd616b129a5 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/pgvector/PgVectorStoreWithChatMemoryAdvisorIT.java +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/pgvector/PgVectorStoreWithChatMemoryAdvisorIT.java @@ -18,6 +18,7 @@ import java.util.List; import java.util.Map; +import java.util.UUID; import org.jetbrains.annotations.NotNull; import org.junit.jupiter.api.DisplayName; @@ -26,6 +27,7 @@ import org.mockito.ArgumentMatchers; import org.mockito.Mockito; import org.postgresql.ds.PGSimpleDataSource; +import org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor; import org.testcontainers.containers.PostgreSQLContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; @@ -75,22 +77,20 @@ class PgVectorStoreWithChatMemoryAdvisorIT { return chatModel; } - private static void initStore(PgVectorStore store) throws Exception { + private static void initStore(PgVectorStore store, String conversationId) { store.afterPropertiesSet(); // fill the store - store.add(List.of(new Document("Tell me a good joke", Map.of("conversationId", "default")), - new Document("Tell me a bad joke", Map.of("conversationId", "default", "messageType", "USER")))); + store.add(List.of(new Document("Tell me a good joke", Map.of("conversationId", conversationId)), + new Document("Tell me a bad joke", Map.of("conversationId", conversationId, "messageType", "USER")))); } private static PgVectorStore createPgVectorStoreUsingTestcontainer(EmbeddingModel embeddingModel) throws Exception { JdbcTemplate jdbcTemplate = createJdbcTemplateWithConnectionToTestcontainer(); - PgVectorStore vectorStore = PgVectorStore.builder(jdbcTemplate, embeddingModel) + return PgVectorStore.builder(jdbcTemplate, embeddingModel) .dimensions(3) // match // embeddings .initializeSchema(true) .build(); - initStore(vectorStore); - return vectorStore; } private static @NotNull JdbcTemplate createJdbcTemplateWithConnectionToTestcontainer() { @@ -105,7 +105,7 @@ private static void verifyRequestHasBeenAdvisedWithMessagesFromVectorStore(ChatM ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); verify(chatModel).call(promptCaptor.capture()); assertThat(promptCaptor.getValue().getInstructions().get(0)).isInstanceOf(SystemMessage.class); - assertThat(promptCaptor.getValue().getInstructions().get(0).getText()).isEqualTo(""" + assertThat(promptCaptor.getValue().getInstructions().get(0).getText()).isEqualToIgnoringWhitespace(""" Use the long term conversation memory from the LONG_TERM_MEMORY section to provide accurate answers. @@ -129,19 +129,59 @@ void advisedChatShouldHaveSimilarMessagesFromVectorStore() throws Exception { // faked embedding model EmbeddingModel embeddingModel = embeddingNModelShouldAlwaysReturnFakedEmbed(); PgVectorStore store = createPgVectorStoreUsingTestcontainer(embeddingModel); + String conversationId = UUID.randomUUID().toString(); + initStore(store, conversationId); // do the chat ChatClient.builder(chatModel) .build() .prompt() .user("joke") - .advisors(VectorStoreChatMemoryAdvisor.builder(store).build()) + .advisors(a -> a.advisors(VectorStoreChatMemoryAdvisor.builder(store).build()) + .param(AbstractChatMemoryAdvisor.CHAT_MEMORY_CONVERSATION_ID_KEY, conversationId)) .call() .chatResponse(); verifyRequestHasBeenAdvisedWithMessagesFromVectorStore(chatModel); } + @Test + void advisedChatShouldHaveSimilarMessagesFromVectorStoreWhenSystemMessageProvided() throws Exception { + // faked ChatModel + ChatModel chatModel = chatModelAlwaysReturnsTheSameReply(); + // faked embedding model + EmbeddingModel embeddingModel = embeddingNModelShouldAlwaysReturnFakedEmbed(); + PgVectorStore store = createPgVectorStoreUsingTestcontainer(embeddingModel); + String conversationId = UUID.randomUUID().toString(); + initStore(store, conversationId); + + // do the chat + ChatClient.builder(chatModel) + .build() + .prompt() + .system("You are a helpful assistant.") + .user("joke") + .advisors(a -> a.advisors(VectorStoreChatMemoryAdvisor.builder(store).build()) + .param(AbstractChatMemoryAdvisor.CHAT_MEMORY_CONVERSATION_ID_KEY, conversationId)) + .call() + .chatResponse(); + + ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); + verify(chatModel).call(promptCaptor.capture()); + assertThat(promptCaptor.getValue().getInstructions().get(0)).isInstanceOf(SystemMessage.class); + assertThat(promptCaptor.getValue().getInstructions().get(0).getText()).isEqualToIgnoringWhitespace(""" + You are a helpful assistant. + + Use the long term conversation memory from the LONG_TERM_MEMORY section to provide accurate answers. + + --------------------- + LONG_TERM_MEMORY: + Tell me a good joke + Tell me a bad joke + --------------------- + """); + } + @SuppressWarnings("unchecked") private @NotNull EmbeddingModel embeddingNModelShouldAlwaysReturnFakedEmbed() { EmbeddingModel embeddingModel = mock(EmbeddingModel.class);