Skip to content

Commit 5b74843

Browse files
committed
feat: implement refresh test
Signed-off-by: potato <[email protected]>
1 parent 5099310 commit 5b74843

File tree

3 files changed

+85
-0
lines changed

3 files changed

+85
-0
lines changed

memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/test/java/org/springframework/ai/chat/memory/repository/jdbc/AbstractJdbcChatMemoryRepositoryIT.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,31 @@ void testMessageOrder() {
186186
"4-Fourth message");
187187
}
188188

189+
@Test
190+
void refreshConversation() {
191+
var conversationId = UUID.randomUUID().toString();
192+
List<Message> initialMessages = List.of(new UserMessage("Hello"), new AssistantMessage("Hi there"),
193+
new UserMessage("How are you?"));
194+
this.chatMemoryRepository.saveAll(conversationId, initialMessages);
195+
196+
assertThat(this.chatMemoryRepository.findByConversationId(conversationId)).hasSize(3);
197+
198+
// Define changes
199+
List<Message> toDelete = List.of(new UserMessage("How are you?"));
200+
List<Message> toAdd = List.of(new AssistantMessage("I am fine, thank you."));
201+
202+
// Apply changes
203+
this.chatMemoryRepository.refresh(conversationId, toDelete, toAdd);
204+
205+
// Verify final state
206+
List<Message> finalMessages = this.chatMemoryRepository.findByConversationId(conversationId);
207+
assertThat(finalMessages).hasSize(3);
208+
assertThat(finalMessages).contains(new UserMessage("Hello"));
209+
assertThat(finalMessages).contains(new AssistantMessage("Hi there"));
210+
assertThat(finalMessages).contains(new AssistantMessage("I am fine, thank you."));
211+
assertThat(finalMessages).doesNotContain(new UserMessage("How are you?"));
212+
}
213+
189214
/**
190215
* Base configuration for all integration tests.
191216
*/

memory/repository/spring-ai-model-chat-memory-repository-neo4j/src/test/java/org/springframework/ai/chat/memory/repository/neo4j/Neo4jChatMemoryRepositoryIT.java

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,45 @@ void saveAndFindMessagesWithEmptyContentOrMetadata() {
399399
// messageType
400400
}
401401

402+
@Test
403+
void refreshConversation() {
404+
var conversationId = UUID.randomUUID().toString();
405+
406+
// 1. Save initial messages
407+
var initialMessages = List.of(new UserMessage("Hello"), new AssistantMessage("Hi there"),
408+
new UserMessage("How are you?"));
409+
this.chatMemoryRepository.saveAll(conversationId, initialMessages);
410+
411+
// Retrieve to get metadata (especially the generated message IDs)
412+
List<Message> savedMessages = this.chatMemoryRepository.findByConversationId(conversationId);
413+
assertThat(savedMessages).hasSize(3);
414+
415+
// 2. Define changes
416+
var messageToDelete = savedMessages.stream().filter(m -> m.getText().equals("How are you?")).findFirst().get();
417+
var toDelete = List.of(messageToDelete);
418+
var toAdd = List.of(new AssistantMessage("I am fine, thank you."));
419+
420+
// 3. Apply changes
421+
this.chatMemoryRepository.refresh(conversationId, toDelete, toAdd);
422+
423+
// 4. Verify final state
424+
List<Message> finalMessages = this.chatMemoryRepository.findByConversationId(conversationId);
425+
assertThat(finalMessages).hasSize(3);
426+
427+
List<String> finalContents = finalMessages.stream().map(Message::getText).toList();
428+
assertThat(finalContents).contains("Hello", "Hi there", "I am fine, thank you.");
429+
assertThat(finalContents).doesNotContain("How are you?");
430+
431+
// Verify directly in the database
432+
try (Session session = this.driver.session()) {
433+
var result = session.run(
434+
"MATCH (s:%s {id:$conversationId})-[:HAS_MESSAGE]->(m:%s) RETURN count(m) as count"
435+
.formatted(this.config.getSessionLabel(), this.config.getMessageLabel()),
436+
Map.of("conversationId", conversationId));
437+
assertThat(result.single().get("count").asLong()).isEqualTo(3);
438+
}
439+
}
440+
402441
private Message createMessageByType(String content, MessageType messageType) {
403442
return switch (messageType) {
404443
case ASSISTANT -> new AssistantMessage(content);

spring-ai-model/src/test/java/org/springframework/ai/chat/memory/InMemoryChatMemoryRepositoryTests.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,4 +153,25 @@ void messagesWithNullElementsNotAllowed() {
153153
.hasMessageContaining("messages cannot contain null elements");
154154
}
155155

156+
@Test
157+
void refreshConversation() {
158+
String conversationId = UUID.randomUUID().toString();
159+
List<Message> initialMessages = List.of(new UserMessage("Hello"), new AssistantMessage("Hi"));
160+
this.chatMemoryRepository.saveAll(conversationId, initialMessages);
161+
162+
assertThat(this.chatMemoryRepository.findByConversationId(conversationId)).hasSize(2);
163+
164+
List<Message> toDelete = List.of(new UserMessage("Hello"));
165+
List<Message> toAdd = List.of(new UserMessage("How are you?"), new AssistantMessage("I'm fine, thanks!"));
166+
167+
this.chatMemoryRepository.refresh(conversationId, toDelete, toAdd);
168+
169+
List<Message> updatedMessages = this.chatMemoryRepository.findByConversationId(conversationId);
170+
assertThat(updatedMessages).hasSize(3);
171+
assertThat(updatedMessages).contains(new AssistantMessage("Hi"));
172+
assertThat(updatedMessages).contains(new UserMessage("How are you?"));
173+
assertThat(updatedMessages).contains(new AssistantMessage("I'm fine, thanks!"));
174+
assertThat(updatedMessages).doesNotContain(new UserMessage("Hello"));
175+
}
176+
156177
}

0 commit comments

Comments
 (0)