Skip to content

Commit 2198e5a

Browse files
committed
Improve advisor chain management in DefaultChatClient and DefaultAroundAdvisorChain
- Refactor DefaultChatClient to use a builder pattern for advisor chain construction - Update DefaultAroundAdvisorChain to separate call and stream advisors - Improve advisor ordering and management in DefaultAroundAdvisorChain.Builder - Enhance AdvisorsTests to verify correct advisor execution order - Remove redundant reordering logic from DefaultAroundAdvisorChain
1 parent cbe711f commit 2198e5a

File tree

3 files changed

+158
-130
lines changed

3 files changed

+158
-130
lines changed

spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java

Lines changed: 53 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,8 @@ private ChatResponse doGetChatResponse(DefaultChatClientRequestSpec inputRequest
379379

380380
// Apply the around advisor chain that terminates with the, last, model call
381381
// advisor.
382-
AdvisedResponse advisedResponse = inputRequestSpec.aroundAdvisorChain.nextAroundCall(advisedRequest);
382+
AdvisedResponse advisedResponse = inputRequestSpec.aroundAdvisorChainBuilder.build()
383+
.nextAroundCall(advisedRequest);
383384

384385
return advisedResponse.response();
385386
}
@@ -424,7 +425,7 @@ private Flux<ChatResponse> doGetObservableFluxChatResponse(DefaultChatClientRequ
424425
// advisedResponse -> couchbaseClient.streamToBucket(advisedResponse).then();
425426
//
426427
// List<StreamAggregationAdvisor> streamAggregationAdvisors = List.of(aggregationAdvisor);
427-
Flux<AdvisedResponse> stream = inputRequest.aroundAdvisorChain.nextAroundStream(initialAdvisedRequest);
428+
Flux<AdvisedResponse> stream = inputRequest.aroundAdvisorChainBuilder.build().nextAroundStream(initialAdvisedRequest);
428429

429430
// if (aggregationAdvisor != null) {
430431
// stream = new MessageAggregator().aggregateAdvisedResponse(stream, aggregationAdvisor);
@@ -485,11 +486,7 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe
485486

486487
private final Map<String, Object> advisorParams = new HashMap<>();
487488

488-
private final DefaultAroundAdvisorChain aroundAdvisorChain;
489-
490-
public CallAroundAdvisorChain getAroundAdvisorChain() {
491-
return this.aroundAdvisorChain;
492-
}
489+
private final DefaultAroundAdvisorChain.Builder aroundAdvisorChainBuilder;
493490

494491
private ObservationRegistry getObservationRegistry() {
495492
return this.observationRegistry;
@@ -574,51 +571,52 @@ public DefaultChatClientRequestSpec(ChatModel chatModel, String userText, Map<St
574571
this.observationRegistry = observationRegistry;
575572
this.customObservationConvention = customObservationConvention;
576573

577-
// @formatter:off
578-
this.aroundAdvisorChain = DefaultAroundAdvisorChain.builder(observationRegistry)
579-
// At the stack bottom add the non-streaming and streaming model call advisors.
580-
// They play the role of the last advisor in the around advisor chain.
581-
.push(new CallAroundAdvisor() {
582-
583-
@Override
584-
public String getName() {
585-
return CallAroundAdvisor.class.getSimpleName();
586-
}
587-
588-
@Override
589-
public int getOrder() {
590-
return Ordered.LOWEST_PRECEDENCE;
591-
}
592-
593-
@Override
594-
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
595-
return new AdvisedResponse(chatModel.call(advisedRequest.toPrompt()), Collections.unmodifiableMap(advisedRequest.adviseContext()));
596-
}
597-
})
598-
.push(new StreamAroundAdvisor() {
599-
600-
@Override
601-
public String getName() {
602-
return StreamAroundAdvisor.class.getSimpleName();
603-
}
604-
605-
@Override
606-
public int getOrder() {
607-
return Ordered.LOWEST_PRECEDENCE;
608-
}
609-
610-
@Override
611-
public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
612-
// TODO: use aggregate stream advisors and apply over the original
613-
// stream
614-
return chatModel.stream(advisedRequest.toPrompt())
615-
.map( chatResponse -> new AdvisedResponse(chatResponse, Collections.unmodifiableMap(advisedRequest.adviseContext())))
616-
.publishOn(Schedulers.boundedElastic());// TODO add option to disable.
617-
}
618-
})
619-
.pushAll(this.advisors)
620-
.build();
621-
// @formatter:on
574+
// @formatter:off
575+
// At the stack bottom add the non-streaming and streaming model call advisors.
576+
// They play the role of the last advisor in the around advisor chain.
577+
this.advisors.add(new CallAroundAdvisor() {
578+
579+
@Override
580+
public String getName() {
581+
return CallAroundAdvisor.class.getSimpleName();
582+
}
583+
584+
@Override
585+
public int getOrder() {
586+
return Ordered.LOWEST_PRECEDENCE;
587+
}
588+
589+
@Override
590+
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
591+
return new AdvisedResponse(chatModel.call(advisedRequest.toPrompt()), Collections.unmodifiableMap(advisedRequest.adviseContext()));
592+
}
593+
});
594+
595+
this.advisors.add(new StreamAroundAdvisor() {
596+
597+
@Override
598+
public String getName() {
599+
return StreamAroundAdvisor.class.getSimpleName();
600+
}
601+
602+
@Override
603+
public int getOrder() {
604+
return Ordered.LOWEST_PRECEDENCE;
605+
}
606+
607+
@Override
608+
public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
609+
// TODO: use aggregate stream advisors and apply over the original
610+
// stream
611+
return chatModel.stream(advisedRequest.toPrompt())
612+
.map( chatResponse -> new AdvisedResponse(chatResponse, Collections.unmodifiableMap(advisedRequest.adviseContext())))
613+
.publishOn(Schedulers.boundedElastic());// TODO add option to disable.
614+
}
615+
});
616+
// @formatter:on
617+
618+
this.aroundAdvisorChainBuilder = DefaultAroundAdvisorChain.builder(observationRegistry)
619+
.pushAll(this.advisors);
622620
}
623621

624622
/**
@@ -648,18 +646,21 @@ public ChatClientRequestSpec advisors(Consumer<ChatClient.AdvisorSpec> consumer)
648646
consumer.accept(as);
649647
this.advisorParams.putAll(as.getParams());
650648
this.advisors.addAll(as.getAdvisors());
649+
this.aroundAdvisorChainBuilder.pushAll(as.getAdvisors());
651650
return this;
652651
}
653652

654653
public ChatClientRequestSpec advisors(Advisor... advisors) {
655654
Assert.notNull(advisors, "the advisors must be non-null");
656655
this.advisors.addAll(Arrays.asList(advisors));
656+
this.aroundAdvisorChainBuilder.pushAll(Arrays.asList(advisors));
657657
return this;
658658
}
659659

660660
public ChatClientRequestSpec advisors(List<Advisor> advisors) {
661661
Assert.notNull(advisors, "the advisors must be non-null");
662662
this.advisors.addAll(advisors);
663+
this.aroundAdvisorChainBuilder.pushAll(advisors);
663664
return this;
664665
}
665666

spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChain.java

Lines changed: 54 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -58,57 +58,16 @@ public class DefaultAroundAdvisorChain implements CallAroundAdvisorChain, Stream
5858

5959
private final ObservationRegistry observationRegistry;
6060

61-
DefaultAroundAdvisorChain(ObservationRegistry observationRegistry, List<Advisor> advisors) {
62-
Assert.notNull(advisors, "the advisors must be non-null");
63-
this.observationRegistry = observationRegistry;
64-
this.callAroundAdvisors = new ConcurrentLinkedDeque<>();
65-
this.streamAroundAdvisors = new ConcurrentLinkedDeque<>();
66-
this.pushAll(advisors);
67-
}
68-
69-
void pushAll(List<? extends Advisor> advisors) {
70-
Assert.notNull(advisors, "the advisors must be non-null");
71-
if (!CollectionUtils.isEmpty(advisors)) {
72-
List<CallAroundAdvisor> callAroundAdvisors = advisors.stream()
73-
.filter(a -> a instanceof CallAroundAdvisor)
74-
.map(a -> (CallAroundAdvisor) a)
75-
.toList();
76-
77-
if (!CollectionUtils.isEmpty(callAroundAdvisors)) {
78-
callAroundAdvisors.forEach(this.callAroundAdvisors::push);
79-
}
61+
DefaultAroundAdvisorChain(ObservationRegistry observationRegistry, Deque<CallAroundAdvisor> callAroundAdvisors,
62+
Deque<StreamAroundAdvisor> streamAroundAdvisors) {
8063

81-
List<StreamAroundAdvisor> streamAroundAdvisors = advisors.stream()
82-
.filter(a -> a instanceof StreamAroundAdvisor)
83-
.map(a -> (StreamAroundAdvisor) a)
84-
.toList();
85-
86-
if (!CollectionUtils.isEmpty(streamAroundAdvisors)) {
87-
streamAroundAdvisors.forEach(this.streamAroundAdvisors::push);
88-
}
89-
90-
this.reOrder();
91-
}
92-
}
64+
Assert.notNull(observationRegistry, "the observationRegistry must be non-null");
65+
Assert.notNull(callAroundAdvisors, "the callAroundAdvisors must be non-null");
66+
Assert.notNull(streamAroundAdvisors, "the streamAroundAdvisors must be non-null");
9367

94-
/**
95-
* (Re)orders the advisors in priority order based on their Ordered attribute.
96-
*
97-
* Note: this can be thread unsafe if the advisors are dynamically modified in the
98-
* prompt. To avoid this make sure to set advisors only in the ChatClient default
99-
* (e.g.builder) section.
100-
*/
101-
private void reOrder() {
102-
//
103-
ArrayList<CallAroundAdvisor> callAdvisors = new ArrayList<>(this.callAroundAdvisors);
104-
OrderComparator.sort(callAdvisors);
105-
this.callAroundAdvisors.clear();
106-
callAdvisors.forEach(this.callAroundAdvisors::addLast);
107-
108-
ArrayList<StreamAroundAdvisor> streamAdvisors = new ArrayList<>(this.streamAroundAdvisors);
109-
OrderComparator.sort(streamAdvisors);
110-
this.streamAroundAdvisors.clear();
111-
streamAdvisors.forEach(this.streamAroundAdvisors::addLast);
68+
this.observationRegistry = observationRegistry;
69+
this.callAroundAdvisors = callAroundAdvisors;
70+
this.streamAroundAdvisors = streamAroundAdvisors;
11271
}
11372

11473
@Override
@@ -171,28 +130,65 @@ public static class Builder {
171130

172131
private final ObservationRegistry observationRegistry;
173132

174-
// TODO(dj): this has all advisors actually; the build step filters the around
175-
// advisors
176-
private final List<Advisor> aroundAdvisors = new ArrayList<>();
133+
private final Deque<CallAroundAdvisor> callAroundAdvisors;
134+
135+
private final Deque<StreamAroundAdvisor> streamAroundAdvisors;
177136

178137
public Builder(ObservationRegistry observationRegistry) {
179138
this.observationRegistry = observationRegistry;
139+
this.callAroundAdvisors = new ConcurrentLinkedDeque<>();
140+
this.streamAroundAdvisors = new ConcurrentLinkedDeque<>();
180141
}
181142

182143
public Builder push(Advisor aroundAdvisor) {
183144
Assert.notNull(aroundAdvisor, "the aroundAdvisor must be non-null");
184-
this.aroundAdvisors.add(aroundAdvisor);
185-
return this;
145+
return this.pushAll(List.of(aroundAdvisor));
186146
}
187147

188-
public Builder pushAll(List<Advisor> aroundAdvisors) {
189-
Assert.notNull(aroundAdvisors, "the aroundAdvisors must be non-null");
190-
this.aroundAdvisors.addAll(aroundAdvisors);
148+
public Builder pushAll(List<? extends Advisor> advisors) {
149+
Assert.notNull(advisors, "the advisors must be non-null");
150+
if (!CollectionUtils.isEmpty(advisors)) {
151+
List<CallAroundAdvisor> callAroundAdvisors = advisors.stream()
152+
.filter(a -> a instanceof CallAroundAdvisor)
153+
.map(a -> (CallAroundAdvisor) a)
154+
.toList();
155+
156+
if (!CollectionUtils.isEmpty(callAroundAdvisors)) {
157+
callAroundAdvisors.forEach(this.callAroundAdvisors::push);
158+
}
159+
160+
List<StreamAroundAdvisor> streamAroundAdvisors = advisors.stream()
161+
.filter(a -> a instanceof StreamAroundAdvisor)
162+
.map(a -> (StreamAroundAdvisor) a)
163+
.toList();
164+
165+
if (!CollectionUtils.isEmpty(streamAroundAdvisors)) {
166+
streamAroundAdvisors.forEach(this.streamAroundAdvisors::push);
167+
}
168+
169+
this.reOrder();
170+
}
191171
return this;
192172
}
193173

174+
/**
175+
* (Re)orders the advisors in priority order based on their Ordered attribute.
176+
*/
177+
private void reOrder() {
178+
ArrayList<CallAroundAdvisor> callAdvisors = new ArrayList<>(this.callAroundAdvisors);
179+
OrderComparator.sort(callAdvisors);
180+
this.callAroundAdvisors.clear();
181+
callAdvisors.forEach(this.callAroundAdvisors::addLast);
182+
183+
ArrayList<StreamAroundAdvisor> streamAdvisors = new ArrayList<>(this.streamAroundAdvisors);
184+
OrderComparator.sort(streamAdvisors);
185+
this.streamAroundAdvisors.clear();
186+
streamAdvisors.forEach(this.streamAroundAdvisors::addLast);
187+
}
188+
194189
public DefaultAroundAdvisorChain build() {
195-
return new DefaultAroundAdvisorChain(this.observationRegistry, this.aroundAdvisors);
190+
return new DefaultAroundAdvisorChain(this.observationRegistry, this.callAroundAdvisors,
191+
this.streamAroundAdvisors);
196192
}
197193

198194
}

0 commit comments

Comments
 (0)