Skip to content

Commit be42cc0

Browse files
committed
Improve JdbcChatMemoryRepository to simplify findConversationIds()
Signed-off-by: Yanming Zhou <[email protected]>
1 parent b219c21 commit be42cc0

File tree

2 files changed

+5
-8
lines changed

2 files changed

+5
-8
lines changed

memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/main/java/org/springframework/ai/chat/memory/repository/jdbc/JdbcChatMemoryRepository.java

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,7 @@ private JdbcChatMemoryRepository(DataSource dataSource, JdbcChatMemoryRepository
7878

7979
@Override
8080
public List<String> findConversationIds() {
81-
List<String> conversationIds = this.jdbcTemplate.query(dialect.getSelectConversationIdsSql(), rs -> {
82-
var ids = new ArrayList<String>();
83-
while (rs.next()) {
84-
ids.add(rs.getString(1));
85-
}
86-
return ids;
87-
});
88-
return conversationIds != null ? conversationIds : List.of();
81+
return this.jdbcTemplate.queryForList(dialect.getSelectConversationIdsSql(), String.class);
8982
}
9083

9184
@Override

memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/test/java/org/springframework/ai/chat/memory/repository/jdbc/AbstractJdbcChatMemoryRepositoryIT.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ void saveMessagesSingleMessage(String content, MessageType messageType) {
7474

7575
chatMemoryRepository.saveAll(conversationId, List.of(message));
7676

77+
assertThat(chatMemoryRepository.findConversationIds()).contains(conversationId);
78+
7779
// Use dialect to get the appropriate SQL query
7880
JdbcChatMemoryRepositoryDialect dialect = JdbcChatMemoryRepositoryDialect.from(jdbcTemplate.getDataSource());
7981
String selectSql = dialect.getSelectMessagesSql()
@@ -96,6 +98,8 @@ void saveMessagesMultipleMessages() {
9698

9799
chatMemoryRepository.saveAll(conversationId, messages);
98100

101+
assertThat(chatMemoryRepository.findConversationIds()).contains(conversationId);
102+
99103
// Use dialect to get the appropriate SQL query
100104
JdbcChatMemoryRepositoryDialect dialect = JdbcChatMemoryRepositoryDialect.from(jdbcTemplate.getDataSource());
101105
String selectSql = dialect.getSelectMessagesSql()

0 commit comments

Comments
 (0)