Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,12 @@
import org.springframework.boot.autoconfigure.ImportAutoConfiguration;
import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration;
import org.springframework.boot.autoconfigure.jdbc.JdbcTemplateAutoConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.context.annotation.Bean;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.test.context.ContextConfiguration;

import java.sql.Timestamp;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.stream.Collectors;

Expand All @@ -47,20 +46,17 @@
* Base class for integration tests for {@link JdbcChatMemoryRepository}.
*
* @author Mark Pollack
* @author Yanming Zhou
*/
@ContextConfiguration(classes = AbstractJdbcChatMemoryRepositoryIT.TestConfiguration.class)
public abstract class AbstractJdbcChatMemoryRepositoryIT {

@Autowired
protected ChatMemoryRepository chatMemoryRepository;
protected JdbcChatMemoryRepository chatMemoryRepository;

@Autowired
protected JdbcTemplate jdbcTemplate;

@Test
void correctChatMemoryRepositoryInstance() {
assertThat(chatMemoryRepository).isInstanceOf(ChatMemoryRepository.class);
}

@ParameterizedTest
@CsvSource({ "Message from assistant,ASSISTANT", "Message from user,USER", "Message from system,SYSTEM" })
void saveMessagesSingleMessage(String content, MessageType messageType) {
Expand Down Expand Up @@ -158,11 +154,6 @@ void deleteMessagesByConversationId() {

@Test
void testMessageOrder() {
// Create a repository using the from method to detect the dialect
JdbcChatMemoryRepository repository = JdbcChatMemoryRepository.builder()
.jdbcTemplate(jdbcTemplate)
.dialect(JdbcChatMemoryRepositoryDialect.from(jdbcTemplate.getDataSource()))
.build();

var conversationId = UUID.randomUUID().toString();

Expand All @@ -174,10 +165,10 @@ void testMessageOrder() {

// Save messages in the expected order
List<Message> orderedMessages = List.of(firstMessage, secondMessage, thirdMessage, fourthMessage);
repository.saveAll(conversationId, orderedMessages);
chatMemoryRepository.saveAll(conversationId, orderedMessages);

// Retrieve messages using the repository
List<Message> retrievedMessages = repository.findByConversationId(conversationId);
List<Message> retrievedMessages = chatMemoryRepository.findByConversationId(conversationId);
assertThat(retrievedMessages).hasSize(4);

// Get the actual order from the retrieved messages
Expand All @@ -192,14 +183,11 @@ void testMessageOrder() {
* Base configuration for all integration tests.
*/
@ImportAutoConfiguration({ DataSourceAutoConfiguration.class, JdbcTemplateAutoConfiguration.class })
static abstract class BaseTestConfiguration {
static class TestConfiguration {

@Bean
ChatMemoryRepository chatMemoryRepository(JdbcTemplate jdbcTemplate, DataSource dataSource) {
return JdbcChatMemoryRepository.builder()
.jdbcTemplate(jdbcTemplate)
.dialect(JdbcChatMemoryRepositoryDialect.from(dataSource))
.build();
ChatMemoryRepository chatMemoryRepository(DataSource dataSource) {
return JdbcChatMemoryRepository.builder().dataSource(dataSource).build();
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

package org.springframework.ai.chat.memory.repository.jdbc;

import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.test.context.TestPropertySource;
import org.springframework.test.context.jdbc.Sql;
Expand All @@ -27,15 +26,11 @@
* @author Jonathan Leijendekker
* @author Thomas Vitale
* @author Mark Pollack
* @author Yanming Zhou
*/
@SpringBootTest(classes = JdbcChatMemoryRepositoryMysqlIT.TestConfiguration.class)
@SpringBootTest
@TestPropertySource(properties = { "spring.datasource.url=jdbc:tc:mariadb:10.3.39:///" })
@Sql(scripts = "classpath:org/springframework/ai/chat/memory/repository/jdbc/schema-mariadb.sql")
class JdbcChatMemoryRepositoryMysqlIT extends AbstractJdbcChatMemoryRepositoryIT {

@SpringBootConfiguration
static class TestConfiguration extends BaseTestConfiguration {

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,84 +16,21 @@

package org.springframework.ai.chat.memory.repository.jdbc;

import java.util.List;
import java.util.UUID;
import javax.sql.DataSource;

import org.junit.jupiter.api.Test;

import org.springframework.ai.chat.memory.ChatMemoryRepository;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.context.annotation.Bean;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.datasource.DataSourceTransactionManager;
import org.springframework.test.context.TestPropertySource;
import org.springframework.test.context.jdbc.Sql;

import static org.assertj.core.api.Assertions.assertThat;

/**
* Integration tests for {@link JdbcChatMemoryRepository} with PostgreSQL.
*
* @author Jonathan Leijendekker
* @author Thomas Vitale
* @author Mark Pollack
* @author Yanming Zhou
*/
@SpringBootTest(classes = JdbcChatMemoryRepositoryPostgresqlIT.TestConfiguration.class)
@SpringBootTest
@TestPropertySource(properties = "spring.datasource.url=jdbc:tc:postgresql:17:///")
@Sql(scripts = "classpath:org/springframework/ai/chat/memory/repository/jdbc/schema-postgresql.sql")
class JdbcChatMemoryRepositoryPostgresqlIT extends AbstractJdbcChatMemoryRepositoryIT {

@Test
void repositoryWithExplicitTransactionManager() {
// Get the repository with explicit transaction manager
ChatMemoryRepository repositoryWithTxManager = TestConfiguration
.chatMemoryRepositoryWithTransactionManager(jdbcTemplate, jdbcTemplate.getDataSource());

var conversationId = UUID.randomUUID().toString();
var messages = List.<Message>of(new AssistantMessage("Message with transaction manager - " + conversationId),
new UserMessage("User message with transaction manager - " + conversationId));

// Save messages using the repository with explicit transaction manager
repositoryWithTxManager.saveAll(conversationId, messages);

// Verify messages were saved correctly
var savedMessages = repositoryWithTxManager.findByConversationId(conversationId);
assertThat(savedMessages).hasSize(2);
assertThat(savedMessages).isEqualTo(messages);

// Verify transaction works by updating and checking atomicity
var newMessages = List.<Message>of(new SystemMessage("New system message - " + conversationId));
repositoryWithTxManager.saveAll(conversationId, newMessages);

// The old messages should be deleted and only the new one should exist
var updatedMessages = repositoryWithTxManager.findByConversationId(conversationId);
assertThat(updatedMessages).hasSize(1);
assertThat(updatedMessages).isEqualTo(newMessages);
}

@SpringBootConfiguration
static class TestConfiguration extends BaseTestConfiguration {

@Bean
ChatMemoryRepository chatMemoryRepositoryWithTxManager(JdbcTemplate jdbcTemplate, DataSource dataSource) {
return chatMemoryRepositoryWithTransactionManager(jdbcTemplate, dataSource);
}

static ChatMemoryRepository chatMemoryRepositoryWithTransactionManager(JdbcTemplate jdbcTemplate,
DataSource dataSource) {
return JdbcChatMemoryRepository.builder()
.jdbcTemplate(jdbcTemplate)
.dialect(JdbcChatMemoryRepositoryDialect.from(dataSource))
.transactionManager(new DataSourceTransactionManager(dataSource))
.build();
}

}

}