Skip to content

Commit 1bf5787

Browse files
linarkousobychacko
authored andcommitted
Fix message order in Cassandra chat memory after retrieving from db
Fixes: #2815 Added integration tests Replaced deprecated CassandraContainer test-container by actual one Reordered imports in CassandraChatMemoryIT Signed-off-by: Linar Abzaltdinov <[email protected]> Fixing CassandraChatMemoryAutoConfigurationIT for the correct ordering Signed-off-by: Soby Chacko <[email protected]>
1 parent b74e308 commit 1bf5787

File tree

3 files changed

+179
-7
lines changed

3 files changed

+179
-7
lines changed

auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-cassandra/src/test/java/org/springframework/ai/model/chat/memory/cassandra/autoconfigure/CassandraChatMemoryAutoConfigurationIT.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -80,11 +80,11 @@ void addAndGet() {
8080

8181
assertThat(memory.get(sessionId, Integer.MAX_VALUE)).hasSize(2);
8282
assertThat(memory.get(sessionId, Integer.MAX_VALUE).get(1).getMessageType())
83-
.isEqualTo(MessageType.USER);
84-
assertThat(memory.get(sessionId, Integer.MAX_VALUE).get(1).getText()).isEqualTo("test question");
85-
assertThat(memory.get(sessionId, Integer.MAX_VALUE).get(0).getMessageType())
8683
.isEqualTo(MessageType.ASSISTANT);
87-
assertThat(memory.get(sessionId, Integer.MAX_VALUE).get(0).getText()).isEqualTo("test answer");
84+
assertThat(memory.get(sessionId, Integer.MAX_VALUE).get(1).getText()).isEqualTo("test answer");
85+
assertThat(memory.get(sessionId, Integer.MAX_VALUE).get(0).getMessageType())
86+
.isEqualTo(MessageType.USER);
87+
assertThat(memory.get(sessionId, Integer.MAX_VALUE).get(0).getText()).isEqualTo("test question");
8888

8989
CassandraChatMemoryProperties properties = context.getBean(CassandraChatMemoryProperties.class);
9090
assertThat(properties.getTimeToLive()).isEqualTo(getTimeToLive());

memory/spring-ai-model-chat-memory-cassandra/src/main/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemory.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -18,6 +18,7 @@
1818

1919
import java.time.Instant;
2020
import java.util.ArrayList;
21+
import java.util.Collections;
2122
import java.util.List;
2223
import java.util.concurrent.atomic.AtomicLong;
2324

@@ -158,6 +159,7 @@ public List<Message> get(String sessionId, int lastN) {
158159
messages.add(new UserMessage(user));
159160
}
160161
}
162+
Collections.reverse(messages);
161163
return messages;
162164
}
163165

memory/spring-ai-model-chat-memory-cassandra/src/test/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryIT.java

Lines changed: 171 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -17,21 +17,34 @@
1717
package org.springframework.ai.chat.memory.cassandra;
1818

1919
import java.time.Duration;
20+
import java.util.List;
21+
import java.util.UUID;
2022

2123
import com.datastax.oss.driver.api.core.CqlSession;
2224
import com.datastax.oss.driver.api.core.CqlSessionBuilder;
25+
import com.datastax.oss.driver.api.core.cql.ResultSet;
2326
import org.junit.jupiter.api.Assertions;
2427
import org.junit.jupiter.api.Test;
28+
29+
import org.junit.jupiter.params.ParameterizedTest;
30+
import org.junit.jupiter.params.provider.CsvSource;
2531
import org.testcontainers.cassandra.CassandraContainer;
2632
import org.testcontainers.junit.jupiter.Container;
2733
import org.testcontainers.junit.jupiter.Testcontainers;
2834

35+
import org.springframework.ai.chat.memory.ChatMemory;
36+
import org.springframework.ai.chat.messages.AssistantMessage;
37+
import org.springframework.ai.chat.messages.Message;
38+
import org.springframework.ai.chat.messages.MessageType;
39+
import org.springframework.ai.chat.messages.UserMessage;
2940
import org.springframework.boot.SpringBootConfiguration;
3041
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
3142
import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration;
3243
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
3344
import org.springframework.context.annotation.Bean;
3445

46+
import static org.assertj.core.api.Assertions.assertThat;
47+
3548
/**
3649
* Use `mvn failsafe:integration-test -Dit.test=CassandraChatMemoryIT`
3750
*
@@ -57,6 +70,163 @@ void ensureBeanGetsCreated() {
5770
});
5871
}
5972

73+
@ParameterizedTest
74+
@CsvSource({ "Message from assistant,ASSISTANT", "Message from user,USER" })
75+
void add_shouldInsertSingleMessage(String content, MessageType messageType) {
76+
this.contextRunner.run(context -> {
77+
var chatMemory = context.getBean(ChatMemory.class);
78+
var sessionId = UUID.randomUUID().toString();
79+
var message = switch (messageType) {
80+
case ASSISTANT -> new AssistantMessage(content);
81+
case USER -> new UserMessage(content);
82+
default -> throw new IllegalArgumentException("Type not supported: " + messageType);
83+
};
84+
85+
chatMemory.add(sessionId, message);
86+
87+
var cqlSession = context.getBean(CqlSession.class);
88+
var query = """
89+
SELECT session_id, message_timestamp, a, u
90+
FROM test_springframework.ai_chat_memory
91+
WHERE session_id = ?
92+
""";
93+
ResultSet resultSet = cqlSession.execute(query, sessionId);
94+
var rows = resultSet.all();
95+
96+
assertThat(rows.size()).isEqualTo(1);
97+
98+
var firstRow = rows.get(0);
99+
100+
assertThat(firstRow.getString("session_id")).isEqualTo(sessionId);
101+
assertThat(firstRow.getInstant("message_timestamp")).isNotNull();
102+
if (messageType == MessageType.ASSISTANT) {
103+
assertThat(firstRow.getString("a")).isEqualTo(content);
104+
assertThat(firstRow.getString("u")).isNull();
105+
}
106+
else if (messageType == MessageType.USER) {
107+
assertThat(firstRow.getString("a")).isNull();
108+
assertThat(firstRow.getString("u")).isEqualTo(content);
109+
}
110+
});
111+
}
112+
113+
@Test
114+
void add_shouldInsertMessages() {
115+
this.contextRunner.run(context -> {
116+
var chatMemory = context.getBean(ChatMemory.class);
117+
var sessionId = UUID.randomUUID().toString();
118+
var messages = List.<Message>of(new AssistantMessage("Message from assistant"),
119+
new UserMessage("Message from user"));
120+
121+
chatMemory.add(sessionId, messages);
122+
123+
var cqlSession = context.getBean(CqlSession.class);
124+
var query = """
125+
SELECT session_id, message_timestamp, a, u
126+
FROM test_springframework.ai_chat_memory
127+
WHERE session_id = ?
128+
ORDER BY message_timestamp ASC
129+
""";
130+
ResultSet resultSet = cqlSession.execute(query, sessionId);
131+
var rows = resultSet.all();
132+
133+
assertThat(rows.size()).isEqualTo(messages.size());
134+
135+
for (var i = 0; i < messages.size(); i++) {
136+
var message = messages.get(i);
137+
var result = rows.get(i);
138+
139+
assertThat(result.getString("session_id")).isNotNull();
140+
assertThat(result.getString("session_id")).isEqualTo(sessionId);
141+
if (message.getMessageType() == MessageType.ASSISTANT) {
142+
assertThat(result.getString("a")).isEqualTo(message.getText());
143+
assertThat(result.getString("u")).isNull();
144+
}
145+
else if (message.getMessageType() == MessageType.USER) {
146+
assertThat(result.getString("a")).isNull();
147+
assertThat(result.getString("u")).isEqualTo(message.getText());
148+
}
149+
assertThat(result.getInstant("message_timestamp")).isNotNull();
150+
}
151+
});
152+
}
153+
154+
@Test
155+
void get_shouldReturnMessages() {
156+
this.contextRunner.run(context -> {
157+
var chatMemory = context.getBean(ChatMemory.class);
158+
var sessionId = UUID.randomUUID().toString();
159+
var messages = List.<Message>of(new AssistantMessage("Message from assistant 1 - " + sessionId),
160+
new AssistantMessage("Message from assistant 2 - " + sessionId),
161+
new UserMessage("Message from user - " + sessionId));
162+
163+
chatMemory.add(sessionId, messages);
164+
165+
var results = chatMemory.get(sessionId, Integer.MAX_VALUE);
166+
167+
assertThat(results.size()).isEqualTo(messages.size());
168+
169+
for (var i = 0; i < messages.size(); i++) {
170+
var message = messages.get(i);
171+
var result = results.get(i);
172+
173+
assertThat(result.getMessageType()).isEqualTo(message.getMessageType());
174+
assertThat(result.getText()).isEqualTo(message.getText());
175+
}
176+
});
177+
}
178+
179+
@Test
180+
void get_afterMultipleAdds_shouldReturnMessagesInSameOrder() {
181+
this.contextRunner.run(context -> {
182+
var chatMemory = context.getBean(ChatMemory.class);
183+
var sessionId = UUID.randomUUID().toString();
184+
var userMessage = new UserMessage("Message from user - " + sessionId);
185+
var assistantMessage = new AssistantMessage("Message from assistant - " + sessionId);
186+
187+
chatMemory.add(sessionId, userMessage);
188+
chatMemory.add(sessionId, assistantMessage);
189+
190+
var results = chatMemory.get(sessionId, Integer.MAX_VALUE);
191+
192+
assertThat(results.size()).isEqualTo(2);
193+
194+
var messages = List.<Message>of(userMessage, assistantMessage);
195+
for (var i = 0; i < messages.size(); i++) {
196+
var message = messages.get(i);
197+
var result = results.get(i);
198+
199+
assertThat(result.getMessageType()).isEqualTo(message.getMessageType());
200+
assertThat(result.getText()).isEqualTo(message.getText());
201+
}
202+
});
203+
}
204+
205+
@Test
206+
void clear_shouldDeleteMessages() {
207+
this.contextRunner.run(context -> {
208+
var chatMemory = context.getBean(ChatMemory.class);
209+
var sessionId = UUID.randomUUID().toString();
210+
var messages = List.<Message>of(new AssistantMessage("Message from assistant - " + sessionId),
211+
new UserMessage("Message from user - " + sessionId));
212+
213+
chatMemory.add(sessionId, messages);
214+
215+
chatMemory.clear(sessionId);
216+
217+
var cqlSession = context.getBean(CqlSession.class);
218+
var query = """
219+
SELECT COUNT(*)
220+
FROM test_springframework.ai_chat_memory
221+
WHERE session_id = ?
222+
""";
223+
ResultSet resultSet = cqlSession.execute(query, sessionId);
224+
var count = resultSet.all().get(0).getLong(0);
225+
226+
assertThat(count).isZero();
227+
});
228+
}
229+
60230
@SpringBootConfiguration
61231
@EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class })
62232
public static class TestApplication {

0 commit comments

Comments
 (0)