2626import com .datastax .oss .driver .api .core .cql .BoundStatementBuilder ;
2727import com .datastax .oss .driver .api .core .cql .PreparedStatement ;
2828import com .datastax .oss .driver .api .core .cql .Row ;
29+ import com .datastax .oss .driver .api .core .data .UdtValue ;
30+ import com .datastax .oss .driver .api .core .type .UserDefinedType ;
2931import com .datastax .oss .driver .api .querybuilder .QueryBuilder ;
32+ import com .datastax .oss .driver .api .querybuilder .SchemaBuilder ;
3033import com .datastax .oss .driver .api .querybuilder .delete .Delete ;
3134import com .datastax .oss .driver .api .querybuilder .delete .DeleteSelection ;
3235import com .datastax .oss .driver .api .querybuilder .insert .InsertInto ;
3336import com .datastax .oss .driver .api .querybuilder .insert .RegularInsert ;
3437import com .datastax .oss .driver .api .querybuilder .select .Select ;
3538import com .datastax .oss .driver .shaded .guava .common .base .Preconditions ;
39+ import java .util .Map ;
3640import org .springframework .ai .chat .memory .ChatMemoryRepository ;
3741import org .springframework .ai .chat .messages .AssistantMessage ;
3842import org .springframework .ai .chat .messages .Message ;
43+ import org .springframework .ai .chat .messages .MessageType ;
44+ import static org .springframework .ai .chat .messages .MessageType .ASSISTANT ;
45+ import org .springframework .ai .chat .messages .SystemMessage ;
46+ import org .springframework .ai .chat .messages .ToolResponseMessage ;
3947import org .springframework .ai .chat .messages .UserMessage ;
4048import org .springframework .util .Assert ;
4149
@@ -54,23 +62,17 @@ public class CassandraChatMemoryRepository implements ChatMemoryRepository {
5462
5563 private final PreparedStatement allStmt ;
5664
57- private final PreparedStatement addUserStmt ;
58-
59- private final PreparedStatement addAssistantStmt ;
65+ private final PreparedStatement addStmt ;
6066
6167 private final PreparedStatement getStmt ;
6268
63- private final PreparedStatement deleteStmt ;
64-
6569 private CassandraChatMemoryRepository (CassandraChatMemoryRepositoryConfig conf ) {
6670 Assert .notNull (conf , "conf cannot be null" );
6771 this .conf = conf ;
6872 this .conf .ensureSchemaExists ();
6973 this .allStmt = prepareAllStatement ();
70- this .addUserStmt = prepareAddStmt (this .conf .userColumn );
71- this .addAssistantStmt = prepareAddStmt (this .conf .assistantColumn );
74+ this .addStmt = prepareAddStmt ();
7275 this .getStmt = prepareGetStatement ();
73- this .deleteStmt = prepareDeleteStmt ();
7476 }
7577
7678 public static CassandraChatMemoryRepository create (CassandraChatMemoryRepositoryConfig conf ) {
@@ -97,6 +99,10 @@ public List<String> findConversationIds() {
9799
98100 @ Override
99101 public List <Message > findByConversationId (String conversationId ) {
102+ return findByConversationIdWithLimit (conversationId , 1 );
103+ }
104+
105+ List <Message > findByConversationIdWithLimit (String conversationId , int limit ) {
100106 Assert .hasText (conversationId , "conversationId cannot be null or empty" );
101107
102108 List <Object > primaryKeys = this .conf .primaryKeyTranslator .apply (conversationId );
@@ -106,19 +112,14 @@ public List<Message> findByConversationId(String conversationId) {
106112 CassandraChatMemoryRepositoryConfig .SchemaColumn keyColumn = this .conf .getPrimaryKeyColumn (k );
107113 builder = builder .set (keyColumn .name (), primaryKeys .get (k ), keyColumn .javaType ());
108114 }
115+ builder = builder .setInt ("legacy_limit" , limit );
109116
110117 List <Message > messages = new ArrayList <>();
111118 for (Row r : this .conf .session .execute (builder .build ())) {
112- String assistant = r .getString (this .conf .assistantColumn );
113- String user = r .getString (this .conf .userColumn );
114- if (null != assistant ) {
115- messages .add (new AssistantMessage (assistant ));
116- }
117- if (null != user ) {
118- messages .add (new UserMessage (user ));
119+ for (UdtValue udt : r .getList (this .conf .messagesColumn , UdtValue .class )) {
120+ messages .add (getMessage (udt ));
119121 }
120122 }
121- Collections .reverse (messages );
122123 return messages ;
123124 }
124125
@@ -128,58 +129,49 @@ public void saveAll(String conversationId, List<Message> messages) {
128129 Assert .notNull (messages , "messages cannot be null" );
129130 Assert .noNullElements (messages , "messages cannot contain null elements" );
130131
131- final AtomicLong instantSeq = new AtomicLong (Instant .now ().toEpochMilli ());
132- messages .forEach (msg -> {
133- if (msg .getMetadata ().containsKey (CONVERSATION_TS )) {
134- msg .getMetadata ().put (CONVERSATION_TS , Instant .ofEpochMilli (instantSeq .getAndIncrement ()));
135- }
136- save (conversationId , msg );
137- });
138- }
139-
140- void save (String conversationId , Message msg ) {
141-
142- Preconditions .checkArgument (
143- !msg .getMetadata ().containsKey (CONVERSATION_TS )
144- || msg .getMetadata ().get (CONVERSATION_TS ) instanceof Instant ,
145- "messages only accept metadata '%s' entries of type Instant" , CONVERSATION_TS );
146-
147- msg .getMetadata ().putIfAbsent (CONVERSATION_TS , Instant .now ());
148-
149- PreparedStatement stmt = getStatement (msg );
150-
132+ Instant instant = Instant .now ();
151133 List <Object > primaryKeys = this .conf .primaryKeyTranslator .apply (conversationId );
152- BoundStatementBuilder builder = stmt .boundStatementBuilder ();
134+ BoundStatementBuilder builder = addStmt .boundStatementBuilder ();
153135
154136 for (int k = 0 ; k < primaryKeys .size (); ++k ) {
155137 CassandraChatMemoryRepositoryConfig .SchemaColumn keyColumn = this .conf .getPrimaryKeyColumn (k );
156138 builder = builder .set (keyColumn .name (), primaryKeys .get (k ), keyColumn .javaType ());
157139 }
158140
159- Instant instant = (Instant ) msg .getMetadata ().get (CONVERSATION_TS );
141+ List <UdtValue > msgs = new ArrayList <>();
142+ for (Message msg : messages ) {
143+
144+ Preconditions .checkArgument (
145+ !msg .getMetadata ().containsKey (CONVERSATION_TS )
146+ || msg .getMetadata ().get (CONVERSATION_TS ) instanceof Instant ,
147+ "messages only accept metadata '%s' entries of type Instant" , CONVERSATION_TS );
160148
149+ msg .getMetadata ().putIfAbsent (CONVERSATION_TS , instant );
150+
151+ UdtValue udt = this .conf .session .getMetadata ()
152+ .getKeyspace (this .conf .schema .keyspace ())
153+ .get ()
154+ .getUserDefinedType (this .conf .messageUDT )
155+ .get ()
156+ .newValue ()
157+ .setInstant (this .conf .messageUdtTimestampColumn , (Instant ) msg .getMetadata ().get (CONVERSATION_TS ))
158+ .setString (this .conf .messageUdtTypeColumn , msg .getMessageType ().name ())
159+ .setString (this .conf .messageUdtContentColumn , msg .getText ());
160+
161+ msgs .add (udt );
162+ }
161163 builder = builder .setInstant (CassandraChatMemoryRepositoryConfig .DEFAULT_EXCHANGE_ID_NAME , instant )
162- .setString ( "message " , msg . getText () );
164+ .setList ( "msgs " , msgs , UdtValue . class );
163165
164166 this .conf .session .execute (builder .build ());
165167 }
166168
167169 @ Override
168170 public void deleteByConversationId (String conversationId ) {
169- Assert .hasText (conversationId , "conversationId cannot be null or empty" );
170-
171- List <Object > primaryKeys = this .conf .primaryKeyTranslator .apply (conversationId );
172- BoundStatementBuilder builder = this .deleteStmt .boundStatementBuilder ();
173-
174- for (int k = 0 ; k < primaryKeys .size (); ++k ) {
175- CassandraChatMemoryRepositoryConfig .SchemaColumn keyColumn = this .conf .getPrimaryKeyColumn (k );
176- builder = builder .set (keyColumn .name (), primaryKeys .get (k ), keyColumn .javaType ());
177- }
178-
179- this .conf .session .execute (builder .build ());
171+ saveAll (conversationId , List .of ());
180172 }
181173
182- private PreparedStatement prepareAddStmt (String column ) {
174+ private PreparedStatement prepareAddStmt () {
183175 RegularInsert stmt = null ;
184176 InsertInto stmtStart = QueryBuilder .insertInto (this .conf .schema .keyspace (), this .conf .schema .table ());
185177 for (var c : this .conf .schema .partitionKeys ()) {
@@ -188,7 +180,7 @@ private PreparedStatement prepareAddStmt(String column) {
188180 for (var c : this .conf .schema .clusteringKeys ()) {
189181 stmt = stmt .value (c .name (), QueryBuilder .bindMarker (c .name ()));
190182 }
191- stmt = stmt .value (column , QueryBuilder .bindMarker ("message " ));
183+ stmt = stmt .value (this . conf . messagesColumn , QueryBuilder .bindMarker ("msgs " ));
192184 return this .conf .session .prepare (stmt .build ());
193185 }
194186
@@ -214,28 +206,27 @@ private PreparedStatement prepareGetStatement() {
214206 String columnName = this .conf .schema .clusteringKeys ().get (i ).name ();
215207 stmt = stmt .whereColumn (columnName ).isEqualTo (QueryBuilder .bindMarker (columnName ));
216208 }
209+ stmt = stmt .limit (QueryBuilder .bindMarker ("legacy_limit" ));
217210 return this .conf .session .prepare (stmt .build ());
218211 }
219212
220- private PreparedStatement prepareDeleteStmt () {
221- Delete stmt = null ;
222- DeleteSelection stmtStart = QueryBuilder .deleteFrom (this .conf .schema .keyspace (), this .conf .schema .table ());
223- for (var c : this .conf .schema .partitionKeys ()) {
224- stmt = (null != stmt ? stmt : stmtStart ).whereColumn (c .name ()).isEqualTo (QueryBuilder .bindMarker (c .name ()));
225- }
226- for (int i = 0 ; i + 1 < this .conf .schema .clusteringKeys ().size (); ++i ) {
227- String columnName = this .conf .schema .clusteringKeys ().get (i ).name ();
228- stmt = stmt .whereColumn (columnName ).isEqualTo (QueryBuilder .bindMarker (columnName ));
213+ private Message getMessage (UdtValue udt ) {
214+ String content = udt .getString (this .conf .messageUdtContentColumn );
215+ Map <String , Object > props = Map .of (CONVERSATION_TS , udt .getInstant (this .conf .messageUdtTimestampColumn ));
216+ switch (MessageType .valueOf (udt .getString (this .conf .messageUdtTypeColumn ))) {
217+ case ASSISTANT :
218+ return new AssistantMessage (content , props );
219+ case USER :
220+ return UserMessage .builder ().text (content ).metadata (props ).build ();
221+ case SYSTEM :
222+ return SystemMessage .builder ().text (content ).metadata (props ).build ();
223+ case TOOL :
224+ // todo – persist ToolResponse somehow
225+ return new ToolResponseMessage (List .of (), props );
226+ default :
227+ throw new IllegalStateException (
228+ String .format ("unknown message type %s" , udt .getString (this .conf .messageUdtTypeColumn )));
229229 }
230- return this .conf .session .prepare (stmt .build ());
231- }
232-
233- private PreparedStatement getStatement (Message msg ) {
234- return switch (msg .getMessageType ()) {
235- case USER -> this .addUserStmt ;
236- case ASSISTANT -> this .addAssistantStmt ;
237- default -> throw new IllegalArgumentException ("Cant add type " + msg );
238- };
239230 }
240231
241232}
0 commit comments