From a9ed8ff7924c68585309df9bae411bf1fbe249da Mon Sep 17 00:00:00 2001 From: mck Date: Wed, 7 May 2025 23:14:50 +0200 Subject: [PATCH] Implement CassandraChatMemoryRepository ref: https://github.com/spring-projects/spring-ai/issues/2998 Signed-off-by: mck --- .../pom.xml | 4 +- .../CassandraChatMemoryAutoConfiguration.java | 8 +- ...assandraChatMemoryAutoConfigurationIT.java | 33 ++- .../pom.xml | 4 +- .../pom.xml | 4 +- .../memory/cassandra/CassandraChatMemory.java | 152 +---------- .../CassandraChatMemoryRepository.java | 244 ++++++++++++++++++ ...a => CassandraChatMemoryRepositoryIT.java} | 78 ++++-- .../modules/ROOT/pages/api/chat-memory.adoc | 78 ++++++ .../modules/ROOT/pages/api/chatclient.adoc | 2 + 10 files changed, 416 insertions(+), 191 deletions(-) create mode 100644 memory/spring-ai-model-chat-memory-cassandra/src/main/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryRepository.java rename memory/spring-ai-model-chat-memory-cassandra/src/test/java/org/springframework/ai/chat/memory/cassandra/{CassandraChatMemoryIT.java => CassandraChatMemoryRepositoryIT.java} (76%) diff --git a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-cassandra/pom.xml b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-cassandra/pom.xml index aeeb345a742..d6ad7ff7090 100644 --- a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-cassandra/pom.xml +++ b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-cassandra/pom.xml @@ -11,8 +11,8 @@ spring-ai-autoconfigure-model-chat-memory-cassandra jar - Spring AI Cassandra Chat Memory Auto Configuration - Spring AI Cassandra Chat Memory Auto Configuration + Spring AI Apache Cassandra Chat Memory Auto Configuration + Spring AI Apache Cassandra Chat Memory Auto Configuration https://github.com/spring-projects/spring-ai diff --git a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-cassandra/src/main/java/org/springframework/ai/model/chat/memory/cassandra/autoconfigure/CassandraChatMemoryAutoConfiguration.java b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-cassandra/src/main/java/org/springframework/ai/model/chat/memory/cassandra/autoconfigure/CassandraChatMemoryAutoConfiguration.java index 58f63d08d0d..ed36a61e2f4 100644 --- a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-cassandra/src/main/java/org/springframework/ai/model/chat/memory/cassandra/autoconfigure/CassandraChatMemoryAutoConfiguration.java +++ b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-cassandra/src/main/java/org/springframework/ai/model/chat/memory/cassandra/autoconfigure/CassandraChatMemoryAutoConfiguration.java @@ -18,8 +18,8 @@ import com.datastax.oss.driver.api.core.CqlSession; -import org.springframework.ai.chat.memory.cassandra.CassandraChatMemory; import org.springframework.ai.chat.memory.cassandra.CassandraChatMemoryConfig; +import org.springframework.ai.chat.memory.cassandra.CassandraChatMemoryRepository; import org.springframework.ai.model.chat.memory.autoconfigure.ChatMemoryAutoConfiguration; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.cassandra.CassandraAutoConfiguration; @@ -36,13 +36,13 @@ * @since 1.0.0 */ @AutoConfiguration(after = CassandraAutoConfiguration.class, before = ChatMemoryAutoConfiguration.class) -@ConditionalOnClass({ CassandraChatMemory.class, CqlSession.class }) +@ConditionalOnClass({ CassandraChatMemoryRepository.class, CqlSession.class }) @EnableConfigurationProperties(CassandraChatMemoryProperties.class) public class CassandraChatMemoryAutoConfiguration { @Bean @ConditionalOnMissingBean - public CassandraChatMemory chatMemory(CassandraChatMemoryProperties properties, CqlSession cqlSession) { + public CassandraChatMemoryRepository chatMemory(CassandraChatMemoryProperties properties, CqlSession cqlSession) { var builder = CassandraChatMemoryConfig.builder().withCqlSession(cqlSession); @@ -58,7 +58,7 @@ public CassandraChatMemory chatMemory(CassandraChatMemoryProperties properties, builder = builder.withTimeToLive(properties.getTimeToLive()); } - return CassandraChatMemory.create(builder.build()); + return CassandraChatMemoryRepository.create(builder.build()); } } diff --git a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-cassandra/src/test/java/org/springframework/ai/model/chat/memory/cassandra/autoconfigure/CassandraChatMemoryAutoConfigurationIT.java b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-cassandra/src/test/java/org/springframework/ai/model/chat/memory/cassandra/autoconfigure/CassandraChatMemoryAutoConfigurationIT.java index 8bd97965349..5139cb12d67 100644 --- a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-cassandra/src/test/java/org/springframework/ai/model/chat/memory/cassandra/autoconfigure/CassandraChatMemoryAutoConfigurationIT.java +++ b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-cassandra/src/test/java/org/springframework/ai/model/chat/memory/cassandra/autoconfigure/CassandraChatMemoryAutoConfigurationIT.java @@ -26,7 +26,7 @@ import org.testcontainers.junit.jupiter.Testcontainers; import org.testcontainers.utility.DockerImageName; -import org.springframework.ai.chat.memory.cassandra.CassandraChatMemory; +import org.springframework.ai.chat.memory.cassandra.CassandraChatMemoryRepository; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.messages.UserMessage; @@ -61,30 +61,29 @@ void addAndGet() { .withPropertyValues("spring.cassandra.localDatacenter=" + cassandraContainer.getLocalDatacenter()) .withPropertyValues("spring.ai.chat.memory.cassandra.time-to-live=" + getTimeToLive()) .run(context -> { - CassandraChatMemory memory = context.getBean(CassandraChatMemory.class); + CassandraChatMemoryRepository memory = context.getBean(CassandraChatMemoryRepository.class); String sessionId = UUIDs.timeBased().toString(); - assertThat(memory.get(sessionId, Integer.MAX_VALUE)).isEmpty(); + assertThat(memory.findByConversationId(sessionId)).isEmpty(); - memory.add(sessionId, new UserMessage("test question")); + memory.saveAll(sessionId, List.of(new UserMessage("test question"))); - assertThat(memory.get(sessionId, Integer.MAX_VALUE)).hasSize(1); - assertThat(memory.get(sessionId, Integer.MAX_VALUE).get(0).getMessageType()) - .isEqualTo(MessageType.USER); - assertThat(memory.get(sessionId, Integer.MAX_VALUE).get(0).getText()).isEqualTo("test question"); + assertThat(memory.findByConversationId(sessionId)).hasSize(1); + assertThat(memory.findByConversationId(sessionId).get(0).getMessageType()).isEqualTo(MessageType.USER); + assertThat(memory.findByConversationId(sessionId).get(0).getText()).isEqualTo("test question"); - memory.clear(sessionId); - assertThat(memory.get(sessionId, Integer.MAX_VALUE)).isEmpty(); + memory.deleteByConversationId(sessionId); + assertThat(memory.findByConversationId(sessionId)).isEmpty(); - memory.add(sessionId, List.of(new UserMessage("test question"), new AssistantMessage("test answer"))); + memory.saveAll(sessionId, + List.of(new UserMessage("test question"), new AssistantMessage("test answer"))); - assertThat(memory.get(sessionId, Integer.MAX_VALUE)).hasSize(2); - assertThat(memory.get(sessionId, Integer.MAX_VALUE).get(1).getMessageType()) + assertThat(memory.findByConversationId(sessionId)).hasSize(2); + assertThat(memory.findByConversationId(sessionId).get(1).getMessageType()) .isEqualTo(MessageType.ASSISTANT); - assertThat(memory.get(sessionId, Integer.MAX_VALUE).get(1).getText()).isEqualTo("test answer"); - assertThat(memory.get(sessionId, Integer.MAX_VALUE).get(0).getMessageType()) - .isEqualTo(MessageType.USER); - assertThat(memory.get(sessionId, Integer.MAX_VALUE).get(0).getText()).isEqualTo("test question"); + assertThat(memory.findByConversationId(sessionId).get(1).getText()).isEqualTo("test answer"); + assertThat(memory.findByConversationId(sessionId).get(0).getMessageType()).isEqualTo(MessageType.USER); + assertThat(memory.findByConversationId(sessionId).get(0).getText()).isEqualTo("test question"); CassandraChatMemoryProperties properties = context.getBean(CassandraChatMemoryProperties.class); assertThat(properties.getTimeToLive()).isEqualTo(getTimeToLive()); diff --git a/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-cassandra/pom.xml b/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-cassandra/pom.xml index 0c5e02a0324..e86c65d1f85 100644 --- a/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-cassandra/pom.xml +++ b/auto-configurations/vector-stores/spring-ai-autoconfigure-vector-store-cassandra/pom.xml @@ -27,8 +27,8 @@ spring-ai-autoconfigure-vector-store-cassandra jar - Spring AI Auto Configuration for Cassandra vector store - Spring AI Auto Configuration for Cassandra vector store + Spring AI Auto Configuration for Apache Cassandra vector store + Spring AI Auto Configuration for Apache Cassandra vector store https://github.com/spring-projects/spring-ai diff --git a/memory/spring-ai-model-chat-memory-cassandra/pom.xml b/memory/spring-ai-model-chat-memory-cassandra/pom.xml index 717edc5be5c..fc6c500b6c8 100644 --- a/memory/spring-ai-model-chat-memory-cassandra/pom.xml +++ b/memory/spring-ai-model-chat-memory-cassandra/pom.xml @@ -27,8 +27,8 @@ spring-ai-model-chat-memory-cassandra - Spring AI Cassandra Chat Memory - Spring AI Cassandra Chat Memory implementation + Spring AI Apache Cassandra Chat Memory + Spring AI Apache Cassandra Chat Memory implementation https://github.com/spring-projects/spring-ai diff --git a/memory/spring-ai-model-chat-memory-cassandra/src/main/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemory.java b/memory/spring-ai-model-chat-memory-cassandra/src/main/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemory.java index 2b405f65327..05aaf819d51 100644 --- a/memory/spring-ai-model-chat-memory-cassandra/src/main/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemory.java +++ b/memory/spring-ai-model-chat-memory-cassandra/src/main/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemory.java @@ -16,60 +16,32 @@ package org.springframework.ai.chat.memory.cassandra; -import java.time.Instant; -import java.util.ArrayList; -import java.util.Collections; import java.util.List; -import java.util.concurrent.atomic.AtomicLong; - -import com.datastax.oss.driver.api.core.cql.BoundStatementBuilder; -import com.datastax.oss.driver.api.core.cql.PreparedStatement; -import com.datastax.oss.driver.api.core.cql.Row; -import com.datastax.oss.driver.api.querybuilder.QueryBuilder; -import com.datastax.oss.driver.api.querybuilder.delete.Delete; -import com.datastax.oss.driver.api.querybuilder.delete.DeleteSelection; -import com.datastax.oss.driver.api.querybuilder.insert.InsertInto; -import com.datastax.oss.driver.api.querybuilder.insert.RegularInsert; -import com.datastax.oss.driver.api.querybuilder.select.Select; -import com.datastax.oss.driver.shaded.guava.common.base.Preconditions; import org.springframework.ai.chat.memory.ChatMemory; -import org.springframework.ai.chat.memory.cassandra.CassandraChatMemoryConfig.SchemaColumn; -import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.UserMessage; /** + * @deprecated Use CassandraChatMemoryRepository + * * Create a CassandraChatMemory like CassandraChatMemory.create(CassandraChatMemoryConfig.builder().withTimeToLive(Duration.ofDays(1)).build()); * * For example @see org.springframework.ai.chat.memory.cassandra.CassandraChatMemory - * * @author Mick Semb Wever * @since 1.0.0 */ +@Deprecated public final class CassandraChatMemory implements ChatMemory { - public static final String CONVERSATION_TS = CassandraChatMemory.class.getSimpleName() + "_message_timestamp"; - final CassandraChatMemoryConfig conf; - private final PreparedStatement addUserStmt; - - private final PreparedStatement addAssistantStmt; - - private final PreparedStatement getStmt; - - private final PreparedStatement deleteStmt; + final CassandraChatMemoryRepository repo; public CassandraChatMemory(CassandraChatMemoryConfig config) { this.conf = config; - this.conf.ensureSchemaExists(); - this.addUserStmt = prepareAddStmt(this.conf.userColumn); - this.addAssistantStmt = prepareAddStmt(this.conf.assistantColumn); - this.getStmt = prepareGetStatement(); - this.deleteStmt = prepareDeleteStmt(); + repo = CassandraChatMemoryRepository.create(conf); } public static CassandraChatMemory create(CassandraChatMemoryConfig conf) { @@ -78,128 +50,22 @@ public static CassandraChatMemory create(CassandraChatMemoryConfig conf) { @Override public void add(String conversationId, List messages) { - final AtomicLong instantSeq = new AtomicLong(Instant.now().toEpochMilli()); - messages.forEach(msg -> { - if (msg.getMetadata().containsKey(CONVERSATION_TS)) { - msg.getMetadata().put(CONVERSATION_TS, Instant.ofEpochMilli(instantSeq.getAndIncrement())); - } - add(conversationId, msg); - }); + repo.saveAll(conversationId, messages); } @Override public void add(String sessionId, Message msg) { - - Preconditions.checkArgument( - !msg.getMetadata().containsKey(CONVERSATION_TS) - || msg.getMetadata().get(CONVERSATION_TS) instanceof Instant, - "messages only accept metadata '%s' entries of type Instant", CONVERSATION_TS); - - msg.getMetadata().putIfAbsent(CONVERSATION_TS, Instant.now()); - - PreparedStatement stmt = getStatement(msg); - - List primaryKeys = this.conf.primaryKeyTranslator.apply(sessionId); - BoundStatementBuilder builder = stmt.boundStatementBuilder(); - - for (int k = 0; k < primaryKeys.size(); ++k) { - SchemaColumn keyColumn = this.conf.getPrimaryKeyColumn(k); - builder = builder.set(keyColumn.name(), primaryKeys.get(k), keyColumn.javaType()); - } - - Instant instant = (Instant) msg.getMetadata().get(CONVERSATION_TS); - - builder = builder.setInstant(CassandraChatMemoryConfig.DEFAULT_EXCHANGE_ID_NAME, instant) - .setString("message", msg.getText()); - - this.conf.session.execute(builder.build()); - } - - PreparedStatement getStatement(Message msg) { - return switch (msg.getMessageType()) { - case USER -> this.addUserStmt; - case ASSISTANT -> this.addAssistantStmt; - default -> throw new IllegalArgumentException("Cant add type " + msg); - }; + repo.save(sessionId, msg); } @Override public void clear(String sessionId) { - - List primaryKeys = this.conf.primaryKeyTranslator.apply(sessionId); - BoundStatementBuilder builder = this.deleteStmt.boundStatementBuilder(); - - for (int k = 0; k < primaryKeys.size(); ++k) { - SchemaColumn keyColumn = this.conf.getPrimaryKeyColumn(k); - builder = builder.set(keyColumn.name(), primaryKeys.get(k), keyColumn.javaType()); - } - - this.conf.session.execute(builder.build()); + repo.deleteByConversationId(sessionId); } @Override public List get(String sessionId, int lastN) { - - List primaryKeys = this.conf.primaryKeyTranslator.apply(sessionId); - BoundStatementBuilder builder = this.getStmt.boundStatementBuilder().setInt("lastN", lastN); - - for (int k = 0; k < primaryKeys.size(); ++k) { - SchemaColumn keyColumn = this.conf.getPrimaryKeyColumn(k); - builder = builder.set(keyColumn.name(), primaryKeys.get(k), keyColumn.javaType()); - } - - List messages = new ArrayList<>(); - for (Row r : this.conf.session.execute(builder.build())) { - String assistant = r.getString(this.conf.assistantColumn); - String user = r.getString(this.conf.userColumn); - if (null != assistant) { - messages.add(new AssistantMessage(assistant)); - } - if (null != user) { - messages.add(new UserMessage(user)); - } - } - Collections.reverse(messages); - return messages; - } - - private PreparedStatement prepareAddStmt(String column) { - RegularInsert stmt = null; - InsertInto stmtStart = QueryBuilder.insertInto(this.conf.schema.keyspace(), this.conf.schema.table()); - for (var c : this.conf.schema.partitionKeys()) { - stmt = (null != stmt ? stmt : stmtStart).value(c.name(), QueryBuilder.bindMarker(c.name())); - } - for (var c : this.conf.schema.clusteringKeys()) { - stmt = stmt.value(c.name(), QueryBuilder.bindMarker(c.name())); - } - stmt = stmt.value(column, QueryBuilder.bindMarker("message")); - return this.conf.session.prepare(stmt.build()); - } - - private PreparedStatement prepareGetStatement() { - Select stmt = QueryBuilder.selectFrom(this.conf.schema.keyspace(), this.conf.schema.table()).all(); - for (var c : this.conf.schema.partitionKeys()) { - stmt = stmt.whereColumn(c.name()).isEqualTo(QueryBuilder.bindMarker(c.name())); - } - for (int i = 0; i + 1 < this.conf.schema.clusteringKeys().size(); ++i) { - String columnName = this.conf.schema.clusteringKeys().get(i).name(); - stmt = stmt.whereColumn(columnName).isEqualTo(QueryBuilder.bindMarker(columnName)); - } - stmt = stmt.limit(QueryBuilder.bindMarker("lastN")); - return this.conf.session.prepare(stmt.build()); - } - - private PreparedStatement prepareDeleteStmt() { - Delete stmt = null; - DeleteSelection stmtStart = QueryBuilder.deleteFrom(this.conf.schema.keyspace(), this.conf.schema.table()); - for (var c : this.conf.schema.partitionKeys()) { - stmt = (null != stmt ? stmt : stmtStart).whereColumn(c.name()).isEqualTo(QueryBuilder.bindMarker(c.name())); - } - for (int i = 0; i + 1 < this.conf.schema.clusteringKeys().size(); ++i) { - String columnName = this.conf.schema.clusteringKeys().get(i).name(); - stmt = stmt.whereColumn(columnName).isEqualTo(QueryBuilder.bindMarker(columnName)); - } - return this.conf.session.prepare(stmt.build()); + return repo.findByConversationId(sessionId).subList(0, lastN); } } diff --git a/memory/spring-ai-model-chat-memory-cassandra/src/main/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryRepository.java b/memory/spring-ai-model-chat-memory-cassandra/src/main/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryRepository.java new file mode 100644 index 00000000000..2e6f5c2ccac --- /dev/null +++ b/memory/spring-ai-model-chat-memory-cassandra/src/main/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryRepository.java @@ -0,0 +1,244 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.memory.cassandra; + +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicLong; + +import com.datastax.oss.driver.api.core.cql.BoundStatement; +import com.datastax.oss.driver.api.core.cql.BoundStatementBuilder; +import com.datastax.oss.driver.api.core.cql.PreparedStatement; +import com.datastax.oss.driver.api.core.cql.Row; +import com.datastax.oss.driver.api.querybuilder.QueryBuilder; +import com.datastax.oss.driver.api.querybuilder.delete.Delete; +import com.datastax.oss.driver.api.querybuilder.delete.DeleteSelection; +import com.datastax.oss.driver.api.querybuilder.insert.InsertInto; +import com.datastax.oss.driver.api.querybuilder.insert.RegularInsert; +import com.datastax.oss.driver.api.querybuilder.select.Select; +import com.datastax.oss.driver.shaded.guava.common.base.Preconditions; +import org.springframework.ai.chat.memory.ChatMemoryRepository; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.util.Assert; + +import static org.springframework.ai.chat.messages.MessageType.ASSISTANT; +import static org.springframework.ai.chat.messages.MessageType.USER; + +/** + * An implementation of {@link ChatMemoryRepository} for Apache Cassandra. + * + * @author Mick Semb Wever + * @since 1.0.0 + */ +public class CassandraChatMemoryRepository implements ChatMemoryRepository { + + public static final String CONVERSATION_TS = CassandraChatMemoryRepository.class.getSimpleName() + + "_message_timestamp"; + + final CassandraChatMemoryConfig conf; + + private final PreparedStatement allStmt; + + private final PreparedStatement addUserStmt; + + private final PreparedStatement addAssistantStmt; + + private final PreparedStatement getStmt; + + private final PreparedStatement deleteStmt; + + private CassandraChatMemoryRepository(CassandraChatMemoryConfig conf) { + Assert.notNull(conf, "conf cannot be null"); + this.conf = conf; + this.conf.ensureSchemaExists(); + this.allStmt = prepareAllStatement(); + this.addUserStmt = prepareAddStmt(this.conf.userColumn); + this.addAssistantStmt = prepareAddStmt(this.conf.assistantColumn); + this.getStmt = prepareGetStatement(); + this.deleteStmt = prepareDeleteStmt(); + } + + public static CassandraChatMemoryRepository create(CassandraChatMemoryConfig conf) { + return new CassandraChatMemoryRepository(conf); + } + + @Override + public List findConversationIds() { + List conversationIds = new ArrayList<>(); + long token = Long.MIN_VALUE; + boolean emptyQuery = false; + + while (!emptyQuery && token < Long.MAX_VALUE) { + BoundStatement stmt = this.allStmt.boundStatementBuilder().setLong("after_token", token).build(); + emptyQuery = true; + for (Row r : this.conf.session.execute(stmt)) { + emptyQuery = false; + conversationIds.add(r.getString(CassandraChatMemoryConfig.DEFAULT_SESSION_ID_NAME)); + token = r.getLong("t"); + } + } + return List.copyOf(conversationIds); + } + + @Override + public List findByConversationId(String conversationId) { + Assert.hasText(conversationId, "conversationId cannot be null or empty"); + + List primaryKeys = this.conf.primaryKeyTranslator.apply(conversationId); + BoundStatementBuilder builder = this.getStmt.boundStatementBuilder(); + + for (int k = 0; k < primaryKeys.size(); ++k) { + CassandraChatMemoryConfig.SchemaColumn keyColumn = this.conf.getPrimaryKeyColumn(k); + builder = builder.set(keyColumn.name(), primaryKeys.get(k), keyColumn.javaType()); + } + + List messages = new ArrayList<>(); + for (Row r : this.conf.session.execute(builder.build())) { + String assistant = r.getString(this.conf.assistantColumn); + String user = r.getString(this.conf.userColumn); + if (null != assistant) { + messages.add(new AssistantMessage(assistant)); + } + if (null != user) { + messages.add(new UserMessage(user)); + } + } + Collections.reverse(messages); + return messages; + } + + @Override + public void saveAll(String conversationId, List messages) { + Assert.hasText(conversationId, "conversationId cannot be null or empty"); + Assert.notNull(messages, "messages cannot be null"); + Assert.noNullElements(messages, "messages cannot contain null elements"); + + final AtomicLong instantSeq = new AtomicLong(Instant.now().toEpochMilli()); + messages.forEach(msg -> { + if (msg.getMetadata().containsKey(CONVERSATION_TS)) { + msg.getMetadata().put(CONVERSATION_TS, Instant.ofEpochMilli(instantSeq.getAndIncrement())); + } + save(conversationId, msg); + }); + } + + void save(String conversationId, Message msg) { + + Preconditions.checkArgument( + !msg.getMetadata().containsKey(CONVERSATION_TS) + || msg.getMetadata().get(CONVERSATION_TS) instanceof Instant, + "messages only accept metadata '%s' entries of type Instant", CONVERSATION_TS); + + msg.getMetadata().putIfAbsent(CONVERSATION_TS, Instant.now()); + + PreparedStatement stmt = getStatement(msg); + + List primaryKeys = this.conf.primaryKeyTranslator.apply(conversationId); + BoundStatementBuilder builder = stmt.boundStatementBuilder(); + + for (int k = 0; k < primaryKeys.size(); ++k) { + CassandraChatMemoryConfig.SchemaColumn keyColumn = this.conf.getPrimaryKeyColumn(k); + builder = builder.set(keyColumn.name(), primaryKeys.get(k), keyColumn.javaType()); + } + + Instant instant = (Instant) msg.getMetadata().get(CONVERSATION_TS); + + builder = builder.setInstant(CassandraChatMemoryConfig.DEFAULT_EXCHANGE_ID_NAME, instant) + .setString("message", msg.getText()); + + this.conf.session.execute(builder.build()); + } + + @Override + public void deleteByConversationId(String conversationId) { + Assert.hasText(conversationId, "conversationId cannot be null or empty"); + + List primaryKeys = this.conf.primaryKeyTranslator.apply(conversationId); + BoundStatementBuilder builder = this.deleteStmt.boundStatementBuilder(); + + for (int k = 0; k < primaryKeys.size(); ++k) { + CassandraChatMemoryConfig.SchemaColumn keyColumn = this.conf.getPrimaryKeyColumn(k); + builder = builder.set(keyColumn.name(), primaryKeys.get(k), keyColumn.javaType()); + } + + this.conf.session.execute(builder.build()); + } + + private PreparedStatement prepareAddStmt(String column) { + RegularInsert stmt = null; + InsertInto stmtStart = QueryBuilder.insertInto(this.conf.schema.keyspace(), this.conf.schema.table()); + for (var c : this.conf.schema.partitionKeys()) { + stmt = (null != stmt ? stmt : stmtStart).value(c.name(), QueryBuilder.bindMarker(c.name())); + } + for (var c : this.conf.schema.clusteringKeys()) { + stmt = stmt.value(c.name(), QueryBuilder.bindMarker(c.name())); + } + stmt = stmt.value(column, QueryBuilder.bindMarker("message")); + return this.conf.session.prepare(stmt.build()); + } + + private PreparedStatement prepareAllStatement() { + Select stmt = QueryBuilder.selectFrom(this.conf.schema.keyspace(), this.conf.schema.table()) + .distinct() + .raw(String.format("token(%s)", CassandraChatMemoryConfig.DEFAULT_SESSION_ID_NAME)) + .as("t") + .column(CassandraChatMemoryConfig.DEFAULT_SESSION_ID_NAME) + .whereToken(CassandraChatMemoryConfig.DEFAULT_SESSION_ID_NAME) + .isGreaterThan(QueryBuilder.bindMarker("after_token")) + .limit(10000); + + return this.conf.session.prepare(stmt.build()); + } + + private PreparedStatement prepareGetStatement() { + Select stmt = QueryBuilder.selectFrom(this.conf.schema.keyspace(), this.conf.schema.table()).all(); + for (var c : this.conf.schema.partitionKeys()) { + stmt = stmt.whereColumn(c.name()).isEqualTo(QueryBuilder.bindMarker(c.name())); + } + for (int i = 0; i + 1 < this.conf.schema.clusteringKeys().size(); ++i) { + String columnName = this.conf.schema.clusteringKeys().get(i).name(); + stmt = stmt.whereColumn(columnName).isEqualTo(QueryBuilder.bindMarker(columnName)); + } + return this.conf.session.prepare(stmt.build()); + } + + private PreparedStatement prepareDeleteStmt() { + Delete stmt = null; + DeleteSelection stmtStart = QueryBuilder.deleteFrom(this.conf.schema.keyspace(), this.conf.schema.table()); + for (var c : this.conf.schema.partitionKeys()) { + stmt = (null != stmt ? stmt : stmtStart).whereColumn(c.name()).isEqualTo(QueryBuilder.bindMarker(c.name())); + } + for (int i = 0; i + 1 < this.conf.schema.clusteringKeys().size(); ++i) { + String columnName = this.conf.schema.clusteringKeys().get(i).name(); + stmt = stmt.whereColumn(columnName).isEqualTo(QueryBuilder.bindMarker(columnName)); + } + return this.conf.session.prepare(stmt.build()); + } + + private PreparedStatement getStatement(Message msg) { + return switch (msg.getMessageType()) { + case USER -> this.addUserStmt; + case ASSISTANT -> this.addAssistantStmt; + default -> throw new IllegalArgumentException("Cant add type " + msg); + }; + } + +} diff --git a/memory/spring-ai-model-chat-memory-cassandra/src/test/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryIT.java b/memory/spring-ai-model-chat-memory-cassandra/src/test/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryRepositoryIT.java similarity index 76% rename from memory/spring-ai-model-chat-memory-cassandra/src/test/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryIT.java rename to memory/spring-ai-model-chat-memory-cassandra/src/test/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryRepositoryIT.java index 928d8b05ed2..1bf694f5057 100644 --- a/memory/spring-ai-model-chat-memory-cassandra/src/test/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryIT.java +++ b/memory/spring-ai-model-chat-memory-cassandra/src/test/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryRepositoryIT.java @@ -32,7 +32,6 @@ import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; -import org.springframework.ai.chat.memory.ChatMemory; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.MessageType; @@ -44,27 +43,39 @@ import org.springframework.context.annotation.Bean; import static org.assertj.core.api.Assertions.assertThat; +import org.springframework.ai.chat.memory.ChatMemoryRepository; /** - * Use `mvn failsafe:integration-test -Dit.test=CassandraChatMemoryIT` + * Use `mvn failsafe:integration-test -Dit.test=CassandraChatMemoryRepositoryIT` * * @author Mick Semb Wever * @author Thomas Vitale * @since 1.0.0 */ @Testcontainers -class CassandraChatMemoryIT { +class CassandraChatMemoryRepositoryIT { @Container static CassandraContainer cassandraContainer = new CassandraContainer(CassandraImage.DEFAULT_IMAGE); private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withUserConfiguration(CassandraChatMemoryIT.TestApplication.class); + .withUserConfiguration(CassandraChatMemoryRepositoryIT.TestApplication.class); @Test - void ensureBeanGetsCreated() { + void ensureLegacyBeanGetsCreated() { + + new ApplicationContextRunner().withUserConfiguration(CassandraChatMemoryRepositoryIT.TestApplication.class) + .run(context -> { + CassandraChatMemory memory = context.getBean(CassandraChatMemory.class); + Assertions.assertNotNull(memory); + memory.conf.checkSchemaValid(); + }); + } + + @Test + void ensureBeansGetsCreated() { this.contextRunner.run(context -> { - CassandraChatMemory memory = context.getBean(CassandraChatMemory.class); + CassandraChatMemoryRepository memory = context.getBean(CassandraChatMemoryRepository.class); Assertions.assertNotNull(memory); memory.conf.checkSchemaValid(); }); @@ -74,7 +85,8 @@ void ensureBeanGetsCreated() { @CsvSource({ "Message from assistant,ASSISTANT", "Message from user,USER" }) void add_shouldInsertSingleMessage(String content, MessageType messageType) { this.contextRunner.run(context -> { - var chatMemory = context.getBean(ChatMemory.class); + var chatMemory = context.getBean(ChatMemoryRepository.class); + assertThat(chatMemory instanceof CassandraChatMemoryRepository); var sessionId = UUID.randomUUID().toString(); var message = switch (messageType) { case ASSISTANT -> new AssistantMessage(content); @@ -82,7 +94,7 @@ void add_shouldInsertSingleMessage(String content, MessageType messageType) { default -> throw new IllegalArgumentException("Type not supported: " + messageType); }; - chatMemory.add(sessionId, message); + chatMemory.saveAll(sessionId, List.of(message)); var cqlSession = context.getBean(CqlSession.class); var query = """ @@ -113,12 +125,13 @@ else if (messageType == MessageType.USER) { @Test void add_shouldInsertMessages() { this.contextRunner.run(context -> { - var chatMemory = context.getBean(ChatMemory.class); + var chatMemory = context.getBean(ChatMemoryRepository.class); + assertThat(chatMemory instanceof CassandraChatMemoryRepository); var sessionId = UUID.randomUUID().toString(); var messages = List.of(new AssistantMessage("Message from assistant"), new UserMessage("Message from user")); - chatMemory.add(sessionId, messages); + chatMemory.saveAll(sessionId, messages); var cqlSession = context.getBean(CqlSession.class); var query = """ @@ -154,15 +167,18 @@ else if (message.getMessageType() == MessageType.USER) { @Test void get_shouldReturnMessages() { this.contextRunner.run(context -> { - var chatMemory = context.getBean(ChatMemory.class); + var chatMemory = context.getBean(ChatMemoryRepository.class); + assertThat(chatMemory instanceof CassandraChatMemoryRepository); var sessionId = UUID.randomUUID().toString(); var messages = List.of(new AssistantMessage("Message from assistant 1 - " + sessionId), new AssistantMessage("Message from assistant 2 - " + sessionId), new UserMessage("Message from user - " + sessionId)); - chatMemory.add(sessionId, messages); + chatMemory.saveAll(sessionId, messages); - var results = chatMemory.get(sessionId, Integer.MAX_VALUE); + assertThat(chatMemory.findConversationIds()).isNotEmpty(); + + var results = chatMemory.findByConversationId(sessionId); assertThat(results.size()).isEqualTo(messages.size()); @@ -179,15 +195,17 @@ void get_shouldReturnMessages() { @Test void get_afterMultipleAdds_shouldReturnMessagesInSameOrder() { this.contextRunner.run(context -> { - var chatMemory = context.getBean(ChatMemory.class); + var chatMemory = context.getBean(ChatMemoryRepository.class); + assertThat(chatMemory instanceof CassandraChatMemoryRepository); var sessionId = UUID.randomUUID().toString(); var userMessage = new UserMessage("Message from user - " + sessionId); var assistantMessage = new AssistantMessage("Message from assistant - " + sessionId); - chatMemory.add(sessionId, userMessage); - chatMemory.add(sessionId, assistantMessage); + chatMemory.saveAll(sessionId, List.of(userMessage, assistantMessage)); + + assertThat(chatMemory.findConversationIds()).isNotEmpty(); - var results = chatMemory.get(sessionId, Integer.MAX_VALUE); + var results = chatMemory.findByConversationId(sessionId); assertThat(results.size()).isEqualTo(2); @@ -205,14 +223,17 @@ void get_afterMultipleAdds_shouldReturnMessagesInSameOrder() { @Test void clear_shouldDeleteMessages() { this.contextRunner.run(context -> { - var chatMemory = context.getBean(ChatMemory.class); + var chatMemory = context.getBean(ChatMemoryRepository.class); + assertThat(chatMemory instanceof CassandraChatMemoryRepository); var sessionId = UUID.randomUUID().toString(); var messages = List.of(new AssistantMessage("Message from assistant - " + sessionId), new UserMessage("Message from user - " + sessionId)); - chatMemory.add(sessionId, messages); + chatMemory.saveAll(sessionId, messages); - chatMemory.clear(sessionId); + assertThat(chatMemory.findConversationIds()).isNotEmpty(); + + chatMemory.deleteByConversationId(sessionId); var cqlSession = context.getBean(CqlSession.class); var query = """ @@ -232,7 +253,7 @@ SELECT COUNT(*) public static class TestApplication { @Bean - public CassandraChatMemory memory(CqlSession cqlSession) { + public CassandraChatMemory memoryLegacy(CqlSession cqlSession) { var conf = CassandraChatMemoryConfig.builder() .withCqlSession(cqlSession) @@ -246,6 +267,21 @@ public CassandraChatMemory memory(CqlSession cqlSession) { return CassandraChatMemory.create(conf); } + @Bean + public CassandraChatMemoryRepository memory(CqlSession cqlSession) { + + var conf = CassandraChatMemoryConfig.builder() + .withCqlSession(cqlSession) + .withKeyspaceName("test_" + CassandraChatMemoryConfig.DEFAULT_KEYSPACE_NAME) + .withAssistantColumnName("a") + .withUserColumnName("u") + .withTimeToLive(Duration.ofMinutes(1)) + .build(); + + conf.dropKeyspace(); + return CassandraChatMemoryRepository.create(conf); + } + @Bean public CqlSession cqlSession() { return new CqlSessionBuilder() diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat-memory.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat-memory.adoc index 8b1427f3f97..39a2fd80cca 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat-memory.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat-memory.adoc @@ -136,6 +136,84 @@ You can disable the schema initialization by setting the property `spring.ai.cha If your project uses a tool like Flyway or Liquibase to manage your database schemas, you can disable the schema initialization and refer to link:https://github.com/spring-projects/spring-ai/tree/main/memory/spring-ai-model-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/jdbc[these SQL scripts] for configuring those tools to create the `ai_chat_memory` table. + +=== CassandraChatMemoryRepository + +`CassandraChatMemoryRepository` uses Aache Cassandra to store messages. It is suitable for applications that require persistent storage of chat memory, especially for availability, or at scale, or when taking advantage of time-to-live (TTL) messages. + +First, add the following dependency to your project: + +[tabs] +====== +Maven:: ++ +[source, xml] +---- + + org.springframework.ai + spring-ai-starter-model-chat-memory-cassandra + +---- + +Gradle:: ++ +[source,groovy] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-starter-model-chat-memory-cassandra' +} +---- +====== + +Spring AI provides auto-configuration for the `CassandraChatMemoryRepository`, that you can use directly in your application. + +[source,java] +---- +@Autowired +CassandraChatMemoryRepository chatMemoryRepository; + +ChatMemory chatMemory = MessageWindowChatMemory.builder() + .chatMemoryRepository(chatMemoryRepository) + .maxMessages(10) + .build(); +---- + +If you'd rather create the `CassandraChatMemoryRepository` manually, you can do so by providing a `CassandraChatMemoryConfig` instance: + +[source,java] +---- +ChatMemoryRepository chatMemoryRepository = CassandraChatMemoryRepository + .create(CassandraChatMemoryConfig.builder().withCqlSession(cqlSession)); + +ChatMemory chatMemory = MessageWindowChatMemory.builder() + .chatMemoryRepository(chatMemoryRepository) + .maxMessages(10) + .build(); +---- + +==== Configuration Properties + +[cols="2,5,1",stripes=even] +|=== +|Property | Description | Default Value +| `spring.cassandra.contactPoints` | host(s) to initiate cluster discovery | `127.0.0.1` +| `spring.cassandra.port` | Cassandra native protocol port to connect to | `9042` +| `spring.cassandra.localDatacenter` | Cassandra datacenter to connect to | `datacenter1` +| `spring.ai.chat.memory.cassandra.time-to-live` | Time to live (TTL) messages are written with in Cassandra | +| `spring.ai.chat.memory.cassandra.keyspace` | Cassandra keyspace | `springframework` +| `spring.ai.chat.memory.cassandra.table` | Cassandra table | `ai_chat_memory` +| `spring.ai.chat.memory.cassandra.initialize-schema` | Whether to initialize the schema on startup. | `true` +|=== + +Configuration properties are the same and compatible with the previous and now deprecated `CassandraChatMemory`. + +==== Schema Initialization + +The auto-configuration will automatically create the `ai_chat_memory` table. + +You can disable the schema initialization by setting the property `spring.ai.chat.memory.repository.jdbc.initialize-schema` to `false`. + + === Neo4j ChatMemoryRepository `Neo4jChatMemoryRepository` is a built-in implementation that uses Neo4j to store chat messages as nodes and relationships in a property graph database. It is suitable for applications that want to leverage Neo4j's graph capabilities for chat memory persistence. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc index ab5629bba48..32b0b70cb9d 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc @@ -545,6 +545,8 @@ To create a `CassandraChatMemory` with `time-to-live`: CassandraChatMemory.create(CassandraChatMemoryConfig.builder().withTimeToLive(Duration.ofDays(1)).build()); ---- +IMPORTANT: Refer to the new xref:api/chat-memory.adoc#_cassandra_repository[Cassandra Chat Memory Repository] documentation for the current features and capabilities. + === Neo4jChatMemory The Neo4j chat memory supports the following configuration parameters: