Skip to content

Commit 5099310

Browse files
committed
feat: implement refresh logic for each database type
Signed-off-by: potato <[email protected]>
1 parent a49f14c commit 5099310

File tree

3 files changed

+83
-0
lines changed

3 files changed

+83
-0
lines changed

memory/repository/spring-ai-model-chat-memory-repository-cassandra/src/main/java/org/springframework/ai/chat/memory/repository/cassandra/CassandraChatMemoryRepository.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,20 @@ public void deleteByConversationId(String conversationId) {
165165
saveAll(conversationId, List.of());
166166
}
167167

168+
@Override
169+
public void refresh(String conversationId, List<Message> deletes, List<Message> adds) {
170+
Assert.hasText(conversationId, "conversationId cannot be null or empty");
171+
Assert.notNull(deletes, "deletes cannot be null");
172+
Assert.notNull(adds, "adds cannot be null");
173+
174+
// RMW (Read-Modify-Write) is the only way with the current schema.
175+
// This is not efficient, but it is correct.
176+
List<Message> currentMessages = new ArrayList<>(this.findByConversationId(conversationId));
177+
currentMessages.removeAll(deletes);
178+
currentMessages.addAll(adds);
179+
this.saveAll(conversationId, currentMessages);
180+
}
181+
168182
private PreparedStatement prepareAddStmt() {
169183
RegularInsert stmt = null;
170184
InsertInto stmtStart = QueryBuilder.insertInto(this.conf.schema.keyspace(), this.conf.schema.table());

memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/main/java/org/springframework/ai/chat/memory/repository/jdbc/JdbcChatMemoryRepository.java

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,51 @@ public void deleteByConversationId(String conversationId) {
106106
this.jdbcTemplate.update(this.dialect.getDeleteMessagesSql(), conversationId);
107107
}
108108

109+
@Override
110+
public void refresh(String conversationId, List<Message> deletes, List<Message> adds) {
111+
Assert.hasText(conversationId, "conversationId cannot be null or empty");
112+
Assert.notNull(deletes, "deletes cannot be null");
113+
Assert.notNull(adds, "adds cannot be null");
114+
115+
this.transactionTemplate.execute(status -> {
116+
if (!deletes.isEmpty()) {
117+
// This is a simplification. In a real implementation, we would need a
118+
// stable
119+
// way to identify messages to delete, perhaps by adding a message_id
120+
// column.
121+
// For now, we delete based on content and type, which is not robust.
122+
this.jdbcTemplate.batchUpdate(this.dialect.getDeleteMessageSql(),
123+
new DeleteBatchPreparedStatement(conversationId, deletes));
124+
}
125+
if (!adds.isEmpty()) {
126+
this.jdbcTemplate.batchUpdate(this.dialect.getInsertMessageSql(),
127+
new AddBatchPreparedStatement(conversationId, adds));
128+
}
129+
return null;
130+
});
131+
}
132+
109133
public static Builder builder() {
110134
return new Builder();
111135
}
112136

137+
private record DeleteBatchPreparedStatement(String conversationId,
138+
List<Message> messages) implements BatchPreparedStatementSetter {
139+
140+
@Override
141+
public void setValues(PreparedStatement ps, int i) throws SQLException {
142+
var message = this.messages.get(i);
143+
ps.setString(1, this.conversationId);
144+
ps.setString(2, message.getText());
145+
ps.setString(3, message.getMessageType().name());
146+
}
147+
148+
@Override
149+
public int getBatchSize() {
150+
return this.messages.size();
151+
}
152+
}
153+
113154
private record AddBatchPreparedStatement(String conversationId, List<Message> messages,
114155
AtomicLong instantSeq) implements BatchPreparedStatementSetter {
115156

memory/repository/spring-ai-model-chat-memory-repository-neo4j/src/main/java/org/springframework/ai/chat/memory/repository/neo4j/Neo4jChatMemoryRepository.java

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,34 @@ OPTIONAL MATCH (m)-[:HAS_TOOL_CALL]->(tc:%s)
166166
}
167167
}
168168

169+
@Override
170+
public void refresh(String conversationId, List<Message> deletes, List<Message> adds) {
171+
try (Session s = this.config.getDriver().session()) {
172+
s.executeWriteWithoutResult(tx -> {
173+
if (!deletes.isEmpty()) {
174+
List<String> messageIds = deletes.stream().map(m -> (String) m.getMetadata().get("id")).toList();
175+
176+
String deleteStatement = """
177+
MATCH (m:%s) WHERE m.id IN $messageIds
178+
OPTIONAL MATCH (m)-[:HAS_METADATA]->(metadata:%s)
179+
OPTIONAL MATCH (m)-[:HAS_MEDIA]->(media:%s)
180+
OPTIONAL MATCH (m)-[:HAS_TOOL_RESPONSE]-(tr:%s)
181+
OPTIONAL MATCH (m)-[:HAS_TOOL_CALL]->(tc:%s)
182+
DETACH DELETE m, metadata, media, tr, tc
183+
""".formatted(this.config.getMessageLabel(), this.config.getMetadataLabel(),
184+
this.config.getMediaLabel(), this.config.getToolResponseLabel(),
185+
this.config.getToolCallLabel());
186+
tx.run(deleteStatement, Map.of("messageIds", messageIds));
187+
}
188+
if (!adds.isEmpty()) {
189+
for (Message m : adds) {
190+
addMessageToTransaction(tx, conversationId, m);
191+
}
192+
}
193+
});
194+
}
195+
}
196+
169197
public Neo4jChatMemoryRepositoryConfig getConfig() {
170198
return this.config;
171199
}

0 commit comments

Comments
 (0)