Skip to content

Commit a5649da

Browse files
Optimise CassandraChatMemoryRepository for MessageWindowChatMemory usage pattern
Time-series each chat window in Cassandra, keeping past (and deleted) windows still in the db. Add ability to store different MessageTypes. Signed-off-by: mck <[email protected]>
1 parent 30eb3ce commit a5649da

File tree

8 files changed

+185
-189
lines changed

8 files changed

+185
-189
lines changed

auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-cassandra/src/main/java/org/springframework/ai/model/chat/memory/repository/cassandra/autoconfigure/CassandraChatMemoryRepositoryAutoConfiguration.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,7 @@ public CassandraChatMemoryRepository cassandraChatMemoryRepository(
4949

5050
builder = builder.withKeyspaceName(properties.getKeyspace())
5151
.withTableName(properties.getTable())
52-
.withAssistantColumnName(properties.getAssistantColumn())
53-
.withUserColumnName(properties.getUserColumn());
52+
.withMessagesColumnName(properties.getMessagesColumn());
5453

5554
if (!properties.isInitializeSchema()) {
5655
builder = builder.disallowSchemaChanges();

auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-cassandra/src/main/java/org/springframework/ai/model/chat/memory/repository/cassandra/autoconfigure/CassandraChatMemoryRepositoryProperties.java

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,6 @@
1818

1919
import java.time.Duration;
2020

21-
import org.slf4j.Logger;
22-
import org.slf4j.LoggerFactory;
23-
2421
import org.springframework.ai.chat.memory.cassandra.CassandraChatMemoryRepositoryConfig;
2522
import org.springframework.boot.context.properties.ConfigurationProperties;
2623
import org.springframework.lang.Nullable;
@@ -35,17 +32,13 @@
3532
@ConfigurationProperties(CassandraChatMemoryRepositoryProperties.CONFIG_PREFIX)
3633
public class CassandraChatMemoryRepositoryProperties {
3734

38-
public static final String CONFIG_PREFIX = "spring.ai.chat.memory.repository.cassandra";
39-
40-
private static final Logger logger = LoggerFactory.getLogger(CassandraChatMemoryRepositoryProperties.class);
35+
public static final String CONFIG_PREFIX = "spring.ai.chat.memory.cassandra";
4136

4237
private String keyspace = CassandraChatMemoryRepositoryConfig.DEFAULT_KEYSPACE_NAME;
4338

4439
private String table = CassandraChatMemoryRepositoryConfig.DEFAULT_TABLE_NAME;
4540

46-
private String assistantColumn = CassandraChatMemoryRepositoryConfig.DEFAULT_ASSISTANT_COLUMN_NAME;
47-
48-
private String userColumn = CassandraChatMemoryRepositoryConfig.DEFAULT_USER_COLUMN_NAME;
41+
private String messagesColumn = CassandraChatMemoryRepositoryConfig.DEFAULT_MESSAGES_COLUMN_NAME;
4942

5043
private boolean initializeSchema = true;
5144

@@ -75,20 +68,12 @@ public void setTable(String table) {
7568
this.table = table;
7669
}
7770

78-
public String getAssistantColumn() {
79-
return this.assistantColumn;
80-
}
81-
82-
public void setAssistantColumn(String assistantColumn) {
83-
this.assistantColumn = assistantColumn;
84-
}
85-
86-
public String getUserColumn() {
87-
return this.userColumn;
71+
public String getMessagesColumn() {
72+
return this.messagesColumn;
8873
}
8974

90-
public void setUserColumn(String userColumn) {
91-
this.userColumn = userColumn;
75+
public void setMessagesColumn(String messagesColumn) {
76+
this.messagesColumn = messagesColumn;
9277
}
9378

9479
@Nullable

auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-cassandra/src/test/java/org/springframework/ai/model/chat/memory/repository/cassandra/autoconfigure/CassandraChatMemoryRepositoryAutoConfigurationIT.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ void addAndGet() {
5959
this.contextRunner.withPropertyValues("spring.cassandra.contactPoints=" + getContactPointHost())
6060
.withPropertyValues("spring.cassandra.port=" + getContactPointPort())
6161
.withPropertyValues("spring.cassandra.localDatacenter=" + cassandraContainer.getLocalDatacenter())
62-
.withPropertyValues("spring.ai.chat.memory.repository.cassandra.time-to-live=" + getTimeToLive())
62+
.withPropertyValues("spring.ai.chat.memory.cassandra.time-to-live=" + getTimeToLive())
6363
.run(context -> {
6464
CassandraChatMemoryRepository memory = context.getBean(CassandraChatMemoryRepository.class);
6565

@@ -96,7 +96,7 @@ void compareTimeToLive_ISO8601Format() {
9696
this.contextRunner.withPropertyValues("spring.cassandra.contactPoints=" + getContactPointHost())
9797
.withPropertyValues("spring.cassandra.port=" + getContactPointPort())
9898
.withPropertyValues("spring.cassandra.localDatacenter=" + cassandraContainer.getLocalDatacenter())
99-
.withPropertyValues("spring.ai.chat.memory.repository.cassandra.time-to-live=" + getTimeToLiveString())
99+
.withPropertyValues("spring.ai.chat.memory.cassandra.time-to-live=" + getTimeToLiveString())
100100
.run(context -> {
101101
CassandraChatMemoryRepositoryProperties properties = context
102102
.getBean(CassandraChatMemoryRepositoryProperties.class);

auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-cassandra/src/test/java/org/springframework/ai/model/chat/memory/repository/cassandra/autoconfigure/CassandraChatMemoryRepositoryPropertiesTest.java

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ void defaultValues() {
3636
var props = new CassandraChatMemoryRepositoryProperties();
3737
assertThat(props.getKeyspace()).isEqualTo(CassandraChatMemoryRepositoryConfig.DEFAULT_KEYSPACE_NAME);
3838
assertThat(props.getTable()).isEqualTo(CassandraChatMemoryRepositoryConfig.DEFAULT_TABLE_NAME);
39-
assertThat(props.getAssistantColumn())
40-
.isEqualTo(CassandraChatMemoryRepositoryConfig.DEFAULT_ASSISTANT_COLUMN_NAME);
41-
assertThat(props.getUserColumn()).isEqualTo(CassandraChatMemoryRepositoryConfig.DEFAULT_USER_COLUMN_NAME);
39+
assertThat(props.getMessagesColumn())
40+
.isEqualTo(CassandraChatMemoryRepositoryConfig.DEFAULT_MESSAGES_COLUMN_NAME);
41+
4242
assertThat(props.getTimeToLive()).isNull();
4343
assertThat(props.isInitializeSchema()).isTrue();
4444
}
@@ -48,15 +48,13 @@ void customValues() {
4848
var props = new CassandraChatMemoryRepositoryProperties();
4949
props.setKeyspace("my_keyspace");
5050
props.setTable("my_table");
51-
props.setAssistantColumn("my_assistant_column");
52-
props.setUserColumn("my_user_column");
51+
props.setMessagesColumn("my_messages_column");
5352
props.setTimeToLive(Duration.ofDays(1));
5453
props.setInitializeSchema(false);
5554

5655
assertThat(props.getKeyspace()).isEqualTo("my_keyspace");
5756
assertThat(props.getTable()).isEqualTo("my_table");
58-
assertThat(props.getAssistantColumn()).isEqualTo("my_assistant_column");
59-
assertThat(props.getUserColumn()).isEqualTo("my_user_column");
57+
assertThat(props.getMessagesColumn()).isEqualTo("my_messages_column");
6058
assertThat(props.getTimeToLive()).isEqualTo(Duration.ofDays(1));
6159
assertThat(props.isInitializeSchema()).isFalse();
6260
}

memory/spring-ai-model-chat-memory-cassandra/src/main/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryRepository.java

Lines changed: 61 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,24 @@
2626
import com.datastax.oss.driver.api.core.cql.BoundStatementBuilder;
2727
import com.datastax.oss.driver.api.core.cql.PreparedStatement;
2828
import com.datastax.oss.driver.api.core.cql.Row;
29+
import com.datastax.oss.driver.api.core.data.UdtValue;
30+
import com.datastax.oss.driver.api.core.type.UserDefinedType;
2931
import com.datastax.oss.driver.api.querybuilder.QueryBuilder;
32+
import com.datastax.oss.driver.api.querybuilder.SchemaBuilder;
3033
import com.datastax.oss.driver.api.querybuilder.delete.Delete;
3134
import com.datastax.oss.driver.api.querybuilder.delete.DeleteSelection;
3235
import com.datastax.oss.driver.api.querybuilder.insert.InsertInto;
3336
import com.datastax.oss.driver.api.querybuilder.insert.RegularInsert;
3437
import com.datastax.oss.driver.api.querybuilder.select.Select;
3538
import com.datastax.oss.driver.shaded.guava.common.base.Preconditions;
39+
import java.util.Map;
3640
import org.springframework.ai.chat.memory.ChatMemoryRepository;
3741
import org.springframework.ai.chat.messages.AssistantMessage;
3842
import org.springframework.ai.chat.messages.Message;
43+
import org.springframework.ai.chat.messages.MessageType;
44+
import static org.springframework.ai.chat.messages.MessageType.ASSISTANT;
45+
import org.springframework.ai.chat.messages.SystemMessage;
46+
import org.springframework.ai.chat.messages.ToolResponseMessage;
3947
import org.springframework.ai.chat.messages.UserMessage;
4048
import org.springframework.util.Assert;
4149

@@ -54,23 +62,17 @@ public class CassandraChatMemoryRepository implements ChatMemoryRepository {
5462

5563
private final PreparedStatement allStmt;
5664

57-
private final PreparedStatement addUserStmt;
58-
59-
private final PreparedStatement addAssistantStmt;
65+
private final PreparedStatement addStmt;
6066

6167
private final PreparedStatement getStmt;
6268

63-
private final PreparedStatement deleteStmt;
64-
6569
private CassandraChatMemoryRepository(CassandraChatMemoryRepositoryConfig conf) {
6670
Assert.notNull(conf, "conf cannot be null");
6771
this.conf = conf;
6872
this.conf.ensureSchemaExists();
6973
this.allStmt = prepareAllStatement();
70-
this.addUserStmt = prepareAddStmt(this.conf.userColumn);
71-
this.addAssistantStmt = prepareAddStmt(this.conf.assistantColumn);
74+
this.addStmt = prepareAddStmt();
7275
this.getStmt = prepareGetStatement();
73-
this.deleteStmt = prepareDeleteStmt();
7476
}
7577

7678
public static CassandraChatMemoryRepository create(CassandraChatMemoryRepositoryConfig conf) {
@@ -97,6 +99,10 @@ public List<String> findConversationIds() {
9799

98100
@Override
99101
public List<Message> findByConversationId(String conversationId) {
102+
return findByConversationIdWithLimit(conversationId, 1);
103+
}
104+
105+
List<Message> findByConversationIdWithLimit(String conversationId, int limit) {
100106
Assert.hasText(conversationId, "conversationId cannot be null or empty");
101107

102108
List<Object> primaryKeys = this.conf.primaryKeyTranslator.apply(conversationId);
@@ -106,19 +112,14 @@ public List<Message> findByConversationId(String conversationId) {
106112
CassandraChatMemoryRepositoryConfig.SchemaColumn keyColumn = this.conf.getPrimaryKeyColumn(k);
107113
builder = builder.set(keyColumn.name(), primaryKeys.get(k), keyColumn.javaType());
108114
}
115+
builder = builder.setInt("legacy_limit", limit);
109116

110117
List<Message> messages = new ArrayList<>();
111118
for (Row r : this.conf.session.execute(builder.build())) {
112-
String assistant = r.getString(this.conf.assistantColumn);
113-
String user = r.getString(this.conf.userColumn);
114-
if (null != assistant) {
115-
messages.add(new AssistantMessage(assistant));
116-
}
117-
if (null != user) {
118-
messages.add(new UserMessage(user));
119+
for (UdtValue udt : r.getList(this.conf.messagesColumn, UdtValue.class)) {
120+
messages.add(getMessage(udt));
119121
}
120122
}
121-
Collections.reverse(messages);
122123
return messages;
123124
}
124125

@@ -128,58 +129,49 @@ public void saveAll(String conversationId, List<Message> messages) {
128129
Assert.notNull(messages, "messages cannot be null");
129130
Assert.noNullElements(messages, "messages cannot contain null elements");
130131

131-
final AtomicLong instantSeq = new AtomicLong(Instant.now().toEpochMilli());
132-
messages.forEach(msg -> {
133-
if (msg.getMetadata().containsKey(CONVERSATION_TS)) {
134-
msg.getMetadata().put(CONVERSATION_TS, Instant.ofEpochMilli(instantSeq.getAndIncrement()));
135-
}
136-
save(conversationId, msg);
137-
});
138-
}
139-
140-
void save(String conversationId, Message msg) {
141-
142-
Preconditions.checkArgument(
143-
!msg.getMetadata().containsKey(CONVERSATION_TS)
144-
|| msg.getMetadata().get(CONVERSATION_TS) instanceof Instant,
145-
"messages only accept metadata '%s' entries of type Instant", CONVERSATION_TS);
146-
147-
msg.getMetadata().putIfAbsent(CONVERSATION_TS, Instant.now());
148-
149-
PreparedStatement stmt = getStatement(msg);
150-
132+
Instant instant = Instant.now();
151133
List<Object> primaryKeys = this.conf.primaryKeyTranslator.apply(conversationId);
152-
BoundStatementBuilder builder = stmt.boundStatementBuilder();
134+
BoundStatementBuilder builder = addStmt.boundStatementBuilder();
153135

154136
for (int k = 0; k < primaryKeys.size(); ++k) {
155137
CassandraChatMemoryRepositoryConfig.SchemaColumn keyColumn = this.conf.getPrimaryKeyColumn(k);
156138
builder = builder.set(keyColumn.name(), primaryKeys.get(k), keyColumn.javaType());
157139
}
158140

159-
Instant instant = (Instant) msg.getMetadata().get(CONVERSATION_TS);
141+
List<UdtValue> msgs = new ArrayList<>();
142+
for (Message msg : messages) {
143+
144+
Preconditions.checkArgument(
145+
!msg.getMetadata().containsKey(CONVERSATION_TS)
146+
|| msg.getMetadata().get(CONVERSATION_TS) instanceof Instant,
147+
"messages only accept metadata '%s' entries of type Instant", CONVERSATION_TS);
160148

149+
msg.getMetadata().putIfAbsent(CONVERSATION_TS, instant);
150+
151+
UdtValue udt = this.conf.session.getMetadata()
152+
.getKeyspace(this.conf.schema.keyspace())
153+
.get()
154+
.getUserDefinedType(this.conf.messageUDT)
155+
.get()
156+
.newValue()
157+
.setInstant(this.conf.messageUdtTimestampColumn, (Instant) msg.getMetadata().get(CONVERSATION_TS))
158+
.setString(this.conf.messageUdtTypeColumn, msg.getMessageType().name())
159+
.setString(this.conf.messageUdtContentColumn, msg.getText());
160+
161+
msgs.add(udt);
162+
}
161163
builder = builder.setInstant(CassandraChatMemoryRepositoryConfig.DEFAULT_EXCHANGE_ID_NAME, instant)
162-
.setString("message", msg.getText());
164+
.setList("msgs", msgs, UdtValue.class);
163165

164166
this.conf.session.execute(builder.build());
165167
}
166168

167169
@Override
168170
public void deleteByConversationId(String conversationId) {
169-
Assert.hasText(conversationId, "conversationId cannot be null or empty");
170-
171-
List<Object> primaryKeys = this.conf.primaryKeyTranslator.apply(conversationId);
172-
BoundStatementBuilder builder = this.deleteStmt.boundStatementBuilder();
173-
174-
for (int k = 0; k < primaryKeys.size(); ++k) {
175-
CassandraChatMemoryRepositoryConfig.SchemaColumn keyColumn = this.conf.getPrimaryKeyColumn(k);
176-
builder = builder.set(keyColumn.name(), primaryKeys.get(k), keyColumn.javaType());
177-
}
178-
179-
this.conf.session.execute(builder.build());
171+
saveAll(conversationId, List.of());
180172
}
181173

182-
private PreparedStatement prepareAddStmt(String column) {
174+
private PreparedStatement prepareAddStmt() {
183175
RegularInsert stmt = null;
184176
InsertInto stmtStart = QueryBuilder.insertInto(this.conf.schema.keyspace(), this.conf.schema.table());
185177
for (var c : this.conf.schema.partitionKeys()) {
@@ -188,7 +180,7 @@ private PreparedStatement prepareAddStmt(String column) {
188180
for (var c : this.conf.schema.clusteringKeys()) {
189181
stmt = stmt.value(c.name(), QueryBuilder.bindMarker(c.name()));
190182
}
191-
stmt = stmt.value(column, QueryBuilder.bindMarker("message"));
183+
stmt = stmt.value(this.conf.messagesColumn, QueryBuilder.bindMarker("msgs"));
192184
return this.conf.session.prepare(stmt.build());
193185
}
194186

@@ -214,28 +206,27 @@ private PreparedStatement prepareGetStatement() {
214206
String columnName = this.conf.schema.clusteringKeys().get(i).name();
215207
stmt = stmt.whereColumn(columnName).isEqualTo(QueryBuilder.bindMarker(columnName));
216208
}
209+
stmt = stmt.limit(QueryBuilder.bindMarker("legacy_limit"));
217210
return this.conf.session.prepare(stmt.build());
218211
}
219212

220-
private PreparedStatement prepareDeleteStmt() {
221-
Delete stmt = null;
222-
DeleteSelection stmtStart = QueryBuilder.deleteFrom(this.conf.schema.keyspace(), this.conf.schema.table());
223-
for (var c : this.conf.schema.partitionKeys()) {
224-
stmt = (null != stmt ? stmt : stmtStart).whereColumn(c.name()).isEqualTo(QueryBuilder.bindMarker(c.name()));
225-
}
226-
for (int i = 0; i + 1 < this.conf.schema.clusteringKeys().size(); ++i) {
227-
String columnName = this.conf.schema.clusteringKeys().get(i).name();
228-
stmt = stmt.whereColumn(columnName).isEqualTo(QueryBuilder.bindMarker(columnName));
213+
private Message getMessage(UdtValue udt) {
214+
String content = udt.getString(this.conf.messageUdtContentColumn);
215+
Map<String, Object> props = Map.of(CONVERSATION_TS, udt.getInstant(this.conf.messageUdtTimestampColumn));
216+
switch (MessageType.valueOf(udt.getString(this.conf.messageUdtTypeColumn))) {
217+
case ASSISTANT:
218+
return new AssistantMessage(content, props);
219+
case USER:
220+
return UserMessage.builder().text(content).metadata(props).build();
221+
case SYSTEM:
222+
return SystemMessage.builder().text(content).metadata(props).build();
223+
case TOOL:
224+
// todo – persist ToolResponse somehow
225+
return new ToolResponseMessage(List.of(), props);
226+
default:
227+
throw new IllegalStateException(
228+
String.format("unknown message type %s", udt.getString(this.conf.messageUdtTypeColumn)));
229229
}
230-
return this.conf.session.prepare(stmt.build());
231-
}
232-
233-
private PreparedStatement getStatement(Message msg) {
234-
return switch (msg.getMessageType()) {
235-
case USER -> this.addUserStmt;
236-
case ASSISTANT -> this.addAssistantStmt;
237-
default -> throw new IllegalArgumentException("Cant add type " + msg);
238-
};
239230
}
240231

241232
}

0 commit comments

Comments
 (0)