diff --git a/memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/main/java/org/springframework/ai/chat/memory/repository/jdbc/JdbcChatMemoryRepository.java b/memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/main/java/org/springframework/ai/chat/memory/repository/jdbc/JdbcChatMemoryRepository.java index b24254deb48..4f61c984aa5 100644 --- a/memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/main/java/org/springframework/ai/chat/memory/repository/jdbc/JdbcChatMemoryRepository.java +++ b/memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/main/java/org/springframework/ai/chat/memory/repository/jdbc/JdbcChatMemoryRepository.java @@ -37,11 +37,6 @@ import org.springframework.jdbc.datasource.DataSourceTransactionManager; import org.springframework.lang.Nullable; import org.springframework.transaction.PlatformTransactionManager; -import org.springframework.transaction.TransactionDefinition; -import org.springframework.transaction.TransactionException; -import org.springframework.transaction.TransactionStatus; -import org.springframework.transaction.support.AbstractPlatformTransactionManager; -import org.springframework.transaction.support.DefaultTransactionStatus; import org.springframework.transaction.support.TransactionTemplate; import org.springframework.util.Assert; import org.slf4j.Logger; @@ -54,6 +49,7 @@ * @author Thomas Vitale * @author Linar Abzaltdinov * @author Mark Pollack + * @author Yanming Zhou * @since 1.0.0 */ public class JdbcChatMemoryRepository implements ChatMemoryRepository { @@ -66,14 +62,14 @@ public class JdbcChatMemoryRepository implements ChatMemoryRepository { private static final Logger logger = LoggerFactory.getLogger(JdbcChatMemoryRepository.class); - private JdbcChatMemoryRepository(DataSource dataSource, JdbcChatMemoryRepositoryDialect dialect, + private JdbcChatMemoryRepository(JdbcTemplate jdbcTemplate, JdbcChatMemoryRepositoryDialect dialect, PlatformTransactionManager txManager) { - Assert.notNull(dataSource, "dataSource cannot be null"); + Assert.notNull(jdbcTemplate, "jdbcTemplate cannot be null"); Assert.notNull(dialect, "dialect cannot be null"); - this.jdbcTemplate = new JdbcTemplate(dataSource); + this.jdbcTemplate = jdbcTemplate; this.dialect = dialect; this.transactionTemplate = new TransactionTemplate( - txManager != null ? txManager : new DataSourceTransactionManager(dataSource)); + txManager != null ? txManager : new DataSourceTransactionManager(jdbcTemplate.getDataSource())); } @Override @@ -200,7 +196,18 @@ public Builder transactionManager(PlatformTransactionManager txManager) { public JdbcChatMemoryRepository build() { DataSource effectiveDataSource = resolveDataSource(); JdbcChatMemoryRepositoryDialect effectiveDialect = resolveDialect(effectiveDataSource); - return new JdbcChatMemoryRepository(effectiveDataSource, effectiveDialect, this.platformTransactionManager); + return new JdbcChatMemoryRepository(resolveJdbcTemplate(), effectiveDialect, + this.platformTransactionManager); + } + + private JdbcTemplate resolveJdbcTemplate() { + if (this.jdbcTemplate != null) { + return this.jdbcTemplate; + } + if (this.dataSource != null) { + return new JdbcTemplate(this.dataSource); + } + throw new IllegalArgumentException("DataSource must be set (either via dataSource() or jdbcTemplate())"); } private DataSource resolveDataSource() { diff --git a/memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/test/java/org/springframework/ai/chat/memory/repository/jdbc/JdbcChatMemoryRepositoryBuilderTests.java b/memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/test/java/org/springframework/ai/chat/memory/repository/jdbc/JdbcChatMemoryRepositoryBuilderTests.java index c6eb00c4abe..c0d50544a85 100644 --- a/memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/test/java/org/springframework/ai/chat/memory/repository/jdbc/JdbcChatMemoryRepositoryBuilderTests.java +++ b/memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/test/java/org/springframework/ai/chat/memory/repository/jdbc/JdbcChatMemoryRepositoryBuilderTests.java @@ -23,6 +23,7 @@ import org.junit.jupiter.api.Test; +import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.transaction.PlatformTransactionManager; import static org.assertj.core.api.Assertions.assertThat; @@ -34,6 +35,7 @@ * Tests for {@link JdbcChatMemoryRepository.Builder}. * * @author Mark Pollack + * @author Yanming Zhou */ public class JdbcChatMemoryRepositoryBuilderTests { @@ -223,4 +225,14 @@ void testBuilderPreferenceForExplicitDialect() throws SQLException { // for this) } + @Test + void repositoryShouldUseProvidedJdbcTemplate() throws SQLException { + DataSource dataSource = mock(DataSource.class); + JdbcTemplate jdbcTemplate = new JdbcTemplate(dataSource); + + JdbcChatMemoryRepository repository = JdbcChatMemoryRepository.builder().jdbcTemplate(jdbcTemplate).build(); + + assertThat(repository).extracting("jdbcTemplate").isSameAs(jdbcTemplate); + } + }