diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClientAttributes.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClientAttributes.java index bd8ee2c51d7..02b26768317 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClientAttributes.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClientAttributes.java @@ -21,10 +21,7 @@ * * @author Thomas Vitale * @since 1.0.0 - * @deprecated only introduced to smooth the transition to the new APIs and ensure - * backward compatibility */ -@Deprecated public enum ChatClientAttributes { //@formatter:off @@ -33,7 +30,6 @@ public enum ChatClientAttributes { ADVISORS("spring.ai.chat.client.advisors"), @Deprecated // Only for backward compatibility until the next release. CHAT_MODEL("spring.ai.chat.client.model"), - @Deprecated // Only for backward compatibility until the next release. OUTPUT_FORMAT("spring.ai.chat.client.output.format"), @Deprecated // Only for backward compatibility until the next release. USER_PARAMS("spring.ai.chat.client.user.params"), diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java index 1ecb7cdcad8..c58bd362088 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java @@ -496,11 +496,13 @@ private ChatClientResponse doGetObservableChatClientResponse(ChatClientRequest c private ChatClientResponse doGetObservableChatClientResponse(ChatClientRequest chatClientRequest, @Nullable String outputFormat) { - ChatClientRequest formattedChatClientRequest = StringUtils.hasText(outputFormat) - ? augmentPromptWithFormatInstructions(chatClientRequest, outputFormat) : chatClientRequest; + + if (outputFormat != null) { + chatClientRequest.context().put(ChatClientAttributes.OUTPUT_FORMAT.getKey(), outputFormat); + } ChatClientObservationContext observationContext = ChatClientObservationContext.builder() - .request(formattedChatClientRequest) + .request(chatClientRequest) .advisors(advisorChain.getCallAdvisors()) .stream(false) .withFormat(outputFormat) @@ -510,7 +512,7 @@ private ChatClientResponse doGetObservableChatClientResponse(ChatClientRequest c DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION, () -> observationContext, observationRegistry); var chatClientResponse = observation.observe(() -> { // Apply the advisor chain that terminates with the ChatModelCallAdvisor. - return advisorChain.nextCall(formattedChatClientRequest); + return advisorChain.nextCall(chatClientRequest); }); return chatClientResponse != null ? chatClientResponse : ChatClientResponse.builder().build(); } diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/ChatModelCallAdvisor.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/ChatModelCallAdvisor.java index 68ccd8cb8f4..d2dd9cf4d62 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/ChatModelCallAdvisor.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/ChatModelCallAdvisor.java @@ -16,14 +16,17 @@ package org.springframework.ai.chat.client.advisor; +import org.springframework.ai.chat.client.ChatClientAttributes; import org.springframework.ai.chat.client.ChatClientRequest; import org.springframework.ai.chat.client.ChatClientResponse; import org.springframework.ai.chat.client.advisor.api.CallAdvisor; import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; import org.springframework.core.Ordered; import org.springframework.util.Assert; +import org.springframework.util.StringUtils; import java.util.Map; @@ -46,9 +49,29 @@ private ChatModelCallAdvisor(ChatModel chatModel) { public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAroundAdvisorChain chain) { Assert.notNull(chatClientRequest, "the chatClientRequest cannot be null"); - ChatResponse chatResponse = chatModel.call(chatClientRequest.prompt()); + ChatClientRequest formattedChatClientRequest = augmentWithFormatInstructions(chatClientRequest); + + ChatResponse chatResponse = chatModel.call(formattedChatClientRequest.prompt()); return ChatClientResponse.builder() .chatResponse(chatResponse) + .context(Map.copyOf(formattedChatClientRequest.context())) + .build(); + } + + private static ChatClientRequest augmentWithFormatInstructions(ChatClientRequest chatClientRequest) { + String outputFormat = (String) chatClientRequest.context().get(ChatClientAttributes.OUTPUT_FORMAT.getKey()); + + if (!StringUtils.hasText(outputFormat)) { + return chatClientRequest; + } + + Prompt augmentedPrompt = chatClientRequest.prompt() + .augmentUserMessage(userMessage -> userMessage.mutate() + .text(userMessage.getText() + System.lineSeparator() + outputFormat) + .build()); + + return ChatClientRequest.builder() + .prompt(augmentedPrompt) .context(Map.copyOf(chatClientRequest.context())) .build(); } diff --git a/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/client/advisor/QuestionAnswerAdvisorIT.java b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/client/advisor/QuestionAnswerAdvisorIT.java index 05893ebd10b..cee8e158cd2 100644 --- a/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/client/advisor/QuestionAnswerAdvisorIT.java +++ b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/client/advisor/QuestionAnswerAdvisorIT.java @@ -163,6 +163,28 @@ void qaCustomPromptTemplate() { evaluateRelevancy(question, chatResponse); } + @Test + void qaOutputConverter() { + String question = "Where does the adventure of Anacletus and Birba take place?"; + + QuestionAnswerAdvisor qaAdvisor = QuestionAnswerAdvisor.builder(this.pgVectorStore).build(); + + Answer answer = ChatClient.builder(this.openAiChatModel) + .build() + .prompt(question) + .advisors(qaAdvisor) + .call() + .entity(Answer.class); + + assertThat(answer).isNotNull(); + + System.out.println(answer); + assertThat(answer.content()).containsIgnoringCase("Highlands"); + } + + private record Answer(String content) { + } + private void evaluateRelevancy(String question, ChatResponse chatResponse) { EvaluationRequest evaluationRequest = new EvaluationRequest(question, chatResponse.getMetadata().get(QuestionAnswerAdvisor.RETRIEVED_DOCUMENTS),