Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
f2bcc5b
Refactor advisor architecture in Spring AI
tzolov Sep 5, 2024
9d4c146
Refactored advisors even further
chemicL Sep 27, 2024
46d8dd9
Minor code style improvements
tzolov Sep 27, 2024
add5c9f
Add missing javadocs
tzolov Sep 27, 2024
087551d
Improve advisor chain management in DefaultChatClient and DefaultArou…
tzolov Sep 27, 2024
48e879f
Fix the protectFromBlocking model
tzolov Sep 27, 2024
8774c8b
Generalize the Protect From Blocking functionality accross all advisors
tzolov Sep 28, 2024
269ffa3
Add Advisors documentation
tzolov Sep 28, 2024
73795ae
improve advisors doc
tzolov Sep 28, 2024
62eb7eb
Another round of adviser doc improvments
tzolov Sep 28, 2024
71f6462
additional advisor doc improvements
tzolov Sep 28, 2024
58d1f30
Improve advisor ordering
tzolov Sep 29, 2024
b447ad3
strealine docs
tzolov Sep 29, 2024
3cfa79b
more doc streamlining
tzolov Sep 29, 2024
a290b91
Update class diagram
tzolov Sep 29, 2024
8f0b0c9
update class diagram
tzolov Sep 29, 2024
c583bf4
Update chatclinet docs to reflect advisors changes
tzolov Sep 29, 2024
4a08866
Add corss refference
tzolov Sep 29, 2024
9010da8
Add order field to AdvisorObservationContext
tzolov Sep 29, 2024
3b27d5e
Removed StreamAggregationAdvisor and related comments for now
chemicL Sep 30, 2024
eb9d3b0
Removed missed commented code
chemicL Sep 30, 2024
1e2bd57
Move advisors docs under the chatclient section
tzolov Sep 29, 2024
63f6a11
docs improvments
tzolov Sep 30, 2024
c659223
udpate the chatclient advisors samples
tzolov Oct 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@
import org.junit.jupiter.params.provider.ValueSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.AdvisedRequest;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.api.RequestAdvisor;
import org.springframework.ai.chat.client.advisor.api.ResponseAdvisor;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
import org.springframework.ai.converter.BeanOutputConverter;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.openai.OpenAiChatModel;
Expand Down Expand Up @@ -65,7 +65,7 @@ public class OpenAiPaymentTransactionIT {
record TransactionStatusResponse(String id, String status) {
}

private static class LoggingAdvisor implements RequestAdvisor, ResponseAdvisor {
private static class LoggingAdvisor implements CallAroundAdvisor {

private final Logger logger = LoggerFactory.getLogger(LoggingAdvisor.class);

Expand All @@ -74,7 +74,23 @@ public String getName() {
}

@Override
public AdvisedRequest adviseRequest(AdvisedRequest request, Map<String, Object> context) {
public int getOrder() {
return 0;
}

@Override
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {

advisedRequest = this.before(advisedRequest);

AdvisedResponse advisedResponse = chain.nextAroundCall(advisedRequest);

this.observeAfter(advisedResponse);

return advisedResponse;
}

private AdvisedRequest before(AdvisedRequest request) {
logger.info("System text: \n" + request.systemText());
logger.info("System params: " + request.systemParams());
logger.info("User text: \n" + request.userText());
Expand All @@ -86,10 +102,8 @@ public AdvisedRequest adviseRequest(AdvisedRequest request, Map<String, Object>
return request;
}

@Override
public ChatResponse adviseResponse(ChatResponse response, Map<String, Object> context) {
logger.info("Response: " + response);
return response;
private void observeAfter(AdvisedResponse advisedResponse) {
logger.info("Response: " + advisedResponse.response());
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.Map;
import java.util.stream.Collectors;

import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.junit.jupiter.params.ParameterizedTest;
Expand Down Expand Up @@ -64,6 +65,38 @@ class OpenAiChatClientIT extends AbstractIT {
record ActorsFilms(String actor, List<String> movies) {
}

@Test
@Disabled("Although the Re2 advisor improves the response correctness it is not always guarantied to work.")
void re2() {
// .user(" Could Scooby Doo fit in a Kangaroo Pouch? Choices: (A) Yes (B) No")
// .user("Roger has 5 tennis balls. He buys 2 more cans of tennis " +
// "balls. Each can has 3 tennis balls. How many tennis balls " +
// "does he have now?")

String REASON_QUESTION = """
What do these words have in common?
Freight Stone Often Canine.
""";

// @formatter:off
ChatClient chatClient = ChatClient.builder(chatModel)
.defaultOptions(OpenAiChatOptions.builder()
.withModel(OpenAiApi.ChatModel.GPT_4_O.getValue()).build())
.defaultUser(REASON_QUESTION)
.build();

String response = chatClient.prompt()
.advisors(new ReReadingAdvisor())
.call()
.content();
// @formatter:on

logger.info("" + response);
assertThat(response.toLowerCase().replace("(", " ").replace(")", " ").replace("\"", " ").replace("\"", " "))
.contains(" eight", " one", " ten", " nine");

}

@Test
void call() {

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Copyright 2024-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.openai.chat.client;

import java.util.HashMap;
import java.util.Map;

import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor;
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain;

import reactor.core.publisher.Flux;

/**
* Drawing inspiration from the human strategy of re-reading, this advisor implements a
* re-reading strategy for LLM reasoning, dubbed RE2, to enhance understanding in the
* input phase. Based on the article:
* <a href="https://arxiv.org/pdf/2309.06275">Re-Reading Improves Reasoning in Large
* Language Models</a>
*
* @author Christian Tzolov
* @since 1.0.0
*/
public class ReReadingAdvisor implements CallAroundAdvisor, StreamAroundAdvisor {

private static final String DEFAULT_RE2_ADVISE_TEMPLATE = """
{re2_input_query}
Read the question again: {re2_input_query}
""";

private final String re2AdviseTemplate;

private int order = 0;

public ReReadingAdvisor() {
this(DEFAULT_RE2_ADVISE_TEMPLATE);
}

public ReReadingAdvisor(String re2AdviseTemplate) {
this.re2AdviseTemplate = re2AdviseTemplate;
}

public String getName() {
return this.getClass().getSimpleName();
}

private AdvisedRequest before(AdvisedRequest advisedRequest) {

Map<String, Object> advisedUserParams = new HashMap<>(advisedRequest.userParams());
advisedUserParams.put("re2_input_query", advisedRequest.userText());

return AdvisedRequest.from(advisedRequest)
.withUserText(this.re2AdviseTemplate)
.withUserParams(advisedUserParams)
.build();
}

@Override
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
return chain.nextAroundCall(this.before(advisedRequest));
}

@Override
public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
return chain.nextAroundStream(this.before(advisedRequest));
}

@Override
public int getOrder() {
return this.order;
}

public ReReadingAdvisor withOrder(int order) {
this.order = order;
return this;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.AdvisedRequest;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.api.RequestAdvisor;
import org.springframework.ai.chat.client.advisor.api.ResponseAdvisor;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.model.function.FunctionCallbackWrapper.Builder.SchemaType;
import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel;
Expand Down Expand Up @@ -65,7 +65,7 @@ public class VertexAiGeminiPaymentTransactionIT {
record TransactionStatusResponse(String id, String status) {
}

private static class LoggingAdvisor implements RequestAdvisor, ResponseAdvisor {
private static class LoggingAdvisor implements CallAroundAdvisor {

private final Logger logger = LoggerFactory.getLogger(LoggingAdvisor.class);

Expand All @@ -75,7 +75,18 @@ public String getName() {
}

@Override
public AdvisedRequest adviseRequest(AdvisedRequest request, Map<String, Object> context) {
public int getOrder() {
return 0;
}

@Override
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
var response = chain.nextAroundCall(before(advisedRequest));
observeAfter(response);
return response;
}

private AdvisedRequest before(AdvisedRequest request) {
logger.info("System text: \n" + request.systemText());
logger.info("System params: " + request.systemParams());
logger.info("User text: \n" + request.userText());
Expand All @@ -87,10 +98,8 @@ public AdvisedRequest adviseRequest(AdvisedRequest request, Map<String, Object>
return request;
}

@Override
public ChatResponse adviseResponse(ChatResponse response, Map<String, Object> context) {
logger.info("Response: " + response);
return response;
private void observeAfter(AdvisedResponse advisedResponse) {
logger.info("Response: " + advisedResponse.response());
}

}
Expand Down
Loading