Skip to content

Commit 5d9a47f

Browse files
committed
Generalize the Protect From Blocking functionality accross all advisors
1 parent 7610897 commit 5d9a47f

File tree

5 files changed

+130
-30
lines changed

5 files changed

+130
-30
lines changed

spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,20 @@
1717
package org.springframework.ai.chat.client.advisor;
1818

1919
import java.util.Map;
20+
import java.util.function.Function;
2021

22+
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
23+
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
2124
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
2225
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor;
26+
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain;
2327
import org.springframework.core.Ordered;
2428
import org.springframework.util.Assert;
2529

30+
import reactor.core.publisher.Flux;
31+
import reactor.core.publisher.Mono;
32+
import reactor.core.scheduler.Schedulers;
33+
2634
/**
2735
* Abstract class that serves as a base for chat memory advisors.
2836
*
@@ -46,11 +54,14 @@ public abstract class AbstractChatMemoryAdvisor<T> implements CallAroundAdvisor,
4654

4755
protected final int defaultChatMemoryRetrieveSize;
4856

57+
private final boolean protectFromBlocking;
58+
4959
public AbstractChatMemoryAdvisor(T chatMemory) {
50-
this(chatMemory, DEFAULT_CHAT_MEMORY_CONVERSATION_ID, DEFAULT_CHAT_MEMORY_RESPONSE_SIZE);
60+
this(chatMemory, DEFAULT_CHAT_MEMORY_CONVERSATION_ID, DEFAULT_CHAT_MEMORY_RESPONSE_SIZE, true);
5161
}
5262

53-
public AbstractChatMemoryAdvisor(T chatMemory, String defaultConversationId, int defaultChatMemoryRetrieveSize) {
63+
public AbstractChatMemoryAdvisor(T chatMemory, String defaultConversationId, int defaultChatMemoryRetrieveSize,
64+
boolean protectFromBlocking) {
5465

5566
Assert.notNull(chatMemory, "The chatMemory must not be null!");
5667
Assert.hasText(defaultConversationId, "The conversationId must not be empty!");
@@ -59,6 +70,7 @@ public AbstractChatMemoryAdvisor(T chatMemory, String defaultConversationId, int
5970
this.chatMemoryStore = chatMemory;
6071
this.defaultConversationId = defaultConversationId;
6172
this.defaultChatMemoryRetrieveSize = defaultChatMemoryRetrieveSize;
73+
this.protectFromBlocking = protectFromBlocking;
6274
}
6375

6476
@Override
@@ -90,4 +102,20 @@ protected int doGetChatMemoryRetrieveSize(Map<String, Object> context) {
90102
: this.defaultChatMemoryRetrieveSize;
91103
}
92104

105+
protected Flux<AdvisedResponse> doNextWithProtectFromBlockingBefore(AdvisedRequest advisedRequest,
106+
StreamAroundAdvisorChain chain, Function<AdvisedRequest, AdvisedRequest> beforeAdvise) {
107+
108+
// This can be executed by both blocking and non-blocking Threads
109+
// E.g. a command line or Tomcat blocking Thread implementation
110+
// or by a WebFlux dispatch in a non-blocking manner.
111+
return (this.protectFromBlocking) ?
112+
// @formatter:off
113+
Mono.just(advisedRequest)
114+
.publishOn(Schedulers.boundedElastic())
115+
.map(beforeAdvise)
116+
.flatMapMany(request -> chain.nextAroundStream(request))
117+
: chain.nextAroundStream(beforeAdvise.apply(advisedRequest));
118+
}
119+
120+
93121
}

spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ public MessageChatMemoryAdvisor(ChatMemory chatMemory) {
4343
}
4444

4545
public MessageChatMemoryAdvisor(ChatMemory chatMemory, String defaultConversationId, int chatHistoryWindowSize) {
46-
super(chatMemory, defaultConversationId, chatHistoryWindowSize);
46+
super(chatMemory, defaultConversationId, chatHistoryWindowSize, true);
4747
}
4848

4949
@Override
@@ -61,9 +61,8 @@ public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvis
6161
@Override
6262
public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
6363

64-
advisedRequest = this.before(advisedRequest);
65-
66-
Flux<AdvisedResponse> advisedResponses = chain.nextAroundStream(advisedRequest);
64+
Flux<AdvisedResponse> advisedResponses = this.doNextWithProtectFromBlockingBefore(advisedRequest, chain,
65+
this::before);
6766

6867
return new MessageAggregator().aggregateAdvisedResponse(advisedResponses, this::observeAfter);
6968
}

spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ public PromptChatMemoryAdvisor(ChatMemory chatMemory, String systemTextAdvise) {
6666

6767
public PromptChatMemoryAdvisor(ChatMemory chatMemory, String defaultConversationId, int chatHistoryWindowSize,
6868
String systemTextAdvise) {
69-
super(chatMemory, defaultConversationId, chatHistoryWindowSize);
69+
super(chatMemory, defaultConversationId, chatHistoryWindowSize, true);
7070
this.systemTextAdvise = systemTextAdvise;
7171
}
7272

@@ -85,9 +85,8 @@ public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvis
8585
@Override
8686
public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
8787

88-
advisedRequest = this.before(advisedRequest);
89-
90-
Flux<AdvisedResponse> advisedResponses = chain.nextAroundStream(advisedRequest);
88+
Flux<AdvisedResponse> advisedResponses = this.doNextWithProtectFromBlockingBefore(advisedRequest, chain,
89+
this::before);
9190

9291
return new MessageAggregator().aggregateAdvisedResponse(advisedResponses, this::observeAfter);
9392
}

spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java

Lines changed: 91 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
import org.springframework.util.StringUtils;
4040

4141
import reactor.core.publisher.Flux;
42+
import reactor.core.publisher.Mono;
43+
import reactor.core.scheduler.Schedulers;
4244

4345
/**
4446
* Context for the question is retrieved from a Vector Store and added to the prompt's
@@ -69,10 +71,24 @@ public class QuestionAnswerAdvisor implements CallAroundAdvisor, StreamAroundAdv
6971

7072
public static final String FILTER_EXPRESSION = "qa_filter_expression";
7173

74+
private final boolean protectFromBlocking;
75+
76+
/**
77+
* The QuestionAnswerAdvisor retrieves context information from a Vector Store and
78+
* combines it with the user's text.
79+
* @param vectorStore The vector store to use
80+
*/
7281
public QuestionAnswerAdvisor(VectorStore vectorStore) {
7382
this(vectorStore, SearchRequest.defaults(), DEFAULT_USER_TEXT_ADVISE);
7483
}
7584

85+
/**
86+
* The QuestionAnswerAdvisor retrieves context information from a Vector Store and
87+
* combines it with the user's text.
88+
* @param vectorStore The vector store to use
89+
* @param searchRequest The search request defined using the portable filter
90+
* expression syntax
91+
*/
7692
public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest) {
7793
this(vectorStore, searchRequest, DEFAULT_USER_TEXT_ADVISE);
7894
}
@@ -85,9 +101,26 @@ public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchReques
85101
* expression syntax
86102
* @param userTextAdvise the user text to append to the existing user prompt. The text
87103
* should contain a placeholder named "question_answer_context".
88-
*
89104
*/
90105
public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest, String userTextAdvise) {
106+
this(vectorStore, searchRequest, userTextAdvise, true);
107+
}
108+
109+
/**
110+
* The QuestionAnswerAdvisor retrieves context information from a Vector Store and
111+
* combines it with the user's text.
112+
* @param vectorStore The vector store to use
113+
* @param searchRequest The search request defined using the portable filter
114+
* expression syntax
115+
* @param userTextAdvise the user text to append to the existing user prompt. The text
116+
* should contain a placeholder named "question_answer_context".
117+
* @param protectFromBlocking if true the advisor will protect the execution from
118+
* blocking threads. If false the advisor will not protect the execution from blocking
119+
* threads. This is useful when the advisor is used in a non-blocking environment. It
120+
* is true by default.
121+
*/
122+
public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest, String userTextAdvise,
123+
boolean protectFromBlocking) {
91124

92125
Assert.notNull(vectorStore, "The vectorStore must not be null!");
93126
Assert.notNull(searchRequest, "The searchRequest must not be null!");
@@ -96,6 +129,7 @@ public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchReques
96129
this.vectorStore = vectorStore;
97130
this.searchRequest = searchRequest;
98131
this.userTextAdvise = userTextAdvise;
132+
this.protectFromBlocking = protectFromBlocking;
99133
}
100134

101135
@Override
@@ -121,9 +155,19 @@ public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvis
121155
@Override
122156
public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
123157

124-
advisedRequest = before(advisedRequest);
125-
126-
return chain.nextAroundStream(advisedRequest).map(ar -> {
158+
// This can be executed by both blocking and non-blocking Threads
159+
// E.g. a command line or Tomcat blocking Thread implementation
160+
// or by a WebFlux dispatch in a non-blocking manner.
161+
Flux<AdvisedResponse> advisedResponses = (this.protectFromBlocking) ?
162+
// @formatter:off
163+
Mono.just(advisedRequest)
164+
.publishOn(Schedulers.boundedElastic())
165+
.map(this::before)
166+
.flatMapMany(request -> chain.nextAroundStream(request))
167+
: chain.nextAroundStream(before(advisedRequest));
168+
// @formatter:on
169+
170+
return advisedResponses.map(ar -> {
127171
if (onFinishReason().test(ar)) {
128172
ar = after(ar);
129173
}
@@ -191,4 +235,47 @@ private Predicate<AdvisedResponse> onFinishReason() {
191235
.isPresent();
192236
}
193237

238+
public static Builder builder(VectorStore vectorStore) {
239+
return new Builder(vectorStore);
240+
}
241+
242+
public static class Builder {
243+
244+
private final VectorStore vectorStore;
245+
246+
private SearchRequest searchRequest = SearchRequest.defaults();
247+
248+
private String userTextAdvise = DEFAULT_USER_TEXT_ADVISE;
249+
250+
private boolean protectFromBlocking = true;
251+
252+
private Builder(VectorStore vectorStore) {
253+
Assert.notNull(vectorStore, "The vectorStore must not be null!");
254+
this.vectorStore = vectorStore;
255+
}
256+
257+
public Builder withSearchRequest(SearchRequest searchRequest) {
258+
Assert.notNull(searchRequest, "The searchRequest must not be null!");
259+
this.searchRequest = searchRequest;
260+
return this;
261+
}
262+
263+
public Builder withUserTextAdvise(String userTextAdvise) {
264+
Assert.hasText(userTextAdvise, "The userTextAdvise must not be empty!");
265+
this.userTextAdvise = userTextAdvise;
266+
return this;
267+
}
268+
269+
public Builder withProtectFromBlocking(boolean protectFromBlocking) {
270+
this.protectFromBlocking = protectFromBlocking;
271+
return this;
272+
}
273+
274+
public QuestionAnswerAdvisor build() {
275+
return new QuestionAnswerAdvisor(this.vectorStore, this.searchRequest, this.userTextAdvise,
276+
this.protectFromBlocking);
277+
}
278+
279+
}
280+
194281
}

spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/VectorStoreChatMemoryAdvisor.java

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@
3636
import org.springframework.ai.vectorstore.VectorStore;
3737

3838
import reactor.core.publisher.Flux;
39-
import reactor.core.publisher.Mono;
40-
import reactor.core.scheduler.Schedulers;
4139

4240
/**
4341
* Memory is retrieved from a VectorStore added into the prompt's system text.
@@ -64,8 +62,6 @@ public class VectorStoreChatMemoryAdvisor extends AbstractChatMemoryAdvisor<Vect
6462

6563
private final String systemTextAdvise;
6664

67-
private final boolean protectFromBlocking = true;
68-
6965
public VectorStoreChatMemoryAdvisor(VectorStore vectorStore) {
7066
this(vectorStore, DEFAULT_SYSTEM_TEXT_ADVISE);
7167
}
@@ -82,7 +78,7 @@ public VectorStoreChatMemoryAdvisor(VectorStore vectorStore, String defaultConve
8278

8379
public VectorStoreChatMemoryAdvisor(VectorStore vectorStore, String defaultConversationId,
8480
int chatHistoryWindowSize, String systemTextAdvise) {
85-
super(vectorStore, defaultConversationId, chatHistoryWindowSize);
81+
super(vectorStore, defaultConversationId, chatHistoryWindowSize, true);
8682
this.systemTextAdvise = systemTextAdvise;
8783
}
8884

@@ -101,17 +97,8 @@ public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvis
10197
@Override
10298
public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
10399

104-
// This can be executed by both blocking and non-blocking Threads
105-
// E.g. a command line or Tomcat blocking Thread implementation
106-
// or by a WebFlux dispatch in a non-blocking manner.
107-
Flux<AdvisedResponse> advisedResponses = (this.protectFromBlocking) ?
108-
// @formatter:off
109-
Mono.just(advisedRequest)
110-
.publishOn(Schedulers.boundedElastic())
111-
.map(this::before)
112-
.flatMapMany(request -> chain.nextAroundStream(request))
113-
: chain.nextAroundStream(this.before(advisedRequest));
114-
// @formatter:on
100+
Flux<AdvisedResponse> advisedResponses = this.doNextWithProtectFromBlockingBefore(advisedRequest, chain,
101+
this::before);
115102

116103
// The observeAfter will certainly be executed on non-blocking Threads in case
117104
// of some models - e.g. when the model client is a WebClient

0 commit comments

Comments
 (0)