Skip to content

Commit af0c80d

Browse files
Optimise CassandraChatMemoryRepository for MessageWindowChatMemory usage pattern
Time-series each chat window in Cassandra, keeping past (and deleted) windows still in the db. Add ability to store different MessageTypes. Signed-off-by: mck <[email protected]>
1 parent fe4284b commit af0c80d

File tree

9 files changed

+194
-172
lines changed

9 files changed

+194
-172
lines changed

auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-cassandra/src/main/java/org/springframework/ai/model/chat/memory/repository/cassandra/autoconfigure/CassandraChatMemoryAutoConfiguration.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,7 @@ public CassandraChatMemoryRepository chatMemory(CassandraChatMemoryProperties pr
4848

4949
builder = builder.withKeyspaceName(properties.getKeyspace())
5050
.withTableName(properties.getTable())
51-
.withAssistantColumnName(properties.getAssistantColumn())
52-
.withUserColumnName(properties.getUserColumn());
51+
.withMessagesColumnName(properties.getMessagesColumn());
5352

5453
if (!properties.isInitializeSchema()) {
5554
builder = builder.disallowSchemaChanges();

auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-cassandra/src/main/java/org/springframework/ai/model/chat/memory/repository/cassandra/autoconfigure/CassandraChatMemoryProperties.java

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,13 @@
3535
@ConfigurationProperties(CassandraChatMemoryProperties.CONFIG_PREFIX)
3636
public class CassandraChatMemoryProperties {
3737

38-
public static final String CONFIG_PREFIX = "spring.ai.chat.memory.repository.cassandra";
39-
40-
private static final Logger logger = LoggerFactory.getLogger(CassandraChatMemoryProperties.class);
38+
public static final String CONFIG_PREFIX = "spring.ai.chat.memory.cassandra";
4139

4240
private String keyspace = CassandraChatMemoryConfig.DEFAULT_KEYSPACE_NAME;
4341

4442
private String table = CassandraChatMemoryConfig.DEFAULT_TABLE_NAME;
4543

46-
private String assistantColumn = CassandraChatMemoryConfig.DEFAULT_ASSISTANT_COLUMN_NAME;
47-
48-
private String userColumn = CassandraChatMemoryConfig.DEFAULT_USER_COLUMN_NAME;
44+
private String messagesColumn = CassandraChatMemoryConfig.DEFAULT_MESSAGES_COLUMN_NAME;
4945

5046
private boolean initializeSchema = true;
5147

@@ -75,20 +71,12 @@ public void setTable(String table) {
7571
this.table = table;
7672
}
7773

78-
public String getAssistantColumn() {
79-
return this.assistantColumn;
80-
}
81-
82-
public void setAssistantColumn(String assistantColumn) {
83-
this.assistantColumn = assistantColumn;
84-
}
85-
86-
public String getUserColumn() {
87-
return this.userColumn;
74+
public String getMessagesColumn() {
75+
return this.messagesColumn;
8876
}
8977

90-
public void setUserColumn(String userColumn) {
91-
this.userColumn = userColumn;
78+
public void setMessagesColumn(String messagesColumn) {
79+
this.messagesColumn = messagesColumn;
9280
}
9381

9482
@Nullable

auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-cassandra/src/test/java/org/springframework/ai/model/chat/memory/repository/cassandra/autoconfigure/CassandraChatMemoryPropertiesTest.java

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@ void defaultValues() {
3636
var props = new CassandraChatMemoryProperties();
3737
assertThat(props.getKeyspace()).isEqualTo(CassandraChatMemoryConfig.DEFAULT_KEYSPACE_NAME);
3838
assertThat(props.getTable()).isEqualTo(CassandraChatMemoryConfig.DEFAULT_TABLE_NAME);
39-
assertThat(props.getAssistantColumn()).isEqualTo(CassandraChatMemoryConfig.DEFAULT_ASSISTANT_COLUMN_NAME);
40-
assertThat(props.getUserColumn()).isEqualTo(CassandraChatMemoryConfig.DEFAULT_USER_COLUMN_NAME);
39+
assertThat(props.getMessagesColumn()).isEqualTo(CassandraChatMemoryConfig.DEFAULT_MESSAGES_COLUMN_NAME);
4140
assertThat(props.getTimeToLive()).isNull();
4241
assertThat(props.isInitializeSchema()).isTrue();
4342
}
@@ -47,15 +46,13 @@ void customValues() {
4746
var props = new CassandraChatMemoryProperties();
4847
props.setKeyspace("my_keyspace");
4948
props.setTable("my_table");
50-
props.setAssistantColumn("my_assistant_column");
51-
props.setUserColumn("my_user_column");
49+
props.setMessagesColumn("my_messages_column");
5250
props.setTimeToLive(Duration.ofDays(1));
5351
props.setInitializeSchema(false);
5452

5553
assertThat(props.getKeyspace()).isEqualTo("my_keyspace");
5654
assertThat(props.getTable()).isEqualTo("my_table");
57-
assertThat(props.getAssistantColumn()).isEqualTo("my_assistant_column");
58-
assertThat(props.getUserColumn()).isEqualTo("my_user_column");
55+
assertThat(props.getMessagesColumn()).isEqualTo("my_messages_column");
5956
assertThat(props.getTimeToLive()).isEqualTo(Duration.ofDays(1));
6057
assertThat(props.isInitializeSchema()).isFalse();
6158
}

memory/spring-ai-model-chat-memory-cassandra/src/main/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemory.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,17 +55,17 @@ public void add(String conversationId, List<Message> messages) {
5555

5656
@Override
5757
public void add(String sessionId, Message msg) {
58-
repo.save(sessionId, msg);
58+
repo.saveAll(sessionId, List.of(msg));
5959
}
6060

6161
@Override
6262
public void clear(String sessionId) {
63-
repo.deleteByConversationId(sessionId);
63+
repo.clear(sessionId);
6464
}
6565

6666
@Override
6767
public List<Message> get(String sessionId, int lastN) {
68-
return repo.findByConversationId(sessionId).subList(0, lastN);
68+
return repo.findByConversationIdWithLimit(sessionId, lastN);
6969
}
7070

7171
}

memory/spring-ai-model-chat-memory-cassandra/src/main/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryConfig.java

Lines changed: 62 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import com.datastax.oss.driver.api.core.metadata.schema.TableMetadata;
3131
import com.datastax.oss.driver.api.core.type.DataType;
3232
import com.datastax.oss.driver.api.core.type.DataTypes;
33+
import com.datastax.oss.driver.api.core.type.UserDefinedType;
3334
import com.datastax.oss.driver.api.core.type.codec.registry.CodecRegistry;
3435
import com.datastax.oss.driver.api.core.type.reflect.GenericType;
3536
import com.datastax.oss.driver.api.querybuilder.SchemaBuilder;
@@ -61,19 +62,26 @@ public final class CassandraChatMemoryConfig {
6162
// todo – make configurable
6263
public static final String DEFAULT_EXCHANGE_ID_NAME = "message_timestamp";
6364

64-
public static final String DEFAULT_ASSISTANT_COLUMN_NAME = "assistant";
65-
66-
public static final String DEFAULT_USER_COLUMN_NAME = "user";
65+
public static final String DEFAULT_MESSAGES_COLUMN_NAME = "messages";
6766

6867
private static final Logger logger = LoggerFactory.getLogger(CassandraChatMemoryConfig.class);
6968

7069
final CqlSession session;
7170

7271
final Schema schema;
7372

74-
final String assistantColumn;
73+
final String messageUDT = "ai_chat_message";
74+
75+
final String messagesColumn;
76+
77+
// todo – make configurable
78+
final String messageUdtTimestampColumn = "msg_timestamp";
79+
80+
// todo – make configurable
81+
final String messageUdtTypeColumn = "msg_type";
7582

76-
final String userColumn;
83+
// todo – make configurable
84+
final String messageUdtContentColumn = "msg_content";
7785

7886
final SessionIdToPrimaryKeysTranslator primaryKeyTranslator;
7987

@@ -84,8 +92,7 @@ public final class CassandraChatMemoryConfig {
8492
private CassandraChatMemoryConfig(Builder builder) {
8593
this.session = builder.session;
8694
this.schema = new Schema(builder.keyspace, builder.table, builder.partitionKeys, builder.clusteringKeys);
87-
this.assistantColumn = builder.assistantColumn;
88-
this.userColumn = builder.userColumn;
95+
this.messagesColumn = builder.messagesColumn;
8996
this.timeToLiveSeconds = builder.timeToLiveSeconds;
9097
this.disallowSchemaChanges = builder.disallowSchemaChanges;
9198
this.primaryKeyTranslator = builder.primaryKeyTranslator;
@@ -109,6 +116,7 @@ void dropKeyspace() {
109116
void ensureSchemaExists() {
110117
if (!this.disallowSchemaChanges) {
111118
SchemaUtil.ensureKeyspaceExists(this.session, this.schema.keyspace);
119+
ensureMessageTypeExist();
112120
ensureTableExists();
113121
ensureTableColumnsExist();
114122
SchemaUtil.checkSchemaAgreement(this.session);
@@ -129,17 +137,35 @@ void checkSchemaValid() {
129137
.getTable(this.schema.table)
130138
.isPresent(), "table %s does not exist");
131139

140+
Preconditions.checkState(this.session.getMetadata()
141+
.getKeyspace(this.schema.keyspace())
142+
.get()
143+
.getUserDefinedType(messageUDT)
144+
.isPresent(), "table %s does not exist");
145+
146+
UserDefinedType udt = this.session.getMetadata()
147+
.getKeyspace(this.schema.keyspace())
148+
.get()
149+
.getUserDefinedType(messageUDT)
150+
.get();
151+
152+
Preconditions.checkState(udt.contains(this.messageUdtTimestampColumn), "field %s does not exist",
153+
this.messageUdtTimestampColumn);
154+
155+
Preconditions.checkState(udt.contains(this.messageUdtTypeColumn), "field %s does not exist",
156+
this.messageUdtTypeColumn);
157+
158+
Preconditions.checkState(udt.contains(this.messageUdtContentColumn), "field %s does not exist",
159+
this.messageUdtContentColumn);
160+
132161
TableMetadata tableMetadata = this.session.getMetadata()
133162
.getKeyspace(this.schema.keyspace)
134163
.get()
135164
.getTable(this.schema.table)
136165
.get();
137166

138-
Preconditions.checkState(tableMetadata.getColumn(this.assistantColumn).isPresent(), "column %s does not exist",
139-
this.assistantColumn);
140-
141-
Preconditions.checkState(tableMetadata.getColumn(this.userColumn).isPresent(), "column %s does not exist",
142-
this.userColumn);
167+
Preconditions.checkState(tableMetadata.getColumn(this.messagesColumn).isPresent(), "column %s does not exist",
168+
this.messagesColumn);
143169
}
144170

145171
private void ensureTableExists() {
@@ -159,9 +185,11 @@ private void ensureTableExists() {
159185

160186
String lastClusteringColumn = this.schema.clusteringKeys.get(this.schema.clusteringKeys.size() - 1).name();
161187

162-
CreateTableWithOptions createTableWithOptions = createTable.withColumn(this.userColumn, DataTypes.TEXT)
188+
CreateTableWithOptions createTableWithOptions = createTable
189+
.withColumn(this.messagesColumn, DataTypes.frozenListOf(SchemaBuilder.udt(messageUDT, true)))
163190
.withClusteringOrder(lastClusteringColumn, ClusteringOrder.DESC)
164-
// TODO replace w/ SchemaBuilder.unifiedCompactionStrategy() is available
191+
// TODO replace w/ SchemaBuilder.unifiedCompactionStrategy() when
192+
// available
165193
.withOption("compaction", Map.of("class", "UnifiedCompactionStrategy"));
166194

167195
if (null != this.timeToLiveSeconds) {
@@ -171,6 +199,18 @@ private void ensureTableExists() {
171199
}
172200
}
173201

202+
private void ensureMessageTypeExist() {
203+
204+
SimpleStatement stmt = SchemaBuilder.createType(messageUDT)
205+
.ifNotExists()
206+
.withField(messageUdtTimestampColumn, DataTypes.TIMESTAMP)
207+
.withField(messageUdtTypeColumn, DataTypes.TEXT)
208+
.withField(messageUdtContentColumn, DataTypes.TEXT)
209+
.build();
210+
211+
this.session.execute(stmt.setKeyspace(this.schema.keyspace));
212+
}
213+
174214
private void ensureTableColumnsExist() {
175215

176216
TableMetadata tableMetadata = this.session.getMetadata()
@@ -179,18 +219,12 @@ private void ensureTableColumnsExist() {
179219
.getTable(this.schema.table())
180220
.get();
181221

182-
boolean addAssistantColumn = tableMetadata.getColumn(this.assistantColumn).isEmpty();
183-
boolean addUserColumn = tableMetadata.getColumn(this.userColumn).isEmpty();
222+
if (tableMetadata.getColumn(this.messagesColumn).isEmpty()) {
223+
224+
SimpleStatement stmt = SchemaBuilder.alterTable(this.schema.keyspace(), this.schema.table())
225+
.addColumn(this.messagesColumn, DataTypes.frozenListOf(SchemaBuilder.udt(messageUDT, true)))
226+
.build();
184227

185-
if (addAssistantColumn || addUserColumn) {
186-
AlterTableAddColumn alterTable = SchemaBuilder.alterTable(this.schema.keyspace(), this.schema.table());
187-
if (addAssistantColumn) {
188-
alterTable = alterTable.addColumn(this.assistantColumn, DataTypes.TEXT);
189-
}
190-
if (addUserColumn) {
191-
alterTable = alterTable.addColumn(this.userColumn, DataTypes.TEXT);
192-
}
193-
SimpleStatement stmt = ((AlterTableAddColumnEnd) alterTable).build();
194228
logger.debug("Executing {}", stmt.getQuery());
195229
this.session.execute(stmt);
196230
}
@@ -228,9 +262,7 @@ public static final class Builder {
228262
private List<SchemaColumn> clusteringKeys = List
229263
.of(new SchemaColumn(DEFAULT_EXCHANGE_ID_NAME, DataTypes.TIMESTAMP));
230264

231-
private String assistantColumn = DEFAULT_ASSISTANT_COLUMN_NAME;
232-
233-
private String userColumn = DEFAULT_USER_COLUMN_NAME;
265+
private String messagesColumn = DEFAULT_MESSAGES_COLUMN_NAME;
234266

235267
private Integer timeToLiveSeconds = null;
236268

@@ -289,13 +321,8 @@ public Builder withClusteringKeys(List<SchemaColumn> clusteringKeys) {
289321
return this;
290322
}
291323

292-
public Builder withAssistantColumnName(String name) {
293-
this.assistantColumn = name;
294-
return this;
295-
}
296-
297-
public Builder withUserColumnName(String name) {
298-
this.userColumn = name;
324+
public Builder withMessagesColumnName(String name) {
325+
this.messagesColumn = name;
299326
return this;
300327
}
301328

0 commit comments

Comments
 (0)