Skip to content

Commit 5205f3a

Browse files
committed
Refactored advisors even further
1 parent 43aeb1e commit 5205f3a

File tree

22 files changed

+117
-340
lines changed

22 files changed

+117
-340
lines changed

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiPaymentTransactionIT.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
import org.springframework.ai.chat.client.ChatClient;
3232
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
3333
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
34-
import org.springframework.ai.chat.client.advisor.api.AroundAdvisorChain;
34+
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
3535
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
3636
import org.springframework.ai.converter.BeanOutputConverter;
3737
import org.springframework.ai.model.function.FunctionCallbackContext;
@@ -79,7 +79,7 @@ public int getOrder() {
7979
}
8080

8181
@Override
82-
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, AroundAdvisorChain chain) {
82+
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
8383

8484
advisedRequest = this.before(advisedRequest);
8585

models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiPaymentTransactionIT.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
import org.springframework.ai.chat.client.ChatClient;
3232
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
3333
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
34-
import org.springframework.ai.chat.client.advisor.api.AroundAdvisorChain;
34+
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
3535
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
3636
import org.springframework.ai.model.function.FunctionCallbackContext;
3737
import org.springframework.ai.model.function.FunctionCallbackWrapper.Builder.SchemaType;
@@ -80,7 +80,7 @@ public int getOrder() {
8080
}
8181

8282
@Override
83-
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, AroundAdvisorChain chain) {
83+
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
8484
var response = chain.nextAroundCall(before(advisedRequest));
8585
observeAfter(response);
8686
return response;

spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,10 @@
3232
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
3333
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
3434
import org.springframework.ai.chat.client.advisor.api.Advisor;
35-
import org.springframework.ai.chat.client.advisor.api.AroundAdvisorChain;
35+
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
3636
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
3737
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor;
38+
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain;
3839
import org.springframework.ai.chat.client.observation.ChatClientObservationContext;
3940
import org.springframework.ai.chat.client.observation.ChatClientObservationConvention;
4041
import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation;
@@ -419,7 +420,17 @@ private Flux<ChatResponse> doGetObservableFluxChatResponse(DefaultChatClientRequ
419420
// @formatter:off
420421
// Apply the around advisor chain that terminates with the, last,
421422
// model call advisor.
422-
return inputRequest.aroundAdvisorChain.nextAroundStream(initialAdvisedRequest)
423+
// StreamAggregationAdvisor aggregationAdvisor =
424+
// advisedResponse -> couchbaseClient.streamToBucket(advisedResponse).then();
425+
//
426+
// List<StreamAggregationAdvisor> streamAggregationAdvisors = List.of(aggregationAdvisor);
427+
Flux<AdvisedResponse> stream = inputRequest.aroundAdvisorChain.nextAroundStream(initialAdvisedRequest);
428+
429+
// if (aggregationAdvisor != null) {
430+
// stream = new MessageAggregator().aggregateAdvisedResponse(stream, aggregationAdvisor);
431+
// }
432+
433+
return stream
423434
.map(AdvisedResponse::response)
424435
.doOnError(observation::error)
425436
.doFinally(s -> observation.stop())
@@ -476,7 +487,7 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe
476487

477488
private final DefaultAroundAdvisorChain aroundAdvisorChain;
478489

479-
public AroundAdvisorChain getAroundAdvisorChain() {
490+
public CallAroundAdvisorChain getAroundAdvisorChain() {
480491
return this.aroundAdvisorChain;
481492
}
482493

@@ -580,7 +591,7 @@ public int getOrder() {
580591
}
581592

582593
@Override
583-
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, AroundAdvisorChain chain) {
594+
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
584595
return new AdvisedResponse(chatModel.call(advisedRequest.toPrompt()), Collections.unmodifiableMap(advisedRequest.adviseContext()));
585596
}
586597
})
@@ -597,7 +608,9 @@ public int getOrder() {
597608
}
598609

599610
@Override
600-
public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, AroundAdvisorChain chain) {
611+
public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
612+
// TODO: use aggregate stream advisors and apply over the original
613+
// stream
601614
return chatModel.stream(advisedRequest.toPrompt())
602615
.map( chatResponse -> new AdvisedResponse(chatResponse, Collections.unmodifiableMap(advisedRequest.adviseContext())))
603616
.publishOn(Schedulers.boundedElastic());// TODO add option to disable.
@@ -635,21 +648,18 @@ public ChatClientRequestSpec advisors(Consumer<ChatClient.AdvisorSpec> consumer)
635648
consumer.accept(as);
636649
this.advisorParams.putAll(as.getParams());
637650
this.advisors.addAll(as.getAdvisors());
638-
this.aroundAdvisorChain.pushAll(as.getAdvisors());
639651
return this;
640652
}
641653

642654
public ChatClientRequestSpec advisors(Advisor... advisors) {
643655
Assert.notNull(advisors, "the advisors must be non-null");
644656
this.advisors.addAll(Arrays.asList(advisors));
645-
this.aroundAdvisorChain.pushAll(Arrays.asList(advisors));
646657
return this;
647658
}
648659

649660
public ChatClientRequestSpec advisors(List<Advisor> advisors) {
650661
Assert.notNull(advisors, "the advisors must be non-null");
651662
this.advisors.addAll(advisors);
652-
this.aroundAdvisorChain.pushAll(advisors);
653663
return this;
654664
}
655665

spring-ai-core/src/main/java/org/springframework/ai/chat/client/RequestResponseAdvisor.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,13 @@
2020
import java.util.HashMap;
2121
import java.util.Map;
2222
import java.util.concurrent.ConcurrentHashMap;
23-
import java.util.concurrent.atomic.AtomicReference;
2423

2524
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
2625
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
27-
import org.springframework.ai.chat.client.advisor.api.AroundAdvisorChain;
26+
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
2827
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
2928
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor;
29+
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain;
3030
import org.springframework.ai.chat.model.ChatModel;
3131
import org.springframework.ai.chat.model.ChatResponse;
3232
import org.springframework.ai.chat.prompt.Prompt;
@@ -38,8 +38,8 @@
3838
* {@link ChatModel#stream(Prompt)} methods calls. The {@link ChatClient} maintains a
3939
* chain of advisors with shared advise context.
4040
*
41-
* @deprecated since 1.0.0 please use {@link CallAroundAdvisor},
42-
* {@link StreamAroundAdvisor}, {@link ObservingAfterAdvisor} instead.
41+
* @deprecated since 1.0.0 please use {@link CallAroundAdvisor} or
42+
* {@link StreamAroundAdvisor} instead.
4343
* @author Christian Tzolov
4444
* @since 1.0.0
4545
*/
@@ -64,7 +64,7 @@ default Flux<ChatResponse> adviseResponse(Flux<ChatResponse> fluxResponse, Map<S
6464
}
6565

6666
@Override
67-
default AdvisedResponse aroundCall(AdvisedRequest advisedRequest, AroundAdvisorChain chain) {
67+
default AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
6868
var context = new HashMap<>(advisedRequest.adviseContext());
6969
var requestPrim = adviseRequest(advisedRequest, context);
7070
advisedRequest = AdvisedRequest.from(requestPrim)
@@ -78,7 +78,7 @@ default AdvisedResponse aroundCall(AdvisedRequest advisedRequest, AroundAdvisorC
7878
return new AdvisedResponse(chatResponse, Collections.unmodifiableMap(context));
7979
}
8080

81-
default Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, AroundAdvisorChain chain) {
81+
default Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
8282

8383
ConcurrentHashMap<String, Object> context = new ConcurrentHashMap<>(advisedRequest.adviseContext());
8484

spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor2.java

Lines changed: 0 additions & 92 deletions
This file was deleted.

spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChain.java

Lines changed: 19 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
99
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
1010
import org.springframework.ai.chat.client.advisor.api.Advisor;
11-
import org.springframework.ai.chat.client.advisor.api.AroundAdvisorChain;
11+
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
1212
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
1313
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor;
14+
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain;
1415
import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationContext;
1516
import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationConvention;
1617
import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationDocumentation;
@@ -23,7 +24,7 @@
2324
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
2425
import reactor.core.publisher.Flux;
2526

26-
public class DefaultAroundAdvisorChain implements AroundAdvisorChain {
27+
public class DefaultAroundAdvisorChain implements CallAroundAdvisorChain, StreamAroundAdvisorChain {
2728

2829
public static final AdvisorObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultAdvisorObservationConvention();
2930

@@ -33,35 +34,25 @@ public class DefaultAroundAdvisorChain implements AroundAdvisorChain {
3334

3435
private final ObservationRegistry observationRegistry;
3536

36-
public DefaultAroundAdvisorChain(ObservationRegistry observationRegistry) {
37-
this(observationRegistry, new ArrayDeque<CallAroundAdvisor>(), new ArrayDeque<StreamAroundAdvisor>());
38-
}
39-
40-
public DefaultAroundAdvisorChain(ObservationRegistry observationRegistry,
41-
Deque<CallAroundAdvisor> callAroundAdvisors, Deque<StreamAroundAdvisor> streamAroundAdvisors) {
42-
Assert.notNull(callAroundAdvisors, "the callAroundAdvisors must be non-null");
37+
DefaultAroundAdvisorChain(ObservationRegistry observationRegistry, List<Advisor> advisors) {
38+
Assert.notNull(advisors, "the callAroundAdvisors must be non-null");
4339
this.observationRegistry = observationRegistry;
44-
this.callAroundAdvisors = callAroundAdvisors;
45-
this.streamAroundAdvisors = streamAroundAdvisors;
46-
}
47-
48-
public DefaultAroundAdvisorChain(ObservationRegistry observationRegistry, List<Advisor> advisors) {
49-
this(observationRegistry);
40+
this.callAroundAdvisors = new ArrayDeque<>();
41+
this.streamAroundAdvisors = new ArrayDeque<>();
5042
Assert.notNull(advisors, "the advisors must be non-null");
5143
this.pushAll(advisors);
5244
}
5345

54-
public void pushAll(List<? extends Advisor> advisors) {
46+
void pushAll(List<? extends Advisor> advisors) {
5547
Assert.notNull(advisors, "the advisors must be non-null");
5648
if (!CollectionUtils.isEmpty(advisors)) {
57-
5849
List<CallAroundAdvisor> callAroundAdvisors = advisors.stream()
5950
.filter(a -> a instanceof CallAroundAdvisor)
6051
.map(a -> (CallAroundAdvisor) a)
6152
.toList();
6253

6354
if (!CollectionUtils.isEmpty(callAroundAdvisors)) {
64-
callAroundAdvisors.stream().forEach(this.callAroundAdvisors::push);
55+
callAroundAdvisors.forEach(this.callAroundAdvisors::push);
6556
}
6657

6758
List<StreamAroundAdvisor> streamAroundAdvisors = advisors.stream()
@@ -70,7 +61,7 @@ public void pushAll(List<? extends Advisor> advisors) {
7061
.toList();
7162

7263
if (!CollectionUtils.isEmpty(streamAroundAdvisors)) {
73-
streamAroundAdvisors.stream().forEach(this.streamAroundAdvisors::push);
64+
streamAroundAdvisors.forEach(this.streamAroundAdvisors::push);
7465
}
7566

7667
this.reOrder();
@@ -79,16 +70,15 @@ public void pushAll(List<? extends Advisor> advisors) {
7970

8071
public void reOrder() {
8172
// Order the advisors in priority order based on their Ordered attribute.
82-
8373
ArrayList<CallAroundAdvisor> temp = new ArrayList<>(this.callAroundAdvisors);
8474
OrderComparator.sort(temp);
8575
this.callAroundAdvisors.clear();
86-
temp.stream().forEach(this.callAroundAdvisors::addLast);
76+
temp.forEach(this.callAroundAdvisors::addLast);
8777

8878
ArrayList<StreamAroundAdvisor> temp2 = new ArrayList<>(this.streamAroundAdvisors);
8979
OrderComparator.sort(temp2);
9080
this.streamAroundAdvisors.clear();
91-
temp2.stream().forEach(this.streamAroundAdvisors::addLast);
81+
temp2.forEach(this.streamAroundAdvisors::addLast);
9282
}
9383

9484
@Override
@@ -114,9 +104,7 @@ public AdvisedResponse nextAroundCall(AdvisedRequest advisedRequest) {
114104

115105
@Override
116106
public Flux<AdvisedResponse> nextAroundStream(AdvisedRequest advisedRequest) {
117-
118107
return Flux.deferContextual(contextView -> {
119-
120108
if (this.streamAroundAdvisors.isEmpty()) {
121109
return Flux.error(new IllegalStateException("No AroundAdvisor available to execute"));
122110
}
@@ -137,9 +125,10 @@ public Flux<AdvisedResponse> nextAroundStream(AdvisedRequest advisedRequest) {
137125

138126
// @formatter:off
139127
return Flux.defer(() -> advisor.aroundStream(advisedRequest, this))
140-
.doOnError(observation::error)
141-
.doFinally(s -> { observation.stop();
142-
}).contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
128+
.doOnError(observation::error)
129+
.doFinally(s -> {
130+
observation.stop();
131+
}).contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
143132
// @formatter:on
144133
});
145134
}
@@ -152,10 +141,12 @@ public static class Builder {
152141

153142
private final DefaultAroundAdvisorChain aroundAdvisorChain;
154143

144+
// TODO(dj): this has all advisors actually; the build step filters the around
145+
// advisors
155146
private final List<Advisor> aroundAdvisors = new ArrayList<>();
156147

157148
public Builder(ObservationRegistry observationRegistry) {
158-
this.aroundAdvisorChain = new DefaultAroundAdvisorChain(observationRegistry);
149+
this.aroundAdvisorChain = new DefaultAroundAdvisorChain(observationRegistry, aroundAdvisors);
159150
}
160151

161152
public Builder push(Advisor aroundAdvisor) {

0 commit comments

Comments
 (0)