diff --git a/memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/main/java/org/springframework/ai/chat/memory/repository/jdbc/JdbcChatMemoryRepository.java b/memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/main/java/org/springframework/ai/chat/memory/repository/jdbc/JdbcChatMemoryRepository.java index 85d8c1c3265..68df4791949 100644 --- a/memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/main/java/org/springframework/ai/chat/memory/repository/jdbc/JdbcChatMemoryRepository.java +++ b/memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/main/java/org/springframework/ai/chat/memory/repository/jdbc/JdbcChatMemoryRepository.java @@ -16,10 +16,7 @@ package org.springframework.ai.chat.memory.repository.jdbc; -import java.sql.PreparedStatement; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.sql.Timestamp; +import java.sql.*; import java.time.Instant; import java.util.ArrayList; import java.util.List; @@ -35,7 +32,9 @@ import org.springframework.jdbc.core.BatchPreparedStatementSetter; import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.jdbc.core.RowMapper; +import org.springframework.jdbc.datasource.DataSourceTransactionManager; import org.springframework.lang.Nullable; +import org.springframework.transaction.support.TransactionTemplate; import org.springframework.util.Assert; /** @@ -51,13 +50,18 @@ public class JdbcChatMemoryRepository implements ChatMemoryRepository { private final JdbcTemplate jdbcTemplate; + private final TransactionTemplate transactionTemplate; + private final JdbcChatMemoryRepositoryDialect dialect; private JdbcChatMemoryRepository(JdbcTemplate jdbcTemplate, JdbcChatMemoryRepositoryDialect dialect) { Assert.notNull(jdbcTemplate, "jdbcTemplate cannot be null"); Assert.notNull(dialect, "dialect cannot be null"); + Assert.notNull(jdbcTemplate.getDataSource(), "dataSource can not be null"); this.jdbcTemplate = jdbcTemplate; this.dialect = dialect; + this.transactionTemplate = new TransactionTemplate( + new DataSourceTransactionManager(jdbcTemplate.getDataSource())); } @Override @@ -83,9 +87,13 @@ public void saveAll(String conversationId, List messages) { Assert.hasText(conversationId, "conversationId cannot be null or empty"); Assert.notNull(messages, "messages cannot be null"); Assert.noNullElements(messages, "messages cannot contain null elements"); - this.deleteByConversationId(conversationId); - this.jdbcTemplate.batchUpdate(dialect.getInsertMessageSql(), - new AddBatchPreparedStatement(conversationId, messages)); + + transactionTemplate.execute(status -> { + deleteByConversationId(conversationId); + jdbcTemplate.batchUpdate(dialect.getInsertMessageSql(), + new AddBatchPreparedStatement(conversationId, messages)); + return null; + }); } @Override