1616
1717package org .springframework .ai .chat .memory .cassandra ;
1818
19- import java .time .Instant ;
20- import java .util .ArrayList ;
21- import java .util .Collections ;
2219import java .util .List ;
23- import java .util .concurrent .atomic .AtomicLong ;
24-
25- import com .datastax .oss .driver .api .core .cql .BoundStatementBuilder ;
26- import com .datastax .oss .driver .api .core .cql .PreparedStatement ;
27- import com .datastax .oss .driver .api .core .cql .Row ;
28- import com .datastax .oss .driver .api .querybuilder .QueryBuilder ;
29- import com .datastax .oss .driver .api .querybuilder .delete .Delete ;
30- import com .datastax .oss .driver .api .querybuilder .delete .DeleteSelection ;
31- import com .datastax .oss .driver .api .querybuilder .insert .InsertInto ;
32- import com .datastax .oss .driver .api .querybuilder .insert .RegularInsert ;
33- import com .datastax .oss .driver .api .querybuilder .select .Select ;
34- import com .datastax .oss .driver .shaded .guava .common .base .Preconditions ;
3520
3621import org .springframework .ai .chat .memory .ChatMemory ;
37- import org .springframework .ai .chat .memory .cassandra .CassandraChatMemoryConfig .SchemaColumn ;
38- import org .springframework .ai .chat .messages .AssistantMessage ;
3922import org .springframework .ai .chat .messages .Message ;
40- import org .springframework .ai .chat .messages .UserMessage ;
4123
4224/**
25+ * @deprecated Use CassandraChatMemoryRepository
26+ *
4327 * Create a CassandraChatMemory like <code>
4428 CassandraChatMemory.create(CassandraChatMemoryConfig.builder().withTimeToLive(Duration.ofDays(1)).build());
4529 </code>
4630 *
4731 * For example @see org.springframework.ai.chat.memory.cassandra.CassandraChatMemory
48- *
4932 * @author Mick Semb Wever
5033 * @since 1.0.0
5134 */
35+ @ Deprecated
5236public final class CassandraChatMemory implements ChatMemory {
5337
54- public static final String CONVERSATION_TS = CassandraChatMemory .class .getSimpleName () + "_message_timestamp" ;
55-
5638 final CassandraChatMemoryConfig conf ;
5739
58- private final PreparedStatement addUserStmt ;
59-
60- private final PreparedStatement addAssistantStmt ;
61-
62- private final PreparedStatement getStmt ;
63-
64- private final PreparedStatement deleteStmt ;
40+ final CassandraChatMemoryRepository repo ;
6541
6642 public CassandraChatMemory (CassandraChatMemoryConfig config ) {
6743 this .conf = config ;
68- this .conf .ensureSchemaExists ();
69- this .addUserStmt = prepareAddStmt (this .conf .userColumn );
70- this .addAssistantStmt = prepareAddStmt (this .conf .assistantColumn );
71- this .getStmt = prepareGetStatement ();
72- this .deleteStmt = prepareDeleteStmt ();
44+ repo = CassandraChatMemoryRepository .create (conf );
7345 }
7446
7547 public static CassandraChatMemory create (CassandraChatMemoryConfig conf ) {
@@ -78,128 +50,22 @@ public static CassandraChatMemory create(CassandraChatMemoryConfig conf) {
7850
7951 @ Override
8052 public void add (String conversationId , List <Message > messages ) {
81- final AtomicLong instantSeq = new AtomicLong (Instant .now ().toEpochMilli ());
82- messages .forEach (msg -> {
83- if (msg .getMetadata ().containsKey (CONVERSATION_TS )) {
84- msg .getMetadata ().put (CONVERSATION_TS , Instant .ofEpochMilli (instantSeq .getAndIncrement ()));
85- }
86- add (conversationId , msg );
87- });
53+ repo .saveAll (conversationId , messages );
8854 }
8955
9056 @ Override
9157 public void add (String sessionId , Message msg ) {
92-
93- Preconditions .checkArgument (
94- !msg .getMetadata ().containsKey (CONVERSATION_TS )
95- || msg .getMetadata ().get (CONVERSATION_TS ) instanceof Instant ,
96- "messages only accept metadata '%s' entries of type Instant" , CONVERSATION_TS );
97-
98- msg .getMetadata ().putIfAbsent (CONVERSATION_TS , Instant .now ());
99-
100- PreparedStatement stmt = getStatement (msg );
101-
102- List <Object > primaryKeys = this .conf .primaryKeyTranslator .apply (sessionId );
103- BoundStatementBuilder builder = stmt .boundStatementBuilder ();
104-
105- for (int k = 0 ; k < primaryKeys .size (); ++k ) {
106- SchemaColumn keyColumn = this .conf .getPrimaryKeyColumn (k );
107- builder = builder .set (keyColumn .name (), primaryKeys .get (k ), keyColumn .javaType ());
108- }
109-
110- Instant instant = (Instant ) msg .getMetadata ().get (CONVERSATION_TS );
111-
112- builder = builder .setInstant (CassandraChatMemoryConfig .DEFAULT_EXCHANGE_ID_NAME , instant )
113- .setString ("message" , msg .getText ());
114-
115- this .conf .session .execute (builder .build ());
116- }
117-
118- PreparedStatement getStatement (Message msg ) {
119- return switch (msg .getMessageType ()) {
120- case USER -> this .addUserStmt ;
121- case ASSISTANT -> this .addAssistantStmt ;
122- default -> throw new IllegalArgumentException ("Cant add type " + msg );
123- };
58+ repo .save (sessionId , msg );
12459 }
12560
12661 @ Override
12762 public void clear (String sessionId ) {
128-
129- List <Object > primaryKeys = this .conf .primaryKeyTranslator .apply (sessionId );
130- BoundStatementBuilder builder = this .deleteStmt .boundStatementBuilder ();
131-
132- for (int k = 0 ; k < primaryKeys .size (); ++k ) {
133- SchemaColumn keyColumn = this .conf .getPrimaryKeyColumn (k );
134- builder = builder .set (keyColumn .name (), primaryKeys .get (k ), keyColumn .javaType ());
135- }
136-
137- this .conf .session .execute (builder .build ());
63+ repo .deleteByConversationId (sessionId );
13864 }
13965
14066 @ Override
14167 public List <Message > get (String sessionId , int lastN ) {
142-
143- List <Object > primaryKeys = this .conf .primaryKeyTranslator .apply (sessionId );
144- BoundStatementBuilder builder = this .getStmt .boundStatementBuilder ().setInt ("lastN" , lastN );
145-
146- for (int k = 0 ; k < primaryKeys .size (); ++k ) {
147- SchemaColumn keyColumn = this .conf .getPrimaryKeyColumn (k );
148- builder = builder .set (keyColumn .name (), primaryKeys .get (k ), keyColumn .javaType ());
149- }
150-
151- List <Message > messages = new ArrayList <>();
152- for (Row r : this .conf .session .execute (builder .build ())) {
153- String assistant = r .getString (this .conf .assistantColumn );
154- String user = r .getString (this .conf .userColumn );
155- if (null != assistant ) {
156- messages .add (new AssistantMessage (assistant ));
157- }
158- if (null != user ) {
159- messages .add (new UserMessage (user ));
160- }
161- }
162- Collections .reverse (messages );
163- return messages ;
164- }
165-
166- private PreparedStatement prepareAddStmt (String column ) {
167- RegularInsert stmt = null ;
168- InsertInto stmtStart = QueryBuilder .insertInto (this .conf .schema .keyspace (), this .conf .schema .table ());
169- for (var c : this .conf .schema .partitionKeys ()) {
170- stmt = (null != stmt ? stmt : stmtStart ).value (c .name (), QueryBuilder .bindMarker (c .name ()));
171- }
172- for (var c : this .conf .schema .clusteringKeys ()) {
173- stmt = stmt .value (c .name (), QueryBuilder .bindMarker (c .name ()));
174- }
175- stmt = stmt .value (column , QueryBuilder .bindMarker ("message" ));
176- return this .conf .session .prepare (stmt .build ());
177- }
178-
179- private PreparedStatement prepareGetStatement () {
180- Select stmt = QueryBuilder .selectFrom (this .conf .schema .keyspace (), this .conf .schema .table ()).all ();
181- for (var c : this .conf .schema .partitionKeys ()) {
182- stmt = stmt .whereColumn (c .name ()).isEqualTo (QueryBuilder .bindMarker (c .name ()));
183- }
184- for (int i = 0 ; i + 1 < this .conf .schema .clusteringKeys ().size (); ++i ) {
185- String columnName = this .conf .schema .clusteringKeys ().get (i ).name ();
186- stmt = stmt .whereColumn (columnName ).isEqualTo (QueryBuilder .bindMarker (columnName ));
187- }
188- stmt = stmt .limit (QueryBuilder .bindMarker ("lastN" ));
189- return this .conf .session .prepare (stmt .build ());
190- }
191-
192- private PreparedStatement prepareDeleteStmt () {
193- Delete stmt = null ;
194- DeleteSelection stmtStart = QueryBuilder .deleteFrom (this .conf .schema .keyspace (), this .conf .schema .table ());
195- for (var c : this .conf .schema .partitionKeys ()) {
196- stmt = (null != stmt ? stmt : stmtStart ).whereColumn (c .name ()).isEqualTo (QueryBuilder .bindMarker (c .name ()));
197- }
198- for (int i = 0 ; i + 1 < this .conf .schema .clusteringKeys ().size (); ++i ) {
199- String columnName = this .conf .schema .clusteringKeys ().get (i ).name ();
200- stmt = stmt .whereColumn (columnName ).isEqualTo (QueryBuilder .bindMarker (columnName ));
201- }
202- return this .conf .session .prepare (stmt .build ());
68+ return repo .findByConversationId (sessionId ).subList (0 , lastN );
20369 }
20470
20571}
0 commit comments