Skip to content

Commit 9915739

Browse files
committed
chat-memory-cassandra : Added integration tests
Signed-off-by: Linar Abzaltdinov <[email protected]>
1 parent 8bf5d44 commit 9915739

File tree

1 file changed

+174
-6
lines changed

1 file changed

+174
-6
lines changed

memory/spring-ai-model-chat-memory-cassandra/src/test/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryIT.java

Lines changed: 174 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,32 @@
1616

1717
package org.springframework.ai.chat.memory.cassandra;
1818

19-
import java.time.Duration;
20-
2119
import com.datastax.oss.driver.api.core.CqlSession;
2220
import com.datastax.oss.driver.api.core.CqlSessionBuilder;
21+
import com.datastax.oss.driver.api.core.cql.ResultSet;
2322
import org.junit.jupiter.api.Assertions;
2423
import org.junit.jupiter.api.Test;
25-
import org.testcontainers.containers.CassandraContainer;
26-
import org.testcontainers.junit.jupiter.Container;
27-
import org.testcontainers.junit.jupiter.Testcontainers;
28-
24+
import org.junit.jupiter.params.ParameterizedTest;
25+
import org.junit.jupiter.params.provider.CsvSource;
26+
import org.springframework.ai.chat.memory.ChatMemory;
27+
import org.springframework.ai.chat.messages.AssistantMessage;
28+
import org.springframework.ai.chat.messages.Message;
29+
import org.springframework.ai.chat.messages.MessageType;
30+
import org.springframework.ai.chat.messages.UserMessage;
2931
import org.springframework.boot.SpringBootConfiguration;
3032
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
3133
import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration;
3234
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
3335
import org.springframework.context.annotation.Bean;
36+
import org.testcontainers.containers.CassandraContainer;
37+
import org.testcontainers.junit.jupiter.Container;
38+
import org.testcontainers.junit.jupiter.Testcontainers;
39+
40+
import java.time.Duration;
41+
import java.util.List;
42+
import java.util.UUID;
43+
44+
import static org.assertj.core.api.Assertions.assertThat;
3445

3546
/**
3647
* Use `mvn failsafe:integration-test -Dit.test=CassandraChatMemoryIT`
@@ -57,6 +68,163 @@ void ensureBeanGetsCreated() {
5768
});
5869
}
5970

71+
@ParameterizedTest
72+
@CsvSource({ "Message from assistant,ASSISTANT", "Message from user,USER" })
73+
void add_shouldInsertSingleMessage(String content, MessageType messageType) {
74+
this.contextRunner.run(context -> {
75+
var chatMemory = context.getBean(ChatMemory.class);
76+
var sessionId = UUID.randomUUID().toString();
77+
var message = switch (messageType) {
78+
case ASSISTANT -> new AssistantMessage(content);
79+
case USER -> new UserMessage(content);
80+
default -> throw new IllegalArgumentException("Type not supported: " + messageType);
81+
};
82+
83+
chatMemory.add(sessionId, message);
84+
85+
var cqlSession = context.getBean(CqlSession.class);
86+
var query = """
87+
SELECT session_id, message_timestamp, a, u
88+
FROM test_springframework.ai_chat_memory
89+
WHERE session_id = ?
90+
""";
91+
ResultSet resultSet = cqlSession.execute(query, sessionId);
92+
var rows = resultSet.all();
93+
94+
assertThat(rows.size()).isEqualTo(1);
95+
96+
var firstRow = rows.get(0);
97+
98+
assertThat(firstRow.getString("session_id")).isEqualTo(sessionId);
99+
assertThat(firstRow.getInstant("message_timestamp")).isNotNull();
100+
if (messageType == MessageType.ASSISTANT) {
101+
assertThat(firstRow.getString("a")).isEqualTo(content);
102+
assertThat(firstRow.getString("u")).isNull();
103+
}
104+
else if (messageType == MessageType.USER) {
105+
assertThat(firstRow.getString("a")).isNull();
106+
assertThat(firstRow.getString("u")).isEqualTo(content);
107+
}
108+
});
109+
}
110+
111+
@Test
112+
void add_shouldInsertMessages() {
113+
this.contextRunner.run(context -> {
114+
var chatMemory = context.getBean(ChatMemory.class);
115+
var sessionId = UUID.randomUUID().toString();
116+
var messages = List.<Message>of(new AssistantMessage("Message from assistant"),
117+
new UserMessage("Message from user"));
118+
119+
chatMemory.add(sessionId, messages);
120+
121+
var cqlSession = context.getBean(CqlSession.class);
122+
var query = """
123+
SELECT session_id, message_timestamp, a, u
124+
FROM test_springframework.ai_chat_memory
125+
WHERE session_id = ?
126+
ORDER BY message_timestamp ASC
127+
""";
128+
ResultSet resultSet = cqlSession.execute(query, sessionId);
129+
var rows = resultSet.all();
130+
131+
assertThat(rows.size()).isEqualTo(messages.size());
132+
133+
for (var i = 0; i < messages.size(); i++) {
134+
var message = messages.get(i);
135+
var result = rows.get(i);
136+
137+
assertThat(result.getString("session_id")).isNotNull();
138+
assertThat(result.getString("session_id")).isEqualTo(sessionId);
139+
if (message.getMessageType() == MessageType.ASSISTANT) {
140+
assertThat(result.getString("a")).isEqualTo(message.getText());
141+
assertThat(result.getString("u")).isNull();
142+
}
143+
else if (message.getMessageType() == MessageType.USER) {
144+
assertThat(result.getString("a")).isNull();
145+
assertThat(result.getString("u")).isEqualTo(message.getText());
146+
}
147+
assertThat(result.getInstant("message_timestamp")).isNotNull();
148+
}
149+
});
150+
}
151+
152+
@Test
153+
void get_shouldReturnMessages() {
154+
this.contextRunner.run(context -> {
155+
var chatMemory = context.getBean(ChatMemory.class);
156+
var sessionId = UUID.randomUUID().toString();
157+
var messages = List.<Message>of(new AssistantMessage("Message from assistant 1 - " + sessionId),
158+
new AssistantMessage("Message from assistant 2 - " + sessionId),
159+
new UserMessage("Message from user - " + sessionId));
160+
161+
chatMemory.add(sessionId, messages);
162+
163+
var results = chatMemory.get(sessionId, Integer.MAX_VALUE);
164+
165+
assertThat(results.size()).isEqualTo(messages.size());
166+
167+
for (var i = 0; i < messages.size(); i++) {
168+
var message = messages.get(i);
169+
var result = results.get(i);
170+
171+
assertThat(result.getMessageType()).isEqualTo(message.getMessageType());
172+
assertThat(result.getText()).isEqualTo(message.getText());
173+
}
174+
});
175+
}
176+
177+
@Test
178+
void get_afterMultipleAdds_shouldReturnMessagesInSameOrder() {
179+
this.contextRunner.run(context -> {
180+
var chatMemory = context.getBean(ChatMemory.class);
181+
var sessionId = UUID.randomUUID().toString();
182+
var userMessage = new UserMessage("Message from user - " + sessionId);
183+
var assistantMessage = new AssistantMessage("Message from assistant - " + sessionId);
184+
185+
chatMemory.add(sessionId, userMessage);
186+
chatMemory.add(sessionId, assistantMessage);
187+
188+
var results = chatMemory.get(sessionId, Integer.MAX_VALUE);
189+
190+
assertThat(results.size()).isEqualTo(2);
191+
192+
var messages = List.<Message>of(userMessage, assistantMessage);
193+
for (var i = 0; i < messages.size(); i++) {
194+
var message = messages.get(i);
195+
var result = results.get(i);
196+
197+
assertThat(result.getMessageType()).isEqualTo(message.getMessageType());
198+
assertThat(result.getText()).isEqualTo(message.getText());
199+
}
200+
});
201+
}
202+
203+
@Test
204+
void clear_shouldDeleteMessages() {
205+
this.contextRunner.run(context -> {
206+
var chatMemory = context.getBean(ChatMemory.class);
207+
var sessionId = UUID.randomUUID().toString();
208+
var messages = List.<Message>of(new AssistantMessage("Message from assistant - " + sessionId),
209+
new UserMessage("Message from user - " + sessionId));
210+
211+
chatMemory.add(sessionId, messages);
212+
213+
chatMemory.clear(sessionId);
214+
215+
var cqlSession = context.getBean(CqlSession.class);
216+
var query = """
217+
SELECT COUNT(*)
218+
FROM test_springframework.ai_chat_memory
219+
WHERE session_id = ?
220+
""";
221+
ResultSet resultSet = cqlSession.execute(query, sessionId);
222+
var count = resultSet.all().get(0).getLong(0);
223+
224+
assertThat(count).isZero();
225+
});
226+
}
227+
60228
@SpringBootConfiguration
61229
@EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class })
62230
public static class TestApplication {

0 commit comments

Comments
 (0)