Skip to content

Commit f7ea0bd

Browse files
committed
feat: Support customizing the table name and related column names used for storing chat history when constructing the JdbcChatMemoryRepository.
Signed-off-by: Sun Yuhan <[email protected]>
1 parent 0569df7 commit f7ea0bd

File tree

2 files changed

+78
-17
lines changed

2 files changed

+78
-17
lines changed

memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryRepository.java

Lines changed: 69 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -48,30 +48,51 @@
4848
*/
4949
public class JdbcChatMemoryRepository implements ChatMemoryRepository {
5050

51-
private static final String QUERY_GET_IDS = """
52-
SELECT DISTINCT conversation_id FROM ai_chat_memory
53-
""";
51+
private final String queryGetIds;
5452

55-
private static final String QUERY_ADD = """
56-
INSERT INTO ai_chat_memory (conversation_id, content, type, "timestamp") VALUES (?, ?, ?, ?)
57-
""";
53+
private final String queryAdd;
5854

59-
private static final String QUERY_GET = """
60-
SELECT content, type FROM ai_chat_memory WHERE conversation_id = ? ORDER BY "timestamp"
61-
""";
55+
private final String queryGet;
6256

63-
private static final String QUERY_CLEAR = "DELETE FROM ai_chat_memory WHERE conversation_id = ?";
57+
private final String queryClear;
6458

6559
private final JdbcTemplate jdbcTemplate;
6660

67-
private JdbcChatMemoryRepository(JdbcTemplate jdbcTemplate) {
61+
public final static String DEFAULT_TABLE_NAME = "ai_chat_memory";
62+
public final static String DEFAULT_CONVERSION_ID_FIELD_NAME = "conversation_id";
63+
public final static String DEFAULT_CONTENT_FIELD_NAME = "content";
64+
public final static String DEFAULT_TYPE_FIELD_NAME = "type";
65+
public final static String DEFAULT_TIMESTAMP_FIELD_NAME = "\"timestamp\"";
66+
67+
public final static String DEFAULT_GET_IDS_QUERY = "SELECT DISTINCT %s FROM %s";
68+
public final static String DEFAULT_ADD_QUERY = "INSERT INTO %s (%s, %s, %s, %s) VALUES (?, ?, ?, ?)";
69+
public final static String DEFAULT_GET_QUERY = "SELECT %s, %s FROM %s WHERE %s = ? ORDER BY %s";
70+
public final static String DEFAULT_CLEAR_QUERY = "DELETE FROM %s WHERE %s = ?";
71+
72+
private JdbcChatMemoryRepository(JdbcTemplate jdbcTemplate,
73+
String tableName,
74+
String conversionIdFiledName,
75+
String contentFiledName,
76+
String typeFiledName,
77+
String timestampFiledName) {
6878
Assert.notNull(jdbcTemplate, "jdbcTemplate cannot be null");
79+
Assert.notNull(tableName, "tableName cannot be null");
80+
Assert.notNull(conversionIdFiledName, "conversionIdFiledName cannot be null");
81+
Assert.notNull(contentFiledName, "contentFiledName cannot be null");
82+
Assert.notNull(typeFiledName, "typeFiledName cannot be null");
83+
Assert.notNull(timestampFiledName, "timestampFiledName cannot be null");
6984
this.jdbcTemplate = jdbcTemplate;
85+
this.queryGetIds = DEFAULT_GET_IDS_QUERY.formatted(conversionIdFiledName, tableName);
86+
this.queryAdd = DEFAULT_ADD_QUERY.formatted(
87+
tableName, conversionIdFiledName, contentFiledName, typeFiledName, timestampFiledName);
88+
this.queryGet = DEFAULT_GET_QUERY.formatted(contentFiledName, typeFiledName,
89+
tableName, conversionIdFiledName, timestampFiledName);
90+
this.queryClear = DEFAULT_CLEAR_QUERY.formatted(tableName, conversionIdFiledName);
7091
}
7192

7293
@Override
7394
public List<String> findConversationIds() {
74-
List<String> conversationIds = this.jdbcTemplate.query(QUERY_GET_IDS, rs -> {
95+
List<String> conversationIds = this.jdbcTemplate.query(queryGetIds, rs -> {
7596
var ids = new ArrayList<String>();
7697
while (rs.next()) {
7798
ids.add(rs.getString(1));
@@ -84,7 +105,7 @@ public List<String> findConversationIds() {
84105
@Override
85106
public List<Message> findByConversationId(String conversationId) {
86107
Assert.hasText(conversationId, "conversationId cannot be null or empty");
87-
return this.jdbcTemplate.query(QUERY_GET, new MessageRowMapper(), conversationId);
108+
return this.jdbcTemplate.query(queryGet, new MessageRowMapper(), conversationId);
88109
}
89110

90111
@Override
@@ -93,13 +114,13 @@ public void saveAll(String conversationId, List<Message> messages) {
93114
Assert.notNull(messages, "messages cannot be null");
94115
Assert.noNullElements(messages, "messages cannot contain null elements");
95116
this.deleteByConversationId(conversationId);
96-
this.jdbcTemplate.batchUpdate(QUERY_ADD, new AddBatchPreparedStatement(conversationId, messages));
117+
this.jdbcTemplate.batchUpdate(queryAdd, new AddBatchPreparedStatement(conversationId, messages));
97118
}
98119

99120
@Override
100121
public void deleteByConversationId(String conversationId) {
101122
Assert.hasText(conversationId, "conversationId cannot be null or empty");
102-
this.jdbcTemplate.update(QUERY_CLEAR, conversationId);
123+
this.jdbcTemplate.update(queryClear, conversationId);
103124
}
104125

105126
private record AddBatchPreparedStatement(String conversationId, List<Message> messages,
@@ -154,6 +175,12 @@ public static class Builder {
154175

155176
private JdbcTemplate jdbcTemplate;
156177

178+
private String tableName = DEFAULT_TABLE_NAME;
179+
private String conversionIdFiledName = DEFAULT_CONVERSION_ID_FIELD_NAME;
180+
private String contentFiledName = DEFAULT_CONTENT_FIELD_NAME;
181+
private String typeFiledName = DEFAULT_TYPE_FIELD_NAME;
182+
private String timestampFiledName = DEFAULT_TIMESTAMP_FIELD_NAME;
183+
157184
private Builder() {
158185
}
159186

@@ -162,8 +189,34 @@ public Builder jdbcTemplate(JdbcTemplate jdbcTemplate) {
162189
return this;
163190
}
164191

192+
public Builder tableName(String tableName) {
193+
this.tableName = tableName;
194+
return this;
195+
}
196+
197+
public Builder conversionIdFiledName(String conversionIdFiledName) {
198+
this.conversionIdFiledName = conversionIdFiledName;
199+
return this;
200+
}
201+
202+
public Builder contentFiledName(String contentFiledName) {
203+
this.contentFiledName = contentFiledName;
204+
return this;
205+
}
206+
207+
public Builder typeFiledName(String typeFiledName) {
208+
this.typeFiledName = typeFiledName;
209+
return this;
210+
}
211+
212+
public Builder timestampFiledName(String timestampFiledName) {
213+
this.timestampFiledName = timestampFiledName;
214+
return this;
215+
}
216+
165217
public JdbcChatMemoryRepository build() {
166-
return new JdbcChatMemoryRepository(this.jdbcTemplate);
218+
return new JdbcChatMemoryRepository(this.jdbcTemplate, this.tableName, this.conversionIdFiledName,
219+
this.contentFiledName, this.typeFiledName, this.timestampFiledName);
167220
}
168221

169222
}

memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryRepositoryPostgresqlIT.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
import java.util.UUID;
4242

4343
import static org.assertj.core.api.Assertions.assertThat;
44+
import static org.springframework.ai.chat.memory.jdbc.JdbcChatMemoryRepository.*;
4445

4546
/**
4647
* Integration tests for {@link JdbcChatMemoryRepository}.
@@ -160,7 +161,14 @@ static class TestConfiguration {
160161

161162
@Bean
162163
ChatMemoryRepository chatMemoryRepository(JdbcTemplate jdbcTemplate) {
163-
return JdbcChatMemoryRepository.builder().jdbcTemplate(jdbcTemplate).build();
164+
return JdbcChatMemoryRepository.builder()
165+
.jdbcTemplate(jdbcTemplate)
166+
.tableName(DEFAULT_TABLE_NAME)
167+
.conversionIdFiledName(DEFAULT_CONVERSION_ID_FIELD_NAME)
168+
.contentFiledName(DEFAULT_CONTENT_FIELD_NAME)
169+
.typeFiledName(DEFAULT_TYPE_FIELD_NAME)
170+
.timestampFiledName(DEFAULT_TIMESTAMP_FIELD_NAME)
171+
.build();
164172
}
165173

166174
}

0 commit comments

Comments
 (0)