Skip to content

Commit 16bb312

Browse files
committed
Introduce first-class chat memory support
- ChatMemory will become a generic interface to implement different memory management strategies. It’s been moved from the “”spring-ai-client-chat” package to “spring-ai-model” package while retaining the same package, so it’s transparent to users. - A MessageWindowChatMemory has been introduced to provide support for a chat memory that keeps at most N messages in the memory. - A MessageWindowProcessingPolicy API has been introduced to customise the processing policy for the message window. A default implementation is provided out-of-the-box. - A ChatMemoryRepository interface has been introduced to support different storage strategies for the chat memory. It’s meant to be used as part of a ChatMemory implementation. This is different than before, where the storage-specific implementation was directly tied to the ChatMemory. This design is familiar to Spring users since it’s used already in the ecosystem. The goal was to use a programming model similar to Spring Session and Spring Data. - The ChatClient now supports memory as a first-class citizen, superseding the need for an Advisor to manage the chat memory. It also simplifies providing a conversationId. This feature lays the foundation for including the intermediate messages in tool calling in the memory as well. - All the changes introduced in this PR are backword-compatible. Signed-off-by: Thomas Vitale <[email protected]>
1 parent c0bc623 commit 16bb312

24 files changed

+1193
-58
lines changed

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientIT.java

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -30,6 +30,7 @@
3030
import org.junit.jupiter.params.provider.ValueSource;
3131
import org.slf4j.Logger;
3232
import org.slf4j.LoggerFactory;
33+
import org.springframework.ai.chat.memory.MessageWindowChatMemory;
3334
import reactor.core.publisher.Flux;
3435

3536
import org.springframework.ai.chat.client.ChatClient;
@@ -378,6 +379,69 @@ void multiModalityAudioResponse() {
378379
logger.info("Response: " + response);
379380
}
380381

382+
@Test
383+
void chatMemoryWithDefaults() {
384+
ChatClient chatClient = ChatClient.builder(this.chatModel)
385+
.defaultMemory(MessageWindowChatMemory.builder().build())
386+
.build();
387+
388+
String conversationId = "007";
389+
390+
ChatResponse response1 = chatClient.prompt("My name is Bond. James Bond.")
391+
.conversationId(conversationId)
392+
.call()
393+
.chatResponse();
394+
395+
assertThat(response1).isNotNull();
396+
397+
ChatResponse response2 = chatClient.prompt("What is my name?")
398+
.conversationId(conversationId)
399+
.call()
400+
.chatResponse();
401+
402+
assertThat(response2).isNotNull();
403+
assertThat(response2.getResults()).hasSize(1);
404+
assertThat(response2.getResults().get(0).getOutput().getText()).contains("James Bond");
405+
}
406+
407+
@Test
408+
void chatMemoryWithMessageWindowSize() {
409+
ChatClient chatClient = ChatClient.builder(this.chatModel)
410+
.defaultMemory(MessageWindowChatMemory.builder().maxMessages(3).build())
411+
.build();
412+
413+
String conversationId = "007";
414+
415+
ChatResponse response1 = chatClient.prompt("The cat is on the table")
416+
.conversationId(conversationId)
417+
.call()
418+
.chatResponse();
419+
420+
assertThat(response1).isNotNull();
421+
422+
ChatResponse response2 = chatClient.prompt("My name is Bond. James Bond.")
423+
.conversationId(conversationId)
424+
.call()
425+
.chatResponse();
426+
427+
assertThat(response2).isNotNull();
428+
429+
ChatResponse response3 = chatClient.prompt("What is my name?")
430+
.conversationId(conversationId)
431+
.call()
432+
.chatResponse();
433+
434+
assertThat(response3).isNotNull();
435+
436+
ChatResponse response4 = chatClient.prompt("Where is the cat?")
437+
.conversationId(conversationId)
438+
.call()
439+
.chatResponse();
440+
441+
assertThat(response4).isNotNull();
442+
assertThat(response2.getResults().get(0).getOutput().getText()).doesNotContainIgnoringCase("table");
443+
}
444+
381445
record ActorsFilms(String actor, List<String> movies) {
382446

383447
}

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClient.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
import org.springframework.ai.chat.client.advisor.api.Advisor;
2929
import org.springframework.ai.chat.client.observation.ChatClientObservationConvention;
30+
import org.springframework.ai.chat.memory.ChatMemory;
3031
import org.springframework.ai.chat.messages.Message;
3132
import org.springframework.ai.chat.model.ChatModel;
3233
import org.springframework.ai.chat.model.ChatResponse;
@@ -247,6 +248,10 @@ interface ChatClientRequestSpec {
247248

248249
ChatClientRequestSpec user(Consumer<PromptUserSpec> consumer);
249250

251+
ChatClientRequestSpec memory(ChatMemory chatMemory);
252+
253+
ChatClientRequestSpec conversationId(String conversationId);
254+
250255
CallResponseSpec call();
251256

252257
StreamResponseSpec stream();
@@ -294,6 +299,8 @@ interface Builder {
294299

295300
Builder defaultToolContext(Map<String, Object> toolContext);
296301

302+
Builder defaultMemory(ChatMemory chatMemory);
303+
297304
Builder clone();
298305

299306
ChatClient build();

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClientResponse.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ public static class Builder {
5151
private Builder() {
5252
}
5353

54-
public Builder chatResponse(ChatResponse chatResponse) {
54+
public Builder chatResponse(@Nullable ChatResponse chatResponse) {
5555
this.chatResponse = chatResponse;
5656
return this;
5757
}

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java

Lines changed: 53 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import org.springframework.ai.chat.client.observation.ChatClientObservationConvention;
4545
import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation;
4646
import org.springframework.ai.chat.client.observation.DefaultChatClientObservationConvention;
47+
import org.springframework.ai.chat.memory.ChatMemory;
4748
import org.springframework.ai.chat.messages.AbstractMessage;
4849
import org.springframework.ai.chat.messages.Message;
4950
import org.springframework.ai.chat.messages.MessageType;
@@ -131,7 +132,7 @@ private static AdvisedRequest toAdvisedRequest(DefaultChatClientRequestSpec inpu
131132
public static DefaultChatClientRequestSpec toDefaultChatClientRequestSpec(AdvisedRequest advisedRequest,
132133
ObservationRegistry observationRegistry, ChatClientObservationConvention customObservationConvention) {
133134

134-
return new DefaultChatClientRequestSpec(advisedRequest.chatModel(), advisedRequest.userText(),
135+
return new DefaultChatClientRequestSpec(advisedRequest.chatModel(), null, null, advisedRequest.userText(),
135136
advisedRequest.userParams(), advisedRequest.systemText(), advisedRequest.systemParams(),
136137
advisedRequest.toolCallbacks(), advisedRequest.messages(), advisedRequest.toolNames(),
137138
advisedRequest.media(), advisedRequest.chatOptions(), advisedRequest.advisors(),
@@ -660,10 +661,13 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe
660661

661662
private final Map<String, Object> advisorParams = new HashMap<>();
662663

663-
private final DefaultAroundAdvisorChain.Builder aroundAdvisorChainBuilder;
664-
665664
private final Map<String, Object> toolContext = new HashMap<>();
666665

666+
private String conversationId;
667+
668+
@Nullable
669+
private ChatMemory chatMemory;
670+
667671
@Nullable
668672
private String userText;
669673

@@ -675,16 +679,17 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe
675679

676680
/* copy constructor */
677681
DefaultChatClientRequestSpec(DefaultChatClientRequestSpec ccr) {
678-
this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.systemText, ccr.systemParams, ccr.toolCallbacks,
679-
ccr.messages, ccr.toolNames, ccr.media, ccr.chatOptions, ccr.advisors, ccr.advisorParams,
680-
ccr.observationRegistry, ccr.observationConvention, ccr.toolContext);
682+
this(ccr.chatModel, ccr.chatMemory, ccr.conversationId, ccr.userText, ccr.userParams, ccr.systemText,
683+
ccr.systemParams, ccr.toolCallbacks, ccr.messages, ccr.toolNames, ccr.media, ccr.chatOptions,
684+
ccr.advisors, ccr.advisorParams, ccr.observationRegistry, ccr.observationConvention,
685+
ccr.toolContext);
681686
}
682687

683-
public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userText,
684-
Map<String, Object> userParams, @Nullable String systemText, Map<String, Object> systemParams,
685-
List<ToolCallback> toolCallbacks, List<Message> messages, List<String> toolNames, List<Media> media,
686-
@Nullable ChatOptions chatOptions, List<Advisor> advisors, Map<String, Object> advisorParams,
687-
ObservationRegistry observationRegistry,
688+
public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable ChatMemory chatMemory,
689+
@Nullable String conversationId, @Nullable String userText, Map<String, Object> userParams,
690+
@Nullable String systemText, Map<String, Object> systemParams, List<ToolCallback> toolCallbacks,
691+
List<Message> messages, List<String> toolNames, List<Media> media, @Nullable ChatOptions chatOptions,
692+
List<Advisor> advisors, Map<String, Object> advisorParams, ObservationRegistry observationRegistry,
688693
@Nullable ChatClientObservationConvention observationConvention, Map<String, Object> toolContext) {
689694

690695
Assert.notNull(chatModel, "chatModel cannot be null");
@@ -703,6 +708,10 @@ public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userTe
703708
this.chatOptions = chatOptions != null ? chatOptions.copy()
704709
: (chatModel.getDefaultOptions() != null) ? chatModel.getDefaultOptions().copy() : null;
705710

711+
this.chatMemory = chatMemory;
712+
this.conversationId = StringUtils.hasText(conversationId) ? conversationId
713+
: ChatMemory.DEFAULT_CONVERSATION_ID;
714+
706715
this.userText = userText;
707716
this.userParams.putAll(userParams);
708717
this.systemText = systemText;
@@ -723,9 +732,6 @@ public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userTe
723732
// They play the role of the last advisors in the advisor chain.
724733
this.advisors.add(new ChatModelCallAdvisor(chatModel));
725734
this.advisors.add(new ChatModelStreamAdvisor(chatModel));
726-
727-
this.aroundAdvisorChainBuilder = DefaultAroundAdvisorChain.builder(observationRegistry)
728-
.pushAll(this.advisors);
729735
}
730736

731737
private ObservationRegistry getObservationRegistry() {
@@ -787,6 +793,15 @@ public Map<String, Object> getToolContext() {
787793
return this.toolContext;
788794
}
789795

796+
public String getConversationId() {
797+
return this.conversationId;
798+
}
799+
800+
@Nullable
801+
public ChatMemory getChatMemory() {
802+
return this.chatMemory;
803+
}
804+
790805
/**
791806
* Return a {@code ChatClient2Builder} to create a new {@code ChatClient2} whose
792807
* settings are replicated from this {@code ChatClientRequest}.
@@ -822,23 +837,20 @@ public ChatClientRequestSpec advisors(Consumer<ChatClient.AdvisorSpec> consumer)
822837
consumer.accept(advisorSpec);
823838
this.advisorParams.putAll(advisorSpec.getParams());
824839
this.advisors.addAll(advisorSpec.getAdvisors());
825-
this.aroundAdvisorChainBuilder.pushAll(advisorSpec.getAdvisors());
826840
return this;
827841
}
828842

829843
public ChatClientRequestSpec advisors(Advisor... advisors) {
830844
Assert.notNull(advisors, "advisors cannot be null");
831845
Assert.noNullElements(advisors, "advisors cannot contain null elements");
832846
this.advisors.addAll(Arrays.asList(advisors));
833-
this.aroundAdvisorChainBuilder.pushAll(Arrays.asList(advisors));
834847
return this;
835848
}
836849

837850
public ChatClientRequestSpec advisors(List<Advisor> advisors) {
838851
Assert.notNull(advisors, "advisors cannot be null");
839852
Assert.noNullElements(advisors, "advisors cannot contain null elements");
840853
this.advisors.addAll(advisors);
841-
this.aroundAdvisorChainBuilder.pushAll(advisors);
842854
return this;
843855
}
844856

@@ -982,18 +994,40 @@ public ChatClientRequestSpec user(Consumer<PromptUserSpec> consumer) {
982994
return this;
983995
}
984996

997+
@Override
998+
public ChatClientRequestSpec memory(ChatMemory chatMemory) {
999+
Assert.notNull(chatMemory, "chatMemory cannot be null");
1000+
this.chatMemory = chatMemory;
1001+
return this;
1002+
}
1003+
1004+
@Override
1005+
public ChatClientRequestSpec conversationId(String conversationId) {
1006+
Assert.hasText(conversationId, "conversationId cannot be null or empty");
1007+
this.conversationId = conversationId;
1008+
return this;
1009+
}
1010+
9851011
public CallResponseSpec call() {
986-
BaseAdvisorChain advisorChain = aroundAdvisorChainBuilder.build();
1012+
BaseAdvisorChain advisorChain = buildAdvisorChain();
9871013
return new DefaultCallResponseSpec(toAdvisedRequest(this).toChatClientRequest(), advisorChain,
9881014
observationRegistry, observationConvention);
9891015
}
9901016

9911017
public StreamResponseSpec stream() {
992-
BaseAdvisorChain advisorChain = aroundAdvisorChainBuilder.build();
1018+
BaseAdvisorChain advisorChain = buildAdvisorChain();
9931019
return new DefaultStreamResponseSpec(toAdvisedRequest(this).toChatClientRequest(), advisorChain,
9941020
observationRegistry, observationConvention);
9951021
}
9961022

1023+
private BaseAdvisorChain buildAdvisorChain() {
1024+
return DefaultAroundAdvisorChain.builder(this.observationRegistry)
1025+
.conversationId(this.conversationId)
1026+
.chatMemory(this.chatMemory)
1027+
.pushAll(this.advisors)
1028+
.build();
1029+
}
1030+
9971031
}
9981032

9991033
// Prompt

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.springframework.ai.chat.client.DefaultChatClient.DefaultChatClientRequestSpec;
3131
import org.springframework.ai.chat.client.advisor.api.Advisor;
3232
import org.springframework.ai.chat.client.observation.ChatClientObservationConvention;
33+
import org.springframework.ai.chat.memory.ChatMemory;
3334
import org.springframework.ai.chat.messages.Message;
3435
import org.springframework.ai.chat.model.ChatModel;
3536
import org.springframework.ai.chat.prompt.ChatOptions;
@@ -64,8 +65,8 @@ public DefaultChatClientBuilder(ChatModel chatModel, ObservationRegistry observa
6465
@Nullable ChatClientObservationConvention customObservationConvention) {
6566
Assert.notNull(chatModel, "the " + ChatModel.class.getName() + " must be non-null");
6667
Assert.notNull(observationRegistry, "the " + ObservationRegistry.class.getName() + " must be non-null");
67-
this.defaultRequest = new DefaultChatClientRequestSpec(chatModel, null, Map.of(), null, Map.of(), List.of(),
68-
List.of(), List.of(), List.of(), null, List.of(), Map.of(), observationRegistry,
68+
this.defaultRequest = new DefaultChatClientRequestSpec(chatModel, null, null, null, Map.of(), null, Map.of(),
69+
List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(), observationRegistry,
6970
customObservationConvention, Map.of());
7071
}
7172

@@ -190,6 +191,13 @@ public Builder defaultToolContext(Map<String, Object> toolContext) {
190191
return this;
191192
}
192193

194+
@Override
195+
public Builder defaultMemory(ChatMemory chatMemory) {
196+
Assert.notNull(chatMemory, "chatMemory cannot be null");
197+
this.defaultRequest.memory(chatMemory);
198+
return this;
199+
}
200+
193201
void addMessages(List<Message> messages) {
194202
this.defaultRequest.messages(messages);
195203
}

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/ChatModelCallAdvisor.java

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,16 @@
2020
import org.springframework.ai.chat.client.ChatClientResponse;
2121
import org.springframework.ai.chat.client.advisor.api.CallAdvisor;
2222
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
23+
import org.springframework.ai.chat.memory.ChatMemory;
24+
import org.springframework.ai.chat.messages.Message;
2325
import org.springframework.ai.chat.model.ChatModel;
2426
import org.springframework.ai.chat.model.ChatResponse;
27+
import org.springframework.ai.chat.model.Generation;
28+
import org.springframework.ai.chat.prompt.Prompt;
2529
import org.springframework.core.Ordered;
2630
import org.springframework.util.Assert;
2731

32+
import java.util.List;
2833
import java.util.Map;
2934

3035
/**
@@ -45,7 +50,29 @@ public ChatModelCallAdvisor(ChatModel chatModel) {
4550
public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAroundAdvisorChain chain) {
4651
Assert.notNull(chatClientRequest, "the chatClientRequest cannot be null");
4752

48-
ChatResponse chatResponse = chatModel.call(chatClientRequest.prompt());
53+
ChatMemory chatMemory = chain.getChatMemory();
54+
55+
ChatResponse chatResponse;
56+
if (chatMemory == null) {
57+
chatResponse = chatModel.call(chatClientRequest.prompt());
58+
}
59+
else {
60+
String conversationId = chain.getConversationId();
61+
chatMemory.add(conversationId, chatClientRequest.prompt().getInstructions());
62+
Prompt prompt = chatClientRequest.prompt().mutate().messages(chatMemory.get(conversationId)).build();
63+
chatResponse = chatModel.call(prompt);
64+
if (chatResponse != null) {
65+
List<Generation> generations = chatResponse.getResults();
66+
if (generations != null) {
67+
List<Message> assistantMessages = generations.stream()
68+
.map(generation -> (Message) generation.getOutput())
69+
.toList();
70+
chatMemory.add(conversationId, assistantMessages);
71+
}
72+
}
73+
74+
}
75+
4976
return ChatClientResponse.builder()
5077
.chatResponse(chatResponse)
5178
.context(Map.copyOf(chatClientRequest.context()))

0 commit comments

Comments
 (0)