Skip to content

Commit 40da194

Browse files
committed
JdbcChatMemoryRepository should use the provided JdbcTemplate
Before this commit, the underlying `JdbcTemplate` is created like `new JdbcTemplate(providedJdbcTemplate.getDataSource())`, it means that settings on provided `JdbcTemplate` will lose. Signed-off-by: Yanming Zhou <[email protected]>
1 parent b219c21 commit 40da194

File tree

2 files changed

+29
-10
lines changed

2 files changed

+29
-10
lines changed

memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/main/java/org/springframework/ai/chat/memory/repository/jdbc/JdbcChatMemoryRepository.java

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,6 @@
3737
import org.springframework.jdbc.datasource.DataSourceTransactionManager;
3838
import org.springframework.lang.Nullable;
3939
import org.springframework.transaction.PlatformTransactionManager;
40-
import org.springframework.transaction.TransactionDefinition;
41-
import org.springframework.transaction.TransactionException;
42-
import org.springframework.transaction.TransactionStatus;
43-
import org.springframework.transaction.support.AbstractPlatformTransactionManager;
44-
import org.springframework.transaction.support.DefaultTransactionStatus;
4540
import org.springframework.transaction.support.TransactionTemplate;
4641
import org.springframework.util.Assert;
4742
import org.slf4j.Logger;
@@ -54,6 +49,7 @@
5449
* @author Thomas Vitale
5550
* @author Linar Abzaltdinov
5651
* @author Mark Pollack
52+
* @author Yanming Zhou
5753
* @since 1.0.0
5854
*/
5955
public class JdbcChatMemoryRepository implements ChatMemoryRepository {
@@ -66,14 +62,14 @@ public class JdbcChatMemoryRepository implements ChatMemoryRepository {
6662

6763
private static final Logger logger = LoggerFactory.getLogger(JdbcChatMemoryRepository.class);
6864

69-
private JdbcChatMemoryRepository(DataSource dataSource, JdbcChatMemoryRepositoryDialect dialect,
65+
private JdbcChatMemoryRepository(JdbcTemplate jdbcTemplate, JdbcChatMemoryRepositoryDialect dialect,
7066
PlatformTransactionManager txManager) {
71-
Assert.notNull(dataSource, "dataSource cannot be null");
67+
Assert.notNull(jdbcTemplate, "jdbcTemplate cannot be null");
7268
Assert.notNull(dialect, "dialect cannot be null");
73-
this.jdbcTemplate = new JdbcTemplate(dataSource);
69+
this.jdbcTemplate = jdbcTemplate;
7470
this.dialect = dialect;
7571
this.transactionTemplate = new TransactionTemplate(
76-
txManager != null ? txManager : new DataSourceTransactionManager(dataSource));
72+
txManager != null ? txManager : new DataSourceTransactionManager(jdbcTemplate.getDataSource()));
7773
}
7874

7975
@Override
@@ -200,7 +196,18 @@ public Builder transactionManager(PlatformTransactionManager txManager) {
200196
public JdbcChatMemoryRepository build() {
201197
DataSource effectiveDataSource = resolveDataSource();
202198
JdbcChatMemoryRepositoryDialect effectiveDialect = resolveDialect(effectiveDataSource);
203-
return new JdbcChatMemoryRepository(effectiveDataSource, effectiveDialect, this.platformTransactionManager);
199+
return new JdbcChatMemoryRepository(resolveJdbcTemplate(), effectiveDialect,
200+
this.platformTransactionManager);
201+
}
202+
203+
private JdbcTemplate resolveJdbcTemplate() {
204+
if (this.jdbcTemplate != null) {
205+
return this.jdbcTemplate;
206+
}
207+
if (this.dataSource != null) {
208+
return new JdbcTemplate(this.dataSource);
209+
}
210+
throw new IllegalArgumentException("DataSource must be set (either via dataSource() or jdbcTemplate())");
204211
}
205212

206213
private DataSource resolveDataSource() {

memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/test/java/org/springframework/ai/chat/memory/repository/jdbc/JdbcChatMemoryRepositoryBuilderTests.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import org.junit.jupiter.api.Test;
2525

26+
import org.springframework.jdbc.core.JdbcTemplate;
2627
import org.springframework.transaction.PlatformTransactionManager;
2728

2829
import static org.assertj.core.api.Assertions.assertThat;
@@ -34,6 +35,7 @@
3435
* Tests for {@link JdbcChatMemoryRepository.Builder}.
3536
*
3637
* @author Mark Pollack
38+
* @author Yanming Zhou
3739
*/
3840
public class JdbcChatMemoryRepositoryBuilderTests {
3941

@@ -223,4 +225,14 @@ void testBuilderPreferenceForExplicitDialect() throws SQLException {
223225
// for this)
224226
}
225227

228+
@Test
229+
void repositoryShouldUseProvidedJdbcTemplate() throws SQLException {
230+
DataSource dataSource = mock(DataSource.class);
231+
JdbcTemplate jdbcTemplate = new JdbcTemplate(dataSource);
232+
233+
JdbcChatMemoryRepository repository = JdbcChatMemoryRepository.builder().jdbcTemplate(jdbcTemplate).build();
234+
235+
assertThat(repository).extracting("jdbcTemplate").isSameAs(jdbcTemplate);
236+
}
237+
226238
}

0 commit comments

Comments
 (0)