Skip to content

Commit 58ac510

Browse files
Implement CassandraChatMemoryRepository
ref: #2998
1 parent 10ff11d commit 58ac510

File tree

1 file changed

+245
-0
lines changed

1 file changed

+245
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
/*
2+
* Copyright 2023-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.chat.memory.cassandra;
18+
19+
import java.time.Instant;
20+
import java.util.ArrayList;
21+
import java.util.Collections;
22+
import java.util.List;
23+
import java.util.concurrent.atomic.AtomicLong;
24+
25+
import com.datastax.oss.driver.api.core.cql.BoundStatement;
26+
import com.datastax.oss.driver.api.core.cql.BoundStatementBuilder;
27+
import com.datastax.oss.driver.api.core.cql.PreparedStatement;
28+
import com.datastax.oss.driver.api.core.cql.Row;
29+
import com.datastax.oss.driver.api.querybuilder.QueryBuilder;
30+
import com.datastax.oss.driver.api.querybuilder.delete.Delete;
31+
import com.datastax.oss.driver.api.querybuilder.delete.DeleteSelection;
32+
import com.datastax.oss.driver.api.querybuilder.insert.InsertInto;
33+
import com.datastax.oss.driver.api.querybuilder.insert.RegularInsert;
34+
import com.datastax.oss.driver.api.querybuilder.select.Select;
35+
import com.datastax.oss.driver.api.querybuilder.select.Selector;
36+
import com.datastax.oss.driver.shaded.guava.common.base.Preconditions;
37+
import org.springframework.ai.chat.memory.ChatMemoryRepository;
38+
import org.springframework.ai.chat.messages.AssistantMessage;
39+
import org.springframework.ai.chat.messages.Message;
40+
import org.springframework.ai.chat.messages.UserMessage;
41+
import org.springframework.util.Assert;
42+
43+
import static org.springframework.ai.chat.memory.cassandra.CassandraChatMemory.CONVERSATION_TS;
44+
import static org.springframework.ai.chat.messages.MessageType.ASSISTANT;
45+
import static org.springframework.ai.chat.messages.MessageType.USER;
46+
47+
/**
48+
* An implementation of {@link ChatMemoryRepository} for Apache Cassandra.
49+
*
50+
* @author Mick Semb Wever
51+
* @since 1.0.0
52+
*/
53+
public class CassandraChatMemoryRepository implements ChatMemoryRepository {
54+
55+
private final CassandraChatMemoryConfig conf;
56+
57+
private final PreparedStatement allStmt;
58+
59+
private final PreparedStatement addUserStmt;
60+
61+
private final PreparedStatement addAssistantStmt;
62+
63+
private final PreparedStatement getStmt;
64+
65+
private final PreparedStatement deleteStmt;
66+
67+
private CassandraChatMemoryRepository(CassandraChatMemoryConfig conf) {
68+
Assert.notNull(conf, "conf cannot be null");
69+
this.conf = conf;
70+
this.conf.ensureSchemaExists();
71+
this.allStmt = prepareAllStatement();
72+
this.addUserStmt = prepareAddStmt(this.conf.userColumn);
73+
this.addAssistantStmt = prepareAddStmt(this.conf.assistantColumn);
74+
this.getStmt = prepareGetStatement();
75+
this.deleteStmt = prepareDeleteStmt();
76+
}
77+
78+
public static CassandraChatMemoryRepository create(CassandraChatMemoryConfig conf) {
79+
return new CassandraChatMemoryRepository(conf);
80+
}
81+
82+
@Override
83+
public List<String> findConversationIds() {
84+
List<String> conversationIds = new ArrayList<>();
85+
long token = Long.MIN_VALUE;
86+
boolean emptyQuery = false;
87+
88+
while (emptyQuery || token < Long.MAX_VALUE) {
89+
BoundStatement stmt = this.allStmt.boundStatementBuilder().setLong("token", token).build();
90+
emptyQuery = true;
91+
for (Row r : this.conf.session.execute(stmt)) {
92+
emptyQuery = false;
93+
conversationIds.add(r.getString(CassandraChatMemoryConfig.DEFAULT_SESSION_ID_NAME));
94+
token = r.getLong("t");
95+
}
96+
}
97+
return List.copyOf(conversationIds);
98+
}
99+
100+
@Override
101+
public List<Message> findByConversationId(String conversationId) {
102+
Assert.hasText(conversationId, "conversationId cannot be null or empty");
103+
104+
List<Object> primaryKeys = this.conf.primaryKeyTranslator.apply(conversationId);
105+
BoundStatementBuilder builder = this.getStmt.boundStatementBuilder().setInt("lastN", Integer.MAX_VALUE);
106+
107+
for (int k = 0; k < primaryKeys.size(); ++k) {
108+
CassandraChatMemoryConfig.SchemaColumn keyColumn = this.conf.getPrimaryKeyColumn(k);
109+
builder = builder.set(keyColumn.name(), primaryKeys.get(k), keyColumn.javaType());
110+
}
111+
112+
List<Message> messages = new ArrayList<>();
113+
for (Row r : this.conf.session.execute(builder.build())) {
114+
String assistant = r.getString(this.conf.assistantColumn);
115+
String user = r.getString(this.conf.userColumn);
116+
if (null != assistant) {
117+
messages.add(new AssistantMessage(assistant));
118+
}
119+
if (null != user) {
120+
messages.add(new UserMessage(user));
121+
}
122+
}
123+
Collections.reverse(messages);
124+
return messages;
125+
}
126+
127+
@Override
128+
public void saveAll(String conversationId, List<Message> messages) {
129+
Assert.hasText(conversationId, "conversationId cannot be null or empty");
130+
Assert.notNull(messages, "messages cannot be null");
131+
Assert.noNullElements(messages, "messages cannot contain null elements");
132+
133+
final AtomicLong instantSeq = new AtomicLong(Instant.now().toEpochMilli());
134+
messages.forEach(msg -> {
135+
if (msg.getMetadata().containsKey(CONVERSATION_TS)) {
136+
msg.getMetadata().put(CONVERSATION_TS, Instant.ofEpochMilli(instantSeq.getAndIncrement()));
137+
}
138+
save(conversationId, msg);
139+
});
140+
}
141+
142+
private void save(String conversationId, Message msg) {
143+
144+
Preconditions.checkArgument(
145+
!msg.getMetadata().containsKey(CONVERSATION_TS)
146+
|| msg.getMetadata().get(CONVERSATION_TS) instanceof Instant,
147+
"messages only accept metadata '%s' entries of type Instant", CONVERSATION_TS);
148+
149+
msg.getMetadata().putIfAbsent(CONVERSATION_TS, Instant.now());
150+
151+
PreparedStatement stmt = getStatement(msg);
152+
153+
List<Object> primaryKeys = this.conf.primaryKeyTranslator.apply(conversationId);
154+
BoundStatementBuilder builder = stmt.boundStatementBuilder();
155+
156+
for (int k = 0; k < primaryKeys.size(); ++k) {
157+
CassandraChatMemoryConfig.SchemaColumn keyColumn = this.conf.getPrimaryKeyColumn(k);
158+
builder = builder.set(keyColumn.name(), primaryKeys.get(k), keyColumn.javaType());
159+
}
160+
161+
Instant instant = (Instant) msg.getMetadata().get(CONVERSATION_TS);
162+
163+
builder = builder.setInstant(CassandraChatMemoryConfig.DEFAULT_EXCHANGE_ID_NAME, instant)
164+
.setString("message", msg.getText());
165+
166+
this.conf.session.execute(builder.build());
167+
}
168+
169+
@Override
170+
public void deleteByConversationId(String conversationId) {
171+
Assert.hasText(conversationId, "conversationId cannot be null or empty");
172+
173+
List<Object> primaryKeys = this.conf.primaryKeyTranslator.apply(conversationId);
174+
BoundStatementBuilder builder = this.deleteStmt.boundStatementBuilder();
175+
176+
for (int k = 0; k < primaryKeys.size(); ++k) {
177+
CassandraChatMemoryConfig.SchemaColumn keyColumn = this.conf.getPrimaryKeyColumn(k);
178+
builder = builder.set(keyColumn.name(), primaryKeys.get(k), keyColumn.javaType());
179+
}
180+
181+
this.conf.session.execute(builder.build());
182+
}
183+
184+
private PreparedStatement prepareAddStmt(String column) {
185+
RegularInsert stmt = null;
186+
InsertInto stmtStart = QueryBuilder.insertInto(this.conf.schema.keyspace(), this.conf.schema.table());
187+
for (var c : this.conf.schema.partitionKeys()) {
188+
stmt = (null != stmt ? stmt : stmtStart).value(c.name(), QueryBuilder.bindMarker(c.name()));
189+
}
190+
for (var c : this.conf.schema.clusteringKeys()) {
191+
stmt = stmt.value(c.name(), QueryBuilder.bindMarker(c.name()));
192+
}
193+
stmt = stmt.value(column, QueryBuilder.bindMarker("message"));
194+
return this.conf.session.prepare(stmt.build());
195+
}
196+
197+
private PreparedStatement prepareAllStatement() {
198+
Select stmt = QueryBuilder.selectFrom(this.conf.schema.keyspace(), this.conf.schema.table())
199+
.distinct()
200+
.function("token", Selector.column(CassandraChatMemoryConfig.DEFAULT_SESSION_ID_NAME))
201+
.as("t")
202+
.column(CassandraChatMemoryConfig.DEFAULT_SESSION_ID_NAME)
203+
.whereToken(CassandraChatMemoryConfig.DEFAULT_SESSION_ID_NAME)
204+
.isGreaterThan(QueryBuilder.bindMarker("token"))
205+
.limit(1000)
206+
.allowFiltering();
207+
208+
return this.conf.session.prepare(stmt.build());
209+
}
210+
211+
private PreparedStatement prepareGetStatement() {
212+
Select stmt = QueryBuilder.selectFrom(this.conf.schema.keyspace(), this.conf.schema.table()).all();
213+
for (var c : this.conf.schema.partitionKeys()) {
214+
stmt = stmt.whereColumn(c.name()).isEqualTo(QueryBuilder.bindMarker(c.name()));
215+
}
216+
for (int i = 0; i + 1 < this.conf.schema.clusteringKeys().size(); ++i) {
217+
String columnName = this.conf.schema.clusteringKeys().get(i).name();
218+
stmt = stmt.whereColumn(columnName).isEqualTo(QueryBuilder.bindMarker(columnName));
219+
}
220+
stmt = stmt.limit(QueryBuilder.bindMarker("lastN"));
221+
return this.conf.session.prepare(stmt.build());
222+
}
223+
224+
private PreparedStatement prepareDeleteStmt() {
225+
Delete stmt = null;
226+
DeleteSelection stmtStart = QueryBuilder.deleteFrom(this.conf.schema.keyspace(), this.conf.schema.table());
227+
for (var c : this.conf.schema.partitionKeys()) {
228+
stmt = (null != stmt ? stmt : stmtStart).whereColumn(c.name()).isEqualTo(QueryBuilder.bindMarker(c.name()));
229+
}
230+
for (int i = 0; i + 1 < this.conf.schema.clusteringKeys().size(); ++i) {
231+
String columnName = this.conf.schema.clusteringKeys().get(i).name();
232+
stmt = stmt.whereColumn(columnName).isEqualTo(QueryBuilder.bindMarker(columnName));
233+
}
234+
return this.conf.session.prepare(stmt.build());
235+
}
236+
237+
private PreparedStatement getStatement(Message msg) {
238+
return switch (msg.getMessageType()) {
239+
case USER -> this.addUserStmt;
240+
case ASSISTANT -> this.addAssistantStmt;
241+
default -> throw new IllegalArgumentException("Cant add type " + msg);
242+
};
243+
}
244+
245+
}

0 commit comments

Comments
 (0)