diff --git a/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryRepository.java b/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryRepository.java index cf10413a550..30f4690dbc0 100644 --- a/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryRepository.java +++ b/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryRepository.java @@ -48,30 +48,54 @@ */ public class JdbcChatMemoryRepository implements ChatMemoryRepository { - private static final String QUERY_GET_IDS = """ - SELECT DISTINCT conversation_id FROM ai_chat_memory - """; + private final String queryGetIds; - private static final String QUERY_ADD = """ - INSERT INTO ai_chat_memory (conversation_id, content, type, "timestamp") VALUES (?, ?, ?, ?) - """; + private final String queryAdd; - private static final String QUERY_GET = """ - SELECT content, type FROM ai_chat_memory WHERE conversation_id = ? ORDER BY "timestamp" - """; + private final String queryGet; - private static final String QUERY_CLEAR = "DELETE FROM ai_chat_memory WHERE conversation_id = ?"; + private final String queryClear; private final JdbcTemplate jdbcTemplate; - private JdbcChatMemoryRepository(JdbcTemplate jdbcTemplate) { + public final static String DEFAULT_TABLE_NAME = "ai_chat_memory"; + + public final static String DEFAULT_CONVERSION_ID_FIELD_NAME = "conversation_id"; + + public final static String DEFAULT_CONTENT_FIELD_NAME = "content"; + + public final static String DEFAULT_TYPE_FIELD_NAME = "type"; + + public final static String DEFAULT_TIMESTAMP_FIELD_NAME = "\"timestamp\""; + + public final static String DEFAULT_GET_IDS_QUERY = "SELECT DISTINCT %s FROM %s"; + + public final static String DEFAULT_ADD_QUERY = "INSERT INTO %s (%s, %s, %s, %s) VALUES (?, ?, ?, ?)"; + + public final static String DEFAULT_GET_QUERY = "SELECT %s, %s FROM %s WHERE %s = ? ORDER BY %s"; + + public final static String DEFAULT_CLEAR_QUERY = "DELETE FROM %s WHERE %s = ?"; + + private JdbcChatMemoryRepository(JdbcTemplate jdbcTemplate, String tableName, String conversionIdFiledName, + String contentFiledName, String typeFiledName, String timestampFiledName) { Assert.notNull(jdbcTemplate, "jdbcTemplate cannot be null"); + Assert.notNull(tableName, "tableName cannot be null"); + Assert.notNull(conversionIdFiledName, "conversionIdFiledName cannot be null"); + Assert.notNull(contentFiledName, "contentFiledName cannot be null"); + Assert.notNull(typeFiledName, "typeFiledName cannot be null"); + Assert.notNull(timestampFiledName, "timestampFiledName cannot be null"); this.jdbcTemplate = jdbcTemplate; + this.queryGetIds = DEFAULT_GET_IDS_QUERY.formatted(conversionIdFiledName, tableName); + this.queryAdd = DEFAULT_ADD_QUERY.formatted(tableName, conversionIdFiledName, contentFiledName, typeFiledName, + timestampFiledName); + this.queryGet = DEFAULT_GET_QUERY.formatted(contentFiledName, typeFiledName, tableName, conversionIdFiledName, + timestampFiledName); + this.queryClear = DEFAULT_CLEAR_QUERY.formatted(tableName, conversionIdFiledName); } @Override public List findConversationIds() { - List conversationIds = this.jdbcTemplate.query(QUERY_GET_IDS, rs -> { + List conversationIds = this.jdbcTemplate.query(queryGetIds, rs -> { var ids = new ArrayList(); while (rs.next()) { ids.add(rs.getString(1)); @@ -84,7 +108,7 @@ public List findConversationIds() { @Override public List findByConversationId(String conversationId) { Assert.hasText(conversationId, "conversationId cannot be null or empty"); - return this.jdbcTemplate.query(QUERY_GET, new MessageRowMapper(), conversationId); + return this.jdbcTemplate.query(queryGet, new MessageRowMapper(), conversationId); } @Override @@ -93,13 +117,13 @@ public void saveAll(String conversationId, List messages) { Assert.notNull(messages, "messages cannot be null"); Assert.noNullElements(messages, "messages cannot contain null elements"); this.deleteByConversationId(conversationId); - this.jdbcTemplate.batchUpdate(QUERY_ADD, new AddBatchPreparedStatement(conversationId, messages)); + this.jdbcTemplate.batchUpdate(queryAdd, new AddBatchPreparedStatement(conversationId, messages)); } @Override public void deleteByConversationId(String conversationId) { Assert.hasText(conversationId, "conversationId cannot be null or empty"); - this.jdbcTemplate.update(QUERY_CLEAR, conversationId); + this.jdbcTemplate.update(queryClear, conversationId); } private record AddBatchPreparedStatement(String conversationId, List messages, @@ -154,6 +178,16 @@ public static class Builder { private JdbcTemplate jdbcTemplate; + private String tableName = DEFAULT_TABLE_NAME; + + private String conversionIdFiledName = DEFAULT_CONVERSION_ID_FIELD_NAME; + + private String contentFiledName = DEFAULT_CONTENT_FIELD_NAME; + + private String typeFiledName = DEFAULT_TYPE_FIELD_NAME; + + private String timestampFiledName = DEFAULT_TIMESTAMP_FIELD_NAME; + private Builder() { } @@ -162,8 +196,34 @@ public Builder jdbcTemplate(JdbcTemplate jdbcTemplate) { return this; } + public Builder tableName(String tableName) { + this.tableName = tableName; + return this; + } + + public Builder conversionIdFiledName(String conversionIdFiledName) { + this.conversionIdFiledName = conversionIdFiledName; + return this; + } + + public Builder contentFiledName(String contentFiledName) { + this.contentFiledName = contentFiledName; + return this; + } + + public Builder typeFiledName(String typeFiledName) { + this.typeFiledName = typeFiledName; + return this; + } + + public Builder timestampFiledName(String timestampFiledName) { + this.timestampFiledName = timestampFiledName; + return this; + } + public JdbcChatMemoryRepository build() { - return new JdbcChatMemoryRepository(this.jdbcTemplate); + return new JdbcChatMemoryRepository(this.jdbcTemplate, this.tableName, this.conversionIdFiledName, + this.contentFiledName, this.typeFiledName, this.timestampFiledName); } } diff --git a/memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryRepositoryPostgresqlIT.java b/memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryRepositoryPostgresqlIT.java index 5e5abc6ac41..5b9520c2ea1 100644 --- a/memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryRepositoryPostgresqlIT.java +++ b/memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryRepositoryPostgresqlIT.java @@ -41,6 +41,7 @@ import java.util.UUID; import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.ai.chat.memory.jdbc.JdbcChatMemoryRepository.*; /** * Integration tests for {@link JdbcChatMemoryRepository}. @@ -160,7 +161,14 @@ static class TestConfiguration { @Bean ChatMemoryRepository chatMemoryRepository(JdbcTemplate jdbcTemplate) { - return JdbcChatMemoryRepository.builder().jdbcTemplate(jdbcTemplate).build(); + return JdbcChatMemoryRepository.builder() + .jdbcTemplate(jdbcTemplate) + .tableName(DEFAULT_TABLE_NAME) + .conversionIdFiledName(DEFAULT_CONVERSION_ID_FIELD_NAME) + .contentFiledName(DEFAULT_CONTENT_FIELD_NAME) + .typeFiledName(DEFAULT_TYPE_FIELD_NAME) + .timestampFiledName(DEFAULT_TIMESTAMP_FIELD_NAME) + .build(); } }