Skip to content

Commit f4972e7

Browse files
committed
Introduce first-class chat memory support
- ChatMemory will become a generic interface to implement different memory management strategies. It’s been moved from the “”spring-ai-client-chat” package to “spring-ai-model” package while retaining the same package, so it’s transparent to users. - A MessageWindowChatMemory has been introduced to provide support for a chat memory that keeps at most N messages in the memory. - A MessageWindowProcessingPolicy API has been introduced to customise the processing policy for the message window. A default implementation is provided out-of-the-box. - A ChatMemoryRepository interface has been introduced to support different storage strategies for the chat memory. It’s meant to be used as part of a ChatMemory implementation. This is different than before, where the storage-specific implementation was directly tied to the ChatMemory. This design is familiar to Spring users since it’s used already in the ecosystem. The goal was to use a programming model similar to Spring Session and Spring Data. - The JdbcChatMemory has been supersed by JdbcChatMemoryRepository. - The ChatClient now supports memory as a first-class citizen, superseding the need for an Advisor to manage the chat memory. It also simplifies providing a conversationId. This feature lays the foundation for including the intermediate messages in tool calling in the memory as well. - All the changes introduced in this PR are backword-compatible. Signed-off-by: Thomas Vitale <[email protected]>
1 parent c0bc623 commit f4972e7

File tree

36 files changed

+1687
-67
lines changed

36 files changed

+1687
-67
lines changed

memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemory.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,11 @@
3838
*
3939
* @author Jonathan Leijendekker
4040
* @since 1.0.0
41+
* @deprecated in favor of providing ChatClient directly a
42+
* {@link org.springframework.ai.chat.memory.MessageWindowChatMemory} with a
43+
* {@link JdbcChatMemoryRepository} instance.
4144
*/
45+
@Deprecated
4246
public class JdbcChatMemory implements ChatMemory {
4347

4448
private static final String QUERY_ADD = """

memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryConfig.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import org.springframework.util.Assert;
2121

2222
/**
23-
* Configuration for {@link JdbcChatMemory}.
23+
* Configuration for {@link JdbcChatMemoryRepository}.
2424
*
2525
* @author Jonathan Leijendekker
2626
* @since 1.0.0
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
/*
2+
* Copyright 2023-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.chat.memory.jdbc;
18+
19+
import org.springframework.ai.chat.memory.ChatMemoryRepository;
20+
import org.springframework.ai.chat.messages.*;
21+
import org.springframework.jdbc.core.BatchPreparedStatementSetter;
22+
import org.springframework.jdbc.core.JdbcTemplate;
23+
import org.springframework.jdbc.core.RowMapper;
24+
import org.springframework.lang.Nullable;
25+
import org.springframework.util.Assert;
26+
27+
import java.sql.PreparedStatement;
28+
import java.sql.ResultSet;
29+
import java.sql.SQLException;
30+
import java.util.List;
31+
32+
/**
33+
* An implementation of {@link ChatMemoryRepository} for JDBC.
34+
*
35+
* @author Jonathan Leijendekker
36+
* @author Thomas Vitale
37+
* @since 1.0.0
38+
*/
39+
public class JdbcChatMemoryRepository implements ChatMemoryRepository {
40+
41+
private static final String QUERY_ADD = """
42+
INSERT INTO ai_chat_memory (conversation_id, content, type) VALUES (?, ?, ?)""";
43+
44+
private static final String QUERY_GET = """
45+
SELECT content, type FROM ai_chat_memory WHERE conversation_id = ? ORDER BY "timestamp" DESC""";
46+
47+
private static final String QUERY_CLEAR = "DELETE FROM ai_chat_memory WHERE conversation_id = ?";
48+
49+
private final JdbcTemplate jdbcTemplate;
50+
51+
private JdbcChatMemoryRepository(JdbcChatMemoryConfig config) {
52+
Assert.notNull(config, "config cannot be null");
53+
this.jdbcTemplate = config.getJdbcTemplate();
54+
}
55+
56+
public static JdbcChatMemoryRepository create(JdbcChatMemoryConfig config) {
57+
return new JdbcChatMemoryRepository(config);
58+
}
59+
60+
@Override
61+
public List<Message> findById(String conversationId) {
62+
Assert.hasText(conversationId, "conversationId cannot be null or empty");
63+
return this.jdbcTemplate.query(QUERY_GET, new MessageRowMapper(), conversationId);
64+
}
65+
66+
@Override
67+
public void save(String conversationId, List<Message> messages) {
68+
Assert.hasText(conversationId, "conversationId cannot be null or empty");
69+
Assert.notNull(messages, "messages cannot be null");
70+
Assert.noNullElements(messages, "messages cannot contain null elements");
71+
this.jdbcTemplate.batchUpdate(QUERY_ADD, new AddBatchPreparedStatement(conversationId, messages));
72+
}
73+
74+
@Override
75+
public void deleteById(String conversationId) {
76+
Assert.hasText(conversationId, "conversationId cannot be null or empty");
77+
this.jdbcTemplate.update(QUERY_CLEAR, conversationId);
78+
}
79+
80+
private record AddBatchPreparedStatement(String conversationId,
81+
List<Message> messages) implements BatchPreparedStatementSetter {
82+
@Override
83+
public void setValues(PreparedStatement ps, int i) throws SQLException {
84+
var message = this.messages.get(i);
85+
86+
ps.setString(1, this.conversationId);
87+
ps.setString(2, message.getText());
88+
ps.setString(3, message.getMessageType().name());
89+
}
90+
91+
@Override
92+
public int getBatchSize() {
93+
return this.messages.size();
94+
}
95+
}
96+
97+
private static class MessageRowMapper implements RowMapper<Message> {
98+
99+
@Override
100+
@Nullable
101+
public Message mapRow(ResultSet rs, int i) throws SQLException {
102+
var content = rs.getString(1);
103+
var type = MessageType.valueOf(rs.getString(2));
104+
105+
return switch (type) {
106+
case USER -> new UserMessage(content);
107+
case ASSISTANT -> new AssistantMessage(content);
108+
case SYSTEM -> new SystemMessage(content);
109+
case TOOL -> null;
110+
};
111+
}
112+
113+
}
114+
115+
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
*
2828
* @author Jonathan Leijendekker
2929
*/
30-
class JdbcChatMemoryRuntimeHints implements RuntimeHintsRegistrar {
30+
class JdbcChatMemoryRepositoryRuntimeHints implements RuntimeHintsRegistrar {
3131

3232
@Override
3333
public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
/*
2+
* Copyright 2023-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
@NonNullApi
18+
@NonNullFields
19+
package org.springframework.ai.chat.memory.jdbc;
20+
21+
import org.springframework.lang.NonNullApi;
22+
import org.springframework.lang.NonNullFields;
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
org.springframework.aot.hint.RuntimeHintsRegistrar=\
2-
org.springframework.ai.chat.memory.jdbc.aot.hint.JdbcChatMemoryRuntimeHints
2+
org.springframework.ai.chat.memory.jdbc.aot.hint.JdbcChatMemoryRepositoryRuntimeHints
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
/*
2+
* Copyright 2023-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.chat.memory.jdbc;
18+
19+
import org.junit.jupiter.api.Test;
20+
import org.junit.jupiter.params.ParameterizedTest;
21+
import org.junit.jupiter.params.provider.CsvSource;
22+
import org.springframework.ai.chat.memory.ChatMemoryRepository;
23+
import org.springframework.ai.chat.messages.Message;
24+
import org.springframework.ai.chat.messages.MessageType;
25+
import org.springframework.ai.chat.messages.AssistantMessage;
26+
import org.springframework.ai.chat.messages.SystemMessage;
27+
import org.springframework.ai.chat.messages.UserMessage;
28+
import org.springframework.boot.SpringBootConfiguration;
29+
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
30+
import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration;
31+
import org.springframework.boot.autoconfigure.jdbc.DataSourceProperties;
32+
import org.springframework.boot.context.properties.ConfigurationProperties;
33+
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
34+
import org.springframework.context.annotation.Bean;
35+
import org.springframework.context.annotation.Primary;
36+
import org.springframework.jdbc.core.JdbcTemplate;
37+
import org.testcontainers.containers.PostgreSQLContainer;
38+
import org.testcontainers.junit.jupiter.Container;
39+
import org.testcontainers.junit.jupiter.Testcontainers;
40+
import org.testcontainers.utility.MountableFile;
41+
42+
import javax.sql.DataSource;
43+
import java.sql.Timestamp;
44+
import java.util.List;
45+
import java.util.UUID;
46+
47+
import static org.assertj.core.api.Assertions.assertThat;
48+
49+
/**
50+
* Integration tests for {@link JdbcChatMemoryRepository}.
51+
*
52+
* @author Jonathan Leijendekker
53+
* @author Thomas Vitale
54+
*/
55+
@Testcontainers
56+
class JdbcChatMemoryRepositoryIT {
57+
58+
@Container
59+
@SuppressWarnings("resource")
60+
static PostgreSQLContainer<?> postgresContainer = new PostgreSQLContainer<>("postgres:17")
61+
.withDatabaseName("chat_memory_test")
62+
.withUsername("postgres")
63+
.withPassword("postgres")
64+
.withCopyFileToContainer(
65+
MountableFile.forClasspathResource("org/springframework/ai/chat/memory/jdbc/schema-postgresql.sql"),
66+
"/docker-entrypoint-initdb.d/schema.sql");
67+
68+
private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
69+
.withUserConfiguration(JdbcChatMemoryRepositoryIT.TestApplication.class)
70+
.withPropertyValues(String.format("myapp.datasource.url=%s", postgresContainer.getJdbcUrl()),
71+
String.format("myapp.datasource.username=%s", postgresContainer.getUsername()),
72+
String.format("myapp.datasource.password=%s", postgresContainer.getPassword()));
73+
74+
@Test
75+
void correctChatMemoryRepositoryInstance() {
76+
this.contextRunner.run(context -> {
77+
var chatMemoryRepository = context.getBean(ChatMemoryRepository.class);
78+
assertThat(chatMemoryRepository).isInstanceOf(ChatMemoryRepository.class);
79+
});
80+
}
81+
82+
@ParameterizedTest
83+
@CsvSource({ "Message from assistant,ASSISTANT", "Message from user,USER", "Message from system,SYSTEM" })
84+
void saveSingleMessage(String content, MessageType messageType) {
85+
this.contextRunner.run(context -> {
86+
var chatMemoryRepository = context.getBean(ChatMemoryRepository.class);
87+
var conversationId = UUID.randomUUID().toString();
88+
var message = switch (messageType) {
89+
case ASSISTANT -> new AssistantMessage(content + " - " + conversationId);
90+
case USER -> new UserMessage(content + " - " + conversationId);
91+
case SYSTEM -> new SystemMessage(content + " - " + conversationId);
92+
default -> throw new IllegalArgumentException("Type not supported: " + messageType);
93+
};
94+
95+
chatMemoryRepository.save(conversationId, List.of(message));
96+
97+
var jdbcTemplate = context.getBean(JdbcTemplate.class);
98+
var query = "SELECT conversation_id, content, type, \"timestamp\" FROM ai_chat_memory WHERE conversation_id = ?";
99+
var result = jdbcTemplate.queryForMap(query, conversationId);
100+
101+
assertThat(result.size()).isEqualTo(4);
102+
assertThat(result.get("conversation_id")).isEqualTo(conversationId);
103+
assertThat(result.get("content")).isEqualTo(message.getText());
104+
assertThat(result.get("type")).isEqualTo(messageType.name());
105+
assertThat(result.get("timestamp")).isInstanceOf(Timestamp.class);
106+
});
107+
}
108+
109+
@Test
110+
void saveMultipleMessages() {
111+
this.contextRunner.run(context -> {
112+
var chatMemoryRepository = context.getBean(ChatMemoryRepository.class);
113+
var conversationId = UUID.randomUUID().toString();
114+
var messages = List.<Message>of(new AssistantMessage("Message from assistant - " + conversationId),
115+
new UserMessage("Message from user - " + conversationId),
116+
new SystemMessage("Message from system - " + conversationId));
117+
118+
chatMemoryRepository.save(conversationId, messages);
119+
120+
var jdbcTemplate = context.getBean(JdbcTemplate.class);
121+
var query = "SELECT conversation_id, content, type, \"timestamp\" FROM ai_chat_memory WHERE conversation_id = ?";
122+
var results = jdbcTemplate.queryForList(query, conversationId);
123+
124+
assertThat(results.size()).isEqualTo(messages.size());
125+
126+
for (var i = 0; i < messages.size(); i++) {
127+
var message = messages.get(i);
128+
var result = results.get(i);
129+
130+
assertThat(result.get("conversation_id")).isNotNull();
131+
assertThat(result.get("conversation_id")).isEqualTo(conversationId);
132+
assertThat(result.get("content")).isEqualTo(message.getText());
133+
assertThat(result.get("type")).isEqualTo(message.getMessageType().name());
134+
assertThat(result.get("timestamp")).isInstanceOf(Timestamp.class);
135+
}
136+
});
137+
}
138+
139+
@Test
140+
void findMessagesByConversationId() {
141+
this.contextRunner.run(context -> {
142+
var chatMemoryRepository = context.getBean(ChatMemoryRepository.class);
143+
var conversationId = UUID.randomUUID().toString();
144+
var messages = List.<Message>of(new AssistantMessage("Message from assistant 1 - " + conversationId),
145+
new AssistantMessage("Message from assistant 2 - " + conversationId),
146+
new UserMessage("Message from user - " + conversationId),
147+
new SystemMessage("Message from system - " + conversationId));
148+
149+
chatMemoryRepository.save(conversationId, messages);
150+
151+
var results = chatMemoryRepository.findById(conversationId);
152+
153+
assertThat(results.size()).isEqualTo(messages.size());
154+
assertThat(results).isEqualTo(messages);
155+
});
156+
}
157+
158+
@Test
159+
void deleteMessagesByConversationId() {
160+
this.contextRunner.run(context -> {
161+
var chatMemoryRepository = context.getBean(ChatMemoryRepository.class);
162+
var conversationId = UUID.randomUUID().toString();
163+
var messages = List.<Message>of(new AssistantMessage("Message from assistant - " + conversationId),
164+
new UserMessage("Message from user - " + conversationId),
165+
new SystemMessage("Message from system - " + conversationId));
166+
167+
chatMemoryRepository.save(conversationId, messages);
168+
169+
chatMemoryRepository.deleteById(conversationId);
170+
171+
var jdbcTemplate = context.getBean(JdbcTemplate.class);
172+
var count = jdbcTemplate.queryForObject("SELECT COUNT(*) FROM ai_chat_memory WHERE conversation_id = ?",
173+
Integer.class, conversationId);
174+
175+
assertThat(count).isZero();
176+
});
177+
}
178+
179+
@SpringBootConfiguration
180+
@EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class })
181+
static class TestApplication {
182+
183+
@Bean
184+
ChatMemoryRepository chatMemoryRepository(JdbcTemplate jdbcTemplate) {
185+
JdbcChatMemoryConfig config = JdbcChatMemoryConfig.builder().jdbcTemplate(jdbcTemplate).build();
186+
return JdbcChatMemoryRepository.create(config);
187+
}
188+
189+
@Bean
190+
JdbcTemplate jdbcTemplate(DataSource dataSource) {
191+
return new JdbcTemplate(dataSource);
192+
}
193+
194+
@Bean
195+
@Primary
196+
@ConfigurationProperties("myapp.datasource")
197+
DataSourceProperties dataSourceProperties() {
198+
return new DataSourceProperties();
199+
}
200+
201+
@Bean
202+
public DataSource dataSource(DataSourceProperties dataSourceProperties) {
203+
return dataSourceProperties.initializeDataSourceBuilder().build();
204+
}
205+
206+
}
207+
208+
}

0 commit comments

Comments
 (0)