Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
</parent>
<artifactId>spring-ai-autoconfigure-model-chat-memory-cassandra</artifactId>
<packaging>jar</packaging>
<name>Spring AI Cassandra Chat Memory Auto Configuration</name>
<description>Spring AI Cassandra Chat Memory Auto Configuration</description>
<name>Spring AI Apache Cassandra Chat Memory Auto Configuration</name>
<description>Spring AI Apache Cassandra Chat Memory Auto Configuration</description>
<url>https://github.com/spring-projects/spring-ai</url>

<scm>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

import com.datastax.oss.driver.api.core.CqlSession;

import org.springframework.ai.chat.memory.cassandra.CassandraChatMemory;
import org.springframework.ai.chat.memory.cassandra.CassandraChatMemoryConfig;
import org.springframework.ai.chat.memory.cassandra.CassandraChatMemoryRepository;
import org.springframework.ai.model.chat.memory.autoconfigure.ChatMemoryAutoConfiguration;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.cassandra.CassandraAutoConfiguration;
Expand All @@ -36,13 +36,13 @@
* @since 1.0.0
*/
@AutoConfiguration(after = CassandraAutoConfiguration.class, before = ChatMemoryAutoConfiguration.class)
@ConditionalOnClass({ CassandraChatMemory.class, CqlSession.class })
@ConditionalOnClass({ CassandraChatMemoryRepository.class, CqlSession.class })
@EnableConfigurationProperties(CassandraChatMemoryProperties.class)
public class CassandraChatMemoryAutoConfiguration {

@Bean
@ConditionalOnMissingBean
public CassandraChatMemory chatMemory(CassandraChatMemoryProperties properties, CqlSession cqlSession) {
public CassandraChatMemoryRepository chatMemory(CassandraChatMemoryProperties properties, CqlSession cqlSession) {

var builder = CassandraChatMemoryConfig.builder().withCqlSession(cqlSession);

Expand All @@ -58,7 +58,7 @@ public CassandraChatMemory chatMemory(CassandraChatMemoryProperties properties,
builder = builder.withTimeToLive(properties.getTimeToLive());
}

return CassandraChatMemory.create(builder.build());
return CassandraChatMemoryRepository.create(builder.build());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import org.testcontainers.junit.jupiter.Testcontainers;
import org.testcontainers.utility.DockerImageName;

import org.springframework.ai.chat.memory.cassandra.CassandraChatMemory;
import org.springframework.ai.chat.memory.cassandra.CassandraChatMemoryRepository;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.UserMessage;
Expand Down Expand Up @@ -61,30 +61,29 @@ void addAndGet() {
.withPropertyValues("spring.cassandra.localDatacenter=" + cassandraContainer.getLocalDatacenter())
.withPropertyValues("spring.ai.chat.memory.cassandra.time-to-live=" + getTimeToLive())
.run(context -> {
CassandraChatMemory memory = context.getBean(CassandraChatMemory.class);
CassandraChatMemoryRepository memory = context.getBean(CassandraChatMemoryRepository.class);

String sessionId = UUIDs.timeBased().toString();
assertThat(memory.get(sessionId, Integer.MAX_VALUE)).isEmpty();
assertThat(memory.findByConversationId(sessionId)).isEmpty();

memory.add(sessionId, new UserMessage("test question"));
memory.saveAll(sessionId, List.of(new UserMessage("test question")));

assertThat(memory.get(sessionId, Integer.MAX_VALUE)).hasSize(1);
assertThat(memory.get(sessionId, Integer.MAX_VALUE).get(0).getMessageType())
.isEqualTo(MessageType.USER);
assertThat(memory.get(sessionId, Integer.MAX_VALUE).get(0).getText()).isEqualTo("test question");
assertThat(memory.findByConversationId(sessionId)).hasSize(1);
assertThat(memory.findByConversationId(sessionId).get(0).getMessageType()).isEqualTo(MessageType.USER);
assertThat(memory.findByConversationId(sessionId).get(0).getText()).isEqualTo("test question");

memory.clear(sessionId);
assertThat(memory.get(sessionId, Integer.MAX_VALUE)).isEmpty();
memory.deleteByConversationId(sessionId);
assertThat(memory.findByConversationId(sessionId)).isEmpty();

memory.add(sessionId, List.of(new UserMessage("test question"), new AssistantMessage("test answer")));
memory.saveAll(sessionId,
List.of(new UserMessage("test question"), new AssistantMessage("test answer")));

assertThat(memory.get(sessionId, Integer.MAX_VALUE)).hasSize(2);
assertThat(memory.get(sessionId, Integer.MAX_VALUE).get(1).getMessageType())
assertThat(memory.findByConversationId(sessionId)).hasSize(2);
assertThat(memory.findByConversationId(sessionId).get(1).getMessageType())
.isEqualTo(MessageType.ASSISTANT);
assertThat(memory.get(sessionId, Integer.MAX_VALUE).get(1).getText()).isEqualTo("test answer");
assertThat(memory.get(sessionId, Integer.MAX_VALUE).get(0).getMessageType())
.isEqualTo(MessageType.USER);
assertThat(memory.get(sessionId, Integer.MAX_VALUE).get(0).getText()).isEqualTo("test question");
assertThat(memory.findByConversationId(sessionId).get(1).getText()).isEqualTo("test answer");
assertThat(memory.findByConversationId(sessionId).get(0).getMessageType()).isEqualTo(MessageType.USER);
assertThat(memory.findByConversationId(sessionId).get(0).getText()).isEqualTo("test question");

CassandraChatMemoryProperties properties = context.getBean(CassandraChatMemoryProperties.class);
assertThat(properties.getTimeToLive()).isEqualTo(getTimeToLive());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
</parent>
<artifactId>spring-ai-autoconfigure-vector-store-cassandra</artifactId>
<packaging>jar</packaging>
<name>Spring AI Auto Configuration for Cassandra vector store</name>
<description>Spring AI Auto Configuration for Cassandra vector store</description>
<name>Spring AI Auto Configuration for Apache Cassandra vector store</name>
<description>Spring AI Auto Configuration for Apache Cassandra vector store</description>
<url>https://github.com/spring-projects/spring-ai</url>

<scm>
Expand Down
4 changes: 2 additions & 2 deletions memory/spring-ai-model-chat-memory-cassandra/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
</parent>

<artifactId>spring-ai-model-chat-memory-cassandra</artifactId>
<name>Spring AI Cassandra Chat Memory</name>
<description>Spring AI Cassandra Chat Memory implementation</description>
<name>Spring AI Apache Cassandra Chat Memory</name>
<description>Spring AI Apache Cassandra Chat Memory implementation</description>

<url>https://github.com/spring-projects/spring-ai</url>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,60 +16,32 @@

package org.springframework.ai.chat.memory.cassandra;

import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;

import com.datastax.oss.driver.api.core.cql.BoundStatementBuilder;
import com.datastax.oss.driver.api.core.cql.PreparedStatement;
import com.datastax.oss.driver.api.core.cql.Row;
import com.datastax.oss.driver.api.querybuilder.QueryBuilder;
import com.datastax.oss.driver.api.querybuilder.delete.Delete;
import com.datastax.oss.driver.api.querybuilder.delete.DeleteSelection;
import com.datastax.oss.driver.api.querybuilder.insert.InsertInto;
import com.datastax.oss.driver.api.querybuilder.insert.RegularInsert;
import com.datastax.oss.driver.api.querybuilder.select.Select;
import com.datastax.oss.driver.shaded.guava.common.base.Preconditions;

import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.memory.cassandra.CassandraChatMemoryConfig.SchemaColumn;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;

/**
* @deprecated Use CassandraChatMemoryRepository
*
* Create a CassandraChatMemory like <code>
CassandraChatMemory.create(CassandraChatMemoryConfig.builder().withTimeToLive(Duration.ofDays(1)).build());
</code>
*
* For example @see org.springframework.ai.chat.memory.cassandra.CassandraChatMemory
*
* @author Mick Semb Wever
* @since 1.0.0
*/
@Deprecated
public final class CassandraChatMemory implements ChatMemory {

public static final String CONVERSATION_TS = CassandraChatMemory.class.getSimpleName() + "_message_timestamp";

final CassandraChatMemoryConfig conf;

private final PreparedStatement addUserStmt;

private final PreparedStatement addAssistantStmt;

private final PreparedStatement getStmt;

private final PreparedStatement deleteStmt;
final CassandraChatMemoryRepository repo;

public CassandraChatMemory(CassandraChatMemoryConfig config) {
this.conf = config;
this.conf.ensureSchemaExists();
this.addUserStmt = prepareAddStmt(this.conf.userColumn);
this.addAssistantStmt = prepareAddStmt(this.conf.assistantColumn);
this.getStmt = prepareGetStatement();
this.deleteStmt = prepareDeleteStmt();
repo = CassandraChatMemoryRepository.create(conf);
}

public static CassandraChatMemory create(CassandraChatMemoryConfig conf) {
Expand All @@ -78,128 +50,22 @@ public static CassandraChatMemory create(CassandraChatMemoryConfig conf) {

@Override
public void add(String conversationId, List<Message> messages) {
final AtomicLong instantSeq = new AtomicLong(Instant.now().toEpochMilli());
messages.forEach(msg -> {
if (msg.getMetadata().containsKey(CONVERSATION_TS)) {
msg.getMetadata().put(CONVERSATION_TS, Instant.ofEpochMilli(instantSeq.getAndIncrement()));
}
add(conversationId, msg);
});
repo.saveAll(conversationId, messages);
}

@Override
public void add(String sessionId, Message msg) {

Preconditions.checkArgument(
!msg.getMetadata().containsKey(CONVERSATION_TS)
|| msg.getMetadata().get(CONVERSATION_TS) instanceof Instant,
"messages only accept metadata '%s' entries of type Instant", CONVERSATION_TS);

msg.getMetadata().putIfAbsent(CONVERSATION_TS, Instant.now());

PreparedStatement stmt = getStatement(msg);

List<Object> primaryKeys = this.conf.primaryKeyTranslator.apply(sessionId);
BoundStatementBuilder builder = stmt.boundStatementBuilder();

for (int k = 0; k < primaryKeys.size(); ++k) {
SchemaColumn keyColumn = this.conf.getPrimaryKeyColumn(k);
builder = builder.set(keyColumn.name(), primaryKeys.get(k), keyColumn.javaType());
}

Instant instant = (Instant) msg.getMetadata().get(CONVERSATION_TS);

builder = builder.setInstant(CassandraChatMemoryConfig.DEFAULT_EXCHANGE_ID_NAME, instant)
.setString("message", msg.getText());

this.conf.session.execute(builder.build());
}

PreparedStatement getStatement(Message msg) {
return switch (msg.getMessageType()) {
case USER -> this.addUserStmt;
case ASSISTANT -> this.addAssistantStmt;
default -> throw new IllegalArgumentException("Cant add type " + msg);
};
repo.save(sessionId, msg);
}

@Override
public void clear(String sessionId) {

List<Object> primaryKeys = this.conf.primaryKeyTranslator.apply(sessionId);
BoundStatementBuilder builder = this.deleteStmt.boundStatementBuilder();

for (int k = 0; k < primaryKeys.size(); ++k) {
SchemaColumn keyColumn = this.conf.getPrimaryKeyColumn(k);
builder = builder.set(keyColumn.name(), primaryKeys.get(k), keyColumn.javaType());
}

this.conf.session.execute(builder.build());
repo.deleteByConversationId(sessionId);
}

@Override
public List<Message> get(String sessionId, int lastN) {

List<Object> primaryKeys = this.conf.primaryKeyTranslator.apply(sessionId);
BoundStatementBuilder builder = this.getStmt.boundStatementBuilder().setInt("lastN", lastN);

for (int k = 0; k < primaryKeys.size(); ++k) {
SchemaColumn keyColumn = this.conf.getPrimaryKeyColumn(k);
builder = builder.set(keyColumn.name(), primaryKeys.get(k), keyColumn.javaType());
}

List<Message> messages = new ArrayList<>();
for (Row r : this.conf.session.execute(builder.build())) {
String assistant = r.getString(this.conf.assistantColumn);
String user = r.getString(this.conf.userColumn);
if (null != assistant) {
messages.add(new AssistantMessage(assistant));
}
if (null != user) {
messages.add(new UserMessage(user));
}
}
Collections.reverse(messages);
return messages;
}

private PreparedStatement prepareAddStmt(String column) {
RegularInsert stmt = null;
InsertInto stmtStart = QueryBuilder.insertInto(this.conf.schema.keyspace(), this.conf.schema.table());
for (var c : this.conf.schema.partitionKeys()) {
stmt = (null != stmt ? stmt : stmtStart).value(c.name(), QueryBuilder.bindMarker(c.name()));
}
for (var c : this.conf.schema.clusteringKeys()) {
stmt = stmt.value(c.name(), QueryBuilder.bindMarker(c.name()));
}
stmt = stmt.value(column, QueryBuilder.bindMarker("message"));
return this.conf.session.prepare(stmt.build());
}

private PreparedStatement prepareGetStatement() {
Select stmt = QueryBuilder.selectFrom(this.conf.schema.keyspace(), this.conf.schema.table()).all();
for (var c : this.conf.schema.partitionKeys()) {
stmt = stmt.whereColumn(c.name()).isEqualTo(QueryBuilder.bindMarker(c.name()));
}
for (int i = 0; i + 1 < this.conf.schema.clusteringKeys().size(); ++i) {
String columnName = this.conf.schema.clusteringKeys().get(i).name();
stmt = stmt.whereColumn(columnName).isEqualTo(QueryBuilder.bindMarker(columnName));
}
stmt = stmt.limit(QueryBuilder.bindMarker("lastN"));
return this.conf.session.prepare(stmt.build());
}

private PreparedStatement prepareDeleteStmt() {
Delete stmt = null;
DeleteSelection stmtStart = QueryBuilder.deleteFrom(this.conf.schema.keyspace(), this.conf.schema.table());
for (var c : this.conf.schema.partitionKeys()) {
stmt = (null != stmt ? stmt : stmtStart).whereColumn(c.name()).isEqualTo(QueryBuilder.bindMarker(c.name()));
}
for (int i = 0; i + 1 < this.conf.schema.clusteringKeys().size(); ++i) {
String columnName = this.conf.schema.clusteringKeys().get(i).name();
stmt = stmt.whereColumn(columnName).isEqualTo(QueryBuilder.bindMarker(columnName));
}
return this.conf.session.prepare(stmt.build());
return repo.findByConversationId(sessionId).subList(0, lastN);
}

}
Loading