|
16 | 16 |
|
17 | 17 | package org.springframework.ai.chat.memory.repository.jdbc; |
18 | 18 |
|
19 | | -import java.sql.PreparedStatement; |
20 | | -import java.sql.ResultSet; |
21 | | -import java.sql.SQLException; |
22 | | -import java.sql.Timestamp; |
| 19 | +import java.sql.*; |
23 | 20 | import java.time.Instant; |
24 | 21 | import java.util.ArrayList; |
25 | 22 | import java.util.List; |
| 23 | +import java.util.Optional; |
26 | 24 | import java.util.concurrent.atomic.AtomicLong; |
27 | 25 |
|
28 | 26 | import org.springframework.ai.chat.memory.ChatMemoryRepository; |
|
35 | 33 | import org.springframework.jdbc.core.BatchPreparedStatementSetter; |
36 | 34 | import org.springframework.jdbc.core.JdbcTemplate; |
37 | 35 | import org.springframework.jdbc.core.RowMapper; |
| 36 | +import org.springframework.jdbc.datasource.DataSourceUtils; |
38 | 37 | import org.springframework.lang.Nullable; |
39 | 38 | import org.springframework.util.Assert; |
40 | 39 |
|
@@ -83,9 +82,31 @@ public void saveAll(String conversationId, List<Message> messages) { |
83 | 82 | Assert.hasText(conversationId, "conversationId cannot be null or empty"); |
84 | 83 | Assert.notNull(messages, "messages cannot be null"); |
85 | 84 | Assert.noNullElements(messages, "messages cannot contain null elements"); |
86 | | - this.deleteByConversationId(conversationId); |
87 | | - this.jdbcTemplate.batchUpdate(dialect.getInsertMessageSql(), |
88 | | - new AddBatchPreparedStatement(conversationId, messages)); |
| 85 | + |
| 86 | + Connection connection = null; |
| 87 | + Assert.notNull(jdbcTemplate.getDataSource(), "jdbcTemplate.getDataSource() cannot be null"); |
| 88 | + try { |
| 89 | + connection = DataSourceUtils.getConnection(jdbcTemplate.getDataSource()); |
| 90 | + connection.setAutoCommit(false); |
| 91 | + this.deleteByConversationId(conversationId); |
| 92 | + this.jdbcTemplate.batchUpdate(dialect.getInsertMessageSql(), |
| 93 | + new AddBatchPreparedStatement(conversationId, messages)); |
| 94 | + connection.commit(); |
| 95 | + } |
| 96 | + catch (SQLException ex) { |
| 97 | + try { |
| 98 | + connection.rollback(); |
| 99 | + } |
| 100 | + catch (SQLException e) { |
| 101 | + throw new RuntimeException("Transaction rollback exception", e); |
| 102 | + } |
| 103 | + throw new RuntimeException("save messages failed", ex); |
| 104 | + } |
| 105 | + finally { |
| 106 | + Optional.ofNullable(connection) |
| 107 | + .ifPresent(conn -> DataSourceUtils.releaseConnection(conn, jdbcTemplate.getDataSource())); |
| 108 | + } |
| 109 | + |
89 | 110 | } |
90 | 111 |
|
91 | 112 | @Override |
|
0 commit comments