1616
1717package org .springframework .ai .chat .memory .jdbc ;
1818
19- import java .sql .PreparedStatement ;
20- import java .sql .ResultSet ;
21- import java .sql .SQLException ;
22- import java .util .List ;
23-
2419import org .springframework .ai .chat .memory .ChatMemory ;
25- import org .springframework .ai .chat .messages .AssistantMessage ;
26- import org .springframework .ai .chat .messages .Message ;
27- import org .springframework .ai .chat .messages .MessageType ;
28- import org .springframework .ai .chat .messages .SystemMessage ;
29- import org .springframework .ai .chat .messages .UserMessage ;
20+ import org .springframework .ai .chat .messages .*;
21+ import org .springframework .boot .jdbc .DatabaseDriver ;
3022import org .springframework .jdbc .core .BatchPreparedStatementSetter ;
3123import org .springframework .jdbc .core .JdbcTemplate ;
3224import org .springframework .jdbc .core .RowMapper ;
25+ import org .springframework .util .Assert ;
26+
27+ import java .sql .Connection ;
28+ import java .sql .PreparedStatement ;
29+ import java .sql .ResultSet ;
30+ import java .sql .SQLException ;
31+ import java .util .List ;
3332
3433/**
3534 * An implementation of {@link ChatMemory} for JDBC. Creating an instance of
3635 * JdbcChatMemory example:
3736 * <code>JdbcChatMemory.create(JdbcChatMemoryConfig.builder().jdbcTemplate(jdbcTemplate).build());</code>
3837 *
3938 * @author Jonathan Leijendekker
39+ * @author Xavier Chopin
4040 * @since 1.0.0
4141 */
4242public class JdbcChatMemory implements ChatMemory {
@@ -45,14 +45,33 @@ public class JdbcChatMemory implements ChatMemory {
4545 INSERT INTO ai_chat_memory (conversation_id, content, type) VALUES (?, ?, ?)""" ;
4646
4747 private static final String QUERY_GET = """
48- SELECT content, type FROM ai_chat_memory WHERE conversation_id = ? ORDER BY "timestamp" DESC LIMIT ?""" ;
48+ SELECT content, type \
49+ FROM ai_chat_memory \
50+ WHERE conversation_id = ? \
51+ ORDER BY "timestamp" DESC \
52+ LIMIT ?
53+ """ ;
54+
55+ private static final String MSSQL_QUERY_GET = """
56+ SELECT content, type \
57+ FROM ( \
58+ SELECT TOP (?) content, type, [timestamp] \
59+ FROM ai_chat_memory \
60+ WHERE conversation_id = ? \
61+ ORDER BY [timestamp] DESC \
62+ ) AS recent \
63+ ORDER BY [timestamp] ASC \
64+ """ ;
4965
5066 private static final String QUERY_CLEAR = "DELETE FROM ai_chat_memory WHERE conversation_id = ?" ;
5167
5268 private final JdbcTemplate jdbcTemplate ;
5369
70+ private final DatabaseDriver driver ;
71+
5472 public JdbcChatMemory (JdbcChatMemoryConfig config ) {
5573 this .jdbcTemplate = config .getJdbcTemplate ();
74+ this .driver = this .detectDialect (this .jdbcTemplate );
5675 }
5776
5877 public static JdbcChatMemory create (JdbcChatMemoryConfig config ) {
@@ -66,16 +85,19 @@ public void add(String conversationId, List<Message> messages) {
6685
6786 @ Override
6887 public List <Message > get (String conversationId , int lastN ) {
69- return this .jdbcTemplate .query (QUERY_GET , new MessageRowMapper (), conversationId , lastN );
88+ return switch (driver ) {
89+ case SQLSERVER -> this .jdbcTemplate .query (MSSQL_QUERY_GET , new MessageRowMapper (), lastN , conversationId );
90+ default -> this .jdbcTemplate .query (QUERY_GET , new MessageRowMapper (), conversationId , lastN );
91+ };
7092 }
7193
7294 @ Override
7395 public void clear (String conversationId ) {
7496 this .jdbcTemplate .update (QUERY_CLEAR , conversationId );
7597 }
7698
77- private record AddBatchPreparedStatement (String conversationId ,
78- List < Message > messages ) implements BatchPreparedStatementSetter {
99+ private record AddBatchPreparedStatement (String conversationId , List < Message > messages )
100+ implements BatchPreparedStatementSetter {
79101 @ Override
80102 public void setValues (PreparedStatement ps , int i ) throws SQLException {
81103 var message = this .messages .get (i );
@@ -108,4 +130,15 @@ public Message mapRow(ResultSet rs, int i) throws SQLException {
108130
109131 }
110132
133+ private DatabaseDriver detectDialect (JdbcTemplate jdbcTemplate ) {
134+ try {
135+ Assert .notNull (jdbcTemplate .getDataSource (), "jdbcTemplate.dataSource must not be null" );
136+ try (Connection conn = jdbcTemplate .getDataSource ().getConnection ()) {
137+ String url = conn .getMetaData ().getURL ();
138+ return DatabaseDriver .fromJdbcUrl (url );
139+ }
140+ } catch (SQLException ex ) {
141+ throw new IllegalStateException ("Impossible to detect dialect" , ex );
142+ }
143+ }
111144}
0 commit comments