1616
1717package org .springframework .ai .chat .memory .cassandra ;
1818
19- import java .time .Duration ;
20-
2119import com .datastax .oss .driver .api .core .CqlSession ;
2220import com .datastax .oss .driver .api .core .CqlSessionBuilder ;
21+ import com .datastax .oss .driver .api .core .cql .ResultSet ;
2322import org .junit .jupiter .api .Assertions ;
2423import org .junit .jupiter .api .Test ;
25- import org .testcontainers .containers .CassandraContainer ;
26- import org .testcontainers .junit .jupiter .Container ;
27- import org .testcontainers .junit .jupiter .Testcontainers ;
28-
24+ import org .junit .jupiter .params .ParameterizedTest ;
25+ import org .junit .jupiter .params .provider .CsvSource ;
26+ import org .springframework .ai .chat .memory .ChatMemory ;
27+ import org .springframework .ai .chat .messages .AssistantMessage ;
28+ import org .springframework .ai .chat .messages .Message ;
29+ import org .springframework .ai .chat .messages .MessageType ;
30+ import org .springframework .ai .chat .messages .UserMessage ;
2931import org .springframework .boot .SpringBootConfiguration ;
3032import org .springframework .boot .autoconfigure .EnableAutoConfiguration ;
3133import org .springframework .boot .autoconfigure .jdbc .DataSourceAutoConfiguration ;
3234import org .springframework .boot .test .context .runner .ApplicationContextRunner ;
3335import org .springframework .context .annotation .Bean ;
36+ import org .testcontainers .containers .CassandraContainer ;
37+ import org .testcontainers .junit .jupiter .Container ;
38+ import org .testcontainers .junit .jupiter .Testcontainers ;
39+
40+ import java .time .Duration ;
41+ import java .util .List ;
42+ import java .util .UUID ;
43+
44+ import static org .assertj .core .api .Assertions .assertThat ;
3445
3546/**
3647 * Use `mvn failsafe:integration-test -Dit.test=CassandraChatMemoryIT`
@@ -57,6 +68,163 @@ void ensureBeanGetsCreated() {
5768 });
5869 }
5970
71+ @ ParameterizedTest
72+ @ CsvSource ({ "Message from assistant,ASSISTANT" , "Message from user,USER" })
73+ void add_shouldInsertSingleMessage (String content , MessageType messageType ) {
74+ this .contextRunner .run (context -> {
75+ var chatMemory = context .getBean (ChatMemory .class );
76+ var sessionId = UUID .randomUUID ().toString ();
77+ var message = switch (messageType ) {
78+ case ASSISTANT -> new AssistantMessage (content );
79+ case USER -> new UserMessage (content );
80+ default -> throw new IllegalArgumentException ("Type not supported: " + messageType );
81+ };
82+
83+ chatMemory .add (sessionId , message );
84+
85+ var cqlSession = context .getBean (CqlSession .class );
86+ var query = """
87+ SELECT session_id, message_timestamp, a, u
88+ FROM test_springframework.ai_chat_memory
89+ WHERE session_id = ?
90+ """ ;
91+ ResultSet resultSet = cqlSession .execute (query , sessionId );
92+ var rows = resultSet .all ();
93+
94+ assertThat (rows .size ()).isEqualTo (1 );
95+
96+ var firstRow = rows .get (0 );
97+
98+ assertThat (firstRow .getString ("session_id" )).isEqualTo (sessionId );
99+ assertThat (firstRow .getInstant ("message_timestamp" )).isNotNull ();
100+ if (messageType == MessageType .ASSISTANT ) {
101+ assertThat (firstRow .getString ("a" )).isEqualTo (content );
102+ assertThat (firstRow .getString ("u" )).isNull ();
103+ }
104+ else if (messageType == MessageType .USER ) {
105+ assertThat (firstRow .getString ("a" )).isNull ();
106+ assertThat (firstRow .getString ("u" )).isEqualTo (content );
107+ }
108+ });
109+ }
110+
111+ @ Test
112+ void add_shouldInsertMessages () {
113+ this .contextRunner .run (context -> {
114+ var chatMemory = context .getBean (ChatMemory .class );
115+ var sessionId = UUID .randomUUID ().toString ();
116+ var messages = List .<Message >of (new AssistantMessage ("Message from assistant" ),
117+ new UserMessage ("Message from user" ));
118+
119+ chatMemory .add (sessionId , messages );
120+
121+ var cqlSession = context .getBean (CqlSession .class );
122+ var query = """
123+ SELECT session_id, message_timestamp, a, u
124+ FROM test_springframework.ai_chat_memory
125+ WHERE session_id = ?
126+ ORDER BY message_timestamp ASC
127+ """ ;
128+ ResultSet resultSet = cqlSession .execute (query , sessionId );
129+ var rows = resultSet .all ();
130+
131+ assertThat (rows .size ()).isEqualTo (messages .size ());
132+
133+ for (var i = 0 ; i < messages .size (); i ++) {
134+ var message = messages .get (i );
135+ var result = rows .get (i );
136+
137+ assertThat (result .getString ("session_id" )).isNotNull ();
138+ assertThat (result .getString ("session_id" )).isEqualTo (sessionId );
139+ if (message .getMessageType () == MessageType .ASSISTANT ) {
140+ assertThat (result .getString ("a" )).isEqualTo (message .getText ());
141+ assertThat (result .getString ("u" )).isNull ();
142+ }
143+ else if (message .getMessageType () == MessageType .USER ) {
144+ assertThat (result .getString ("a" )).isNull ();
145+ assertThat (result .getString ("u" )).isEqualTo (message .getText ());
146+ }
147+ assertThat (result .getInstant ("message_timestamp" )).isNotNull ();
148+ }
149+ });
150+ }
151+
152+ @ Test
153+ void get_shouldReturnMessages () {
154+ this .contextRunner .run (context -> {
155+ var chatMemory = context .getBean (ChatMemory .class );
156+ var sessionId = UUID .randomUUID ().toString ();
157+ var messages = List .<Message >of (new AssistantMessage ("Message from assistant 1 - " + sessionId ),
158+ new AssistantMessage ("Message from assistant 2 - " + sessionId ),
159+ new UserMessage ("Message from user - " + sessionId ));
160+
161+ chatMemory .add (sessionId , messages );
162+
163+ var results = chatMemory .get (sessionId , Integer .MAX_VALUE );
164+
165+ assertThat (results .size ()).isEqualTo (messages .size ());
166+
167+ for (var i = 0 ; i < messages .size (); i ++) {
168+ var message = messages .get (i );
169+ var result = results .get (i );
170+
171+ assertThat (result .getMessageType ()).isEqualTo (message .getMessageType ());
172+ assertThat (result .getText ()).isEqualTo (message .getText ());
173+ }
174+ });
175+ }
176+
177+ @ Test
178+ void get_afterMultipleAdds_shouldReturnMessagesInSameOrder () {
179+ this .contextRunner .run (context -> {
180+ var chatMemory = context .getBean (ChatMemory .class );
181+ var sessionId = UUID .randomUUID ().toString ();
182+ var userMessage = new UserMessage ("Message from user - " + sessionId );
183+ var assistantMessage = new AssistantMessage ("Message from assistant - " + sessionId );
184+
185+ chatMemory .add (sessionId , userMessage );
186+ chatMemory .add (sessionId , assistantMessage );
187+
188+ var results = chatMemory .get (sessionId , Integer .MAX_VALUE );
189+
190+ assertThat (results .size ()).isEqualTo (2 );
191+
192+ var messages = List .<Message >of (userMessage , assistantMessage );
193+ for (var i = 0 ; i < messages .size (); i ++) {
194+ var message = messages .get (i );
195+ var result = results .get (i );
196+
197+ assertThat (result .getMessageType ()).isEqualTo (message .getMessageType ());
198+ assertThat (result .getText ()).isEqualTo (message .getText ());
199+ }
200+ });
201+ }
202+
203+ @ Test
204+ void clear_shouldDeleteMessages () {
205+ this .contextRunner .run (context -> {
206+ var chatMemory = context .getBean (ChatMemory .class );
207+ var sessionId = UUID .randomUUID ().toString ();
208+ var messages = List .<Message >of (new AssistantMessage ("Message from assistant - " + sessionId ),
209+ new UserMessage ("Message from user - " + sessionId ));
210+
211+ chatMemory .add (sessionId , messages );
212+
213+ chatMemory .clear (sessionId );
214+
215+ var cqlSession = context .getBean (CqlSession .class );
216+ var query = """
217+ SELECT COUNT(*)
218+ FROM test_springframework.ai_chat_memory
219+ WHERE session_id = ?
220+ """ ;
221+ ResultSet resultSet = cqlSession .execute (query , sessionId );
222+ var count = resultSet .all ().get (0 ).getLong (0 );
223+
224+ assertThat (count ).isZero ();
225+ });
226+ }
227+
60228 @ SpringBootConfiguration
61229 @ EnableAutoConfiguration (exclude = { DataSourceAutoConfiguration .class })
62230 public static class TestApplication {
0 commit comments