Skip to content

Commit 5029d00

Browse files
committed
🩹 use memoized response on chat client [#2097]
1 parent f5761de commit 5029d00

File tree

1 file changed

+34
-22
lines changed

1 file changed

+34
-22
lines changed

spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,27 +16,9 @@
1616

1717
package org.springframework.ai.chat.client;
1818

19-
import java.io.IOException;
20-
import java.net.URL;
21-
import java.nio.charset.Charset;
22-
import java.util.ArrayList;
23-
import java.util.Arrays;
24-
import java.util.Collection;
25-
import java.util.Collections;
26-
import java.util.HashMap;
27-
import java.util.List;
28-
import java.util.Map;
29-
import java.util.Optional;
30-
import java.util.concurrent.ConcurrentHashMap;
31-
import java.util.function.Consumer;
32-
3319
import io.micrometer.observation.Observation;
3420
import io.micrometer.observation.ObservationRegistry;
3521
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
36-
import org.springframework.ai.tool.ToolCallbacks;
37-
import reactor.core.publisher.Flux;
38-
import reactor.core.scheduler.Schedulers;
39-
4022
import org.springframework.ai.chat.client.advisor.DefaultAroundAdvisorChain;
4123
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
4224
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
@@ -65,6 +47,7 @@
6547
import org.springframework.ai.model.Media;
6648
import org.springframework.ai.model.function.FunctionCallback;
6749
import org.springframework.ai.model.function.FunctionCallbackWrapper;
50+
import org.springframework.ai.tool.ToolCallbacks;
6851
import org.springframework.core.Ordered;
6952
import org.springframework.core.ParameterizedTypeReference;
7053
import org.springframework.core.io.Resource;
@@ -73,6 +56,22 @@
7356
import org.springframework.util.CollectionUtils;
7457
import org.springframework.util.MimeType;
7558
import org.springframework.util.StringUtils;
59+
import reactor.core.publisher.Flux;
60+
import reactor.core.scheduler.Schedulers;
61+
62+
import java.io.IOException;
63+
import java.net.URL;
64+
import java.nio.charset.Charset;
65+
import java.util.ArrayList;
66+
import java.util.Arrays;
67+
import java.util.Collection;
68+
import java.util.Collections;
69+
import java.util.HashMap;
70+
import java.util.List;
71+
import java.util.Map;
72+
import java.util.Optional;
73+
import java.util.concurrent.ConcurrentHashMap;
74+
import java.util.function.Consumer;
7675

7776
/**
7877
* The default implementation of {@link ChatClient} as created by the
@@ -393,6 +392,8 @@ public static class DefaultCallResponseSpec implements CallResponseSpec {
393392

394393
private final DefaultChatClientRequestSpec request;
395394

395+
private final ThreadLocal<Optional<ChatResponse>> memoizedResponse = ThreadLocal.withInitial(Optional::empty);
396+
396397
public DefaultCallResponseSpec(DefaultChatClientRequestSpec request) {
397398
Assert.notNull(request, "request cannot be null");
398399
this.request = request;
@@ -506,13 +507,16 @@ private static String getContentFromChatResponse(@Nullable ChatResponse chatResp
506507
@Override
507508
@Nullable
508509
public ChatResponse chatResponse() {
509-
return doGetChatResponse();
510+
final var chatResponse = memoizedResponse.get().orElseGet(this::doGetChatResponse);
511+
memoizedResponse.set(Optional.ofNullable(chatResponse));
512+
return chatResponse;
510513
}
511514

512515
@Override
513516
@Nullable
514517
public String content() {
515-
ChatResponse chatResponse = doGetChatResponse();
518+
final var chatResponse = memoizedResponse.get().orElseGet(this::doGetChatResponse);
519+
memoizedResponse.set(Optional.ofNullable(chatResponse));
516520
return getContentFromChatResponse(chatResponse);
517521
}
518522

@@ -522,6 +526,8 @@ public static class DefaultStreamResponseSpec implements StreamResponseSpec {
522526

523527
private final DefaultChatClientRequestSpec request;
524528

529+
private final ThreadLocal<Optional<Flux<ChatResponse>>> memoizedFlux = ThreadLocal.withInitial(Optional::empty);
530+
525531
public DefaultStreamResponseSpec(DefaultChatClientRequestSpec request) {
526532
Assert.notNull(request, "request cannot be null");
527533
this.request = request;
@@ -559,12 +565,18 @@ private Flux<ChatResponse> doGetObservableFluxChatResponse(DefaultChatClientRequ
559565

560566
@Override
561567
public Flux<ChatResponse> chatResponse() {
562-
return doGetObservableFluxChatResponse(this.request);
568+
final var chatResponseFlux = memoizedFlux.get()
569+
.orElseGet(() -> doGetObservableFluxChatResponse(this.request));
570+
memoizedFlux.set(Optional.of(chatResponseFlux));
571+
return chatResponseFlux;
563572
}
564573

565574
@Override
566575
public Flux<String> content() {
567-
return doGetObservableFluxChatResponse(this.request).map(r -> {
576+
final var chatResponseFlux = memoizedFlux.get()
577+
.orElseGet(() -> doGetObservableFluxChatResponse(this.request));
578+
memoizedFlux.set(Optional.of(chatResponseFlux));
579+
return chatResponseFlux.map(r -> {
568580
if (r.getResult() == null || r.getResult().getOutput() == null
569581
|| r.getResult().getOutput().getText() == null) {
570582
return "";

0 commit comments

Comments
 (0)