diff --git a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/main/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryAutoConfiguration.java b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/main/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryAutoConfiguration.java index bc811c3ded6..079caa35bf5 100644 --- a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/main/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryAutoConfiguration.java +++ b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/main/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryAutoConfiguration.java @@ -23,6 +23,7 @@ import org.springframework.ai.chat.memory.jdbc.JdbcChatMemory; import org.springframework.ai.chat.memory.jdbc.JdbcChatMemoryConfig; +import org.springframework.ai.chat.memory.jdbc.JdbcChatMemoryRepository; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; @@ -35,6 +36,7 @@ /** * @author Jonathan Leijendekker + * @author Thomas Vitale * @since 1.0.0 */ @AutoConfiguration(after = JdbcTemplateAutoConfiguration.class) @@ -46,9 +48,21 @@ public class JdbcChatMemoryAutoConfiguration { @Bean @ConditionalOnMissingBean - public JdbcChatMemory chatMemory(JdbcTemplate jdbcTemplate) { + JdbcChatMemoryRepository jdbcChatMemoryRepository(JdbcTemplate jdbcTemplate) { var config = JdbcChatMemoryConfig.builder().jdbcTemplate(jdbcTemplate).build(); + return JdbcChatMemoryRepository.create(config); + } + /** + * @deprecated in favor of providing ChatClient directly a + * {@link org.springframework.ai.chat.memory.MessageWindowChatMemory} with a + * {@link JdbcChatMemoryRepository} instance. + */ + @Bean + @ConditionalOnMissingBean + @Deprecated + JdbcChatMemory chatMemory(JdbcTemplate jdbcTemplate) { + var config = JdbcChatMemoryConfig.builder().jdbcTemplate(jdbcTemplate).build(); return JdbcChatMemory.create(config); } @@ -56,9 +70,8 @@ public JdbcChatMemory chatMemory(JdbcTemplate jdbcTemplate) { @ConditionalOnMissingBean @ConditionalOnProperty(value = "spring.ai.chat.memory.jdbc.initialize-schema", havingValue = "true", matchIfMissing = true) - public DataSourceScriptDatabaseInitializer jdbcChatMemoryScriptDatabaseInitializer(DataSource dataSource) { + DataSourceScriptDatabaseInitializer jdbcChatMemoryScriptDatabaseInitializer(DataSource dataSource) { logger.debug("Initializing JdbcChatMemory schema"); - return new JdbcChatMemoryDataSourceScriptDatabaseInitializer(dataSource); } diff --git a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryAutoConfigurationIT.java b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryAutoConfigurationIT.java index df9a49d85b9..871626e9bb1 100644 --- a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryAutoConfigurationIT.java +++ b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryAutoConfigurationIT.java @@ -26,6 +26,7 @@ import org.testcontainers.utility.DockerImageName; import org.springframework.ai.chat.memory.jdbc.JdbcChatMemory; +import org.springframework.ai.chat.memory.jdbc.JdbcChatMemoryRepository; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; @@ -38,6 +39,7 @@ /** * @author Jonathan Leijendekker + * @author Thomas Vitale */ @Testcontainers class JdbcChatMemoryAutoConfigurationIT { @@ -96,4 +98,30 @@ void addGetAndClear_shouldAllExecute() { }); } + @Test + void useAutoConfiguredJdbcChatMemoryRepository() { + this.contextRunner.withPropertyValues("spring.ai.chat.memory.jdbc.initialize-schema=true").run(context -> { + var chatMemoryRepository = context.getBean(JdbcChatMemoryRepository.class); + var conversationId = UUID.randomUUID().toString(); + var userMessage = new UserMessage("Message from the user"); + + chatMemoryRepository.save(conversationId, List.of(userMessage)); + + assertThat(chatMemoryRepository.findById(conversationId)).hasSize(1); + assertThat(chatMemoryRepository.findById(conversationId)).isEqualTo(List.of(userMessage)); + + chatMemoryRepository.deleteById(conversationId); + + assertThat(chatMemoryRepository.findById(conversationId)).isEmpty(); + + var multipleMessages = List.of(new UserMessage("Message from the user 1"), + new AssistantMessage("Message from the assistant 1")); + + chatMemoryRepository.save(conversationId, multipleMessages); + + assertThat(chatMemoryRepository.findById(conversationId)).hasSize(multipleMessages.size()); + assertThat(chatMemoryRepository.findById(conversationId)).isEqualTo(multipleMessages); + }); + } + } diff --git a/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemory.java b/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemory.java index 6c9825bac1b..13c652c505c 100644 --- a/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemory.java +++ b/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemory.java @@ -38,7 +38,11 @@ * * @author Jonathan Leijendekker * @since 1.0.0 + * @deprecated in favor of providing ChatClient directly a + * {@link org.springframework.ai.chat.memory.MessageWindowChatMemory} with a + * {@link JdbcChatMemoryRepository} instance. */ +@Deprecated public class JdbcChatMemory implements ChatMemory { private static final String QUERY_ADD = """ diff --git a/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryConfig.java b/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryConfig.java index 5a503aef051..37e5c60fa96 100644 --- a/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryConfig.java +++ b/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryConfig.java @@ -20,7 +20,7 @@ import org.springframework.util.Assert; /** - * Configuration for {@link JdbcChatMemory}. + * Configuration for {@link JdbcChatMemoryRepository}. * * @author Jonathan Leijendekker * @since 1.0.0 diff --git a/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryRepository.java b/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryRepository.java new file mode 100644 index 00000000000..3f7e492c052 --- /dev/null +++ b/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryRepository.java @@ -0,0 +1,115 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.memory.jdbc; + +import org.springframework.ai.chat.memory.ChatMemoryRepository; +import org.springframework.ai.chat.messages.*; +import org.springframework.jdbc.core.BatchPreparedStatementSetter; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.jdbc.core.RowMapper; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.List; + +/** + * An implementation of {@link ChatMemoryRepository} for JDBC. + * + * @author Jonathan Leijendekker + * @author Thomas Vitale + * @since 1.0.0 + */ +public class JdbcChatMemoryRepository implements ChatMemoryRepository { + + private static final String QUERY_ADD = """ + INSERT INTO ai_chat_memory (conversation_id, content, type) VALUES (?, ?, ?)"""; + + private static final String QUERY_GET = """ + SELECT content, type FROM ai_chat_memory WHERE conversation_id = ? ORDER BY "timestamp" DESC"""; + + private static final String QUERY_CLEAR = "DELETE FROM ai_chat_memory WHERE conversation_id = ?"; + + private final JdbcTemplate jdbcTemplate; + + private JdbcChatMemoryRepository(JdbcChatMemoryConfig config) { + Assert.notNull(config, "config cannot be null"); + this.jdbcTemplate = config.getJdbcTemplate(); + } + + public static JdbcChatMemoryRepository create(JdbcChatMemoryConfig config) { + return new JdbcChatMemoryRepository(config); + } + + @Override + public List findById(String conversationId) { + Assert.hasText(conversationId, "conversationId cannot be null or empty"); + return this.jdbcTemplate.query(QUERY_GET, new MessageRowMapper(), conversationId); + } + + @Override + public void save(String conversationId, List messages) { + Assert.hasText(conversationId, "conversationId cannot be null or empty"); + Assert.notNull(messages, "messages cannot be null"); + Assert.noNullElements(messages, "messages cannot contain null elements"); + this.jdbcTemplate.batchUpdate(QUERY_ADD, new AddBatchPreparedStatement(conversationId, messages)); + } + + @Override + public void deleteById(String conversationId) { + Assert.hasText(conversationId, "conversationId cannot be null or empty"); + this.jdbcTemplate.update(QUERY_CLEAR, conversationId); + } + + private record AddBatchPreparedStatement(String conversationId, + List messages) implements BatchPreparedStatementSetter { + @Override + public void setValues(PreparedStatement ps, int i) throws SQLException { + var message = this.messages.get(i); + + ps.setString(1, this.conversationId); + ps.setString(2, message.getText()); + ps.setString(3, message.getMessageType().name()); + } + + @Override + public int getBatchSize() { + return this.messages.size(); + } + } + + private static class MessageRowMapper implements RowMapper { + + @Override + @Nullable + public Message mapRow(ResultSet rs, int i) throws SQLException { + var content = rs.getString(1); + var type = MessageType.valueOf(rs.getString(2)); + + return switch (type) { + case USER -> new UserMessage(content); + case ASSISTANT -> new AssistantMessage(content); + case SYSTEM -> new SystemMessage(content); + case TOOL -> null; + }; + } + + } + +} diff --git a/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/aot/hint/JdbcChatMemoryRuntimeHints.java b/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/aot/hint/JdbcChatMemoryRepositoryRuntimeHints.java similarity index 94% rename from memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/aot/hint/JdbcChatMemoryRuntimeHints.java rename to memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/aot/hint/JdbcChatMemoryRepositoryRuntimeHints.java index 6740602e3f8..3b518733ffb 100644 --- a/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/aot/hint/JdbcChatMemoryRuntimeHints.java +++ b/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/aot/hint/JdbcChatMemoryRepositoryRuntimeHints.java @@ -27,7 +27,7 @@ * * @author Jonathan Leijendekker */ -class JdbcChatMemoryRuntimeHints implements RuntimeHintsRegistrar { +class JdbcChatMemoryRepositoryRuntimeHints implements RuntimeHintsRegistrar { @Override public void registerHints(RuntimeHints hints, ClassLoader classLoader) { diff --git a/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/package-info.java b/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/package-info.java new file mode 100644 index 00000000000..a26f200dec1 --- /dev/null +++ b/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/package-info.java @@ -0,0 +1,22 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +@NonNullApi +@NonNullFields +package org.springframework.ai.chat.memory.jdbc; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/memory/spring-ai-model-chat-memory-jdbc/src/main/resources/META-INF/spring/aot.factories b/memory/spring-ai-model-chat-memory-jdbc/src/main/resources/META-INF/spring/aot.factories index 4b6f4a8f5ce..7169645c1e9 100644 --- a/memory/spring-ai-model-chat-memory-jdbc/src/main/resources/META-INF/spring/aot.factories +++ b/memory/spring-ai-model-chat-memory-jdbc/src/main/resources/META-INF/spring/aot.factories @@ -1,2 +1,2 @@ org.springframework.aot.hint.RuntimeHintsRegistrar=\ -org.springframework.ai.chat.memory.jdbc.aot.hint.JdbcChatMemoryRuntimeHints +org.springframework.ai.chat.memory.jdbc.aot.hint.JdbcChatMemoryRepositoryRuntimeHints diff --git a/memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryRepositoryIT.java b/memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryRepositoryIT.java new file mode 100644 index 00000000000..ec46cb5c2ba --- /dev/null +++ b/memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryRepositoryIT.java @@ -0,0 +1,208 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.memory.jdbc; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.springframework.ai.chat.memory.ChatMemoryRepository; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.MessageType; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.autoconfigure.EnableAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.DataSourceProperties; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Primary; +import org.springframework.jdbc.core.JdbcTemplate; +import org.testcontainers.containers.PostgreSQLContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.utility.MountableFile; + +import javax.sql.DataSource; +import java.sql.Timestamp; +import java.util.List; +import java.util.UUID; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for {@link JdbcChatMemoryRepository}. + * + * @author Jonathan Leijendekker + * @author Thomas Vitale + */ +@Testcontainers +class JdbcChatMemoryRepositoryIT { + + @Container + @SuppressWarnings("resource") + static PostgreSQLContainer postgresContainer = new PostgreSQLContainer<>("postgres:17") + .withDatabaseName("chat_memory_test") + .withUsername("postgres") + .withPassword("postgres") + .withCopyFileToContainer( + MountableFile.forClasspathResource("org/springframework/ai/chat/memory/jdbc/schema-postgresql.sql"), + "/docker-entrypoint-initdb.d/schema.sql"); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(JdbcChatMemoryRepositoryIT.TestApplication.class) + .withPropertyValues(String.format("myapp.datasource.url=%s", postgresContainer.getJdbcUrl()), + String.format("myapp.datasource.username=%s", postgresContainer.getUsername()), + String.format("myapp.datasource.password=%s", postgresContainer.getPassword())); + + @Test + void correctChatMemoryRepositoryInstance() { + this.contextRunner.run(context -> { + var chatMemoryRepository = context.getBean(ChatMemoryRepository.class); + assertThat(chatMemoryRepository).isInstanceOf(ChatMemoryRepository.class); + }); + } + + @ParameterizedTest + @CsvSource({ "Message from assistant,ASSISTANT", "Message from user,USER", "Message from system,SYSTEM" }) + void saveSingleMessage(String content, MessageType messageType) { + this.contextRunner.run(context -> { + var chatMemoryRepository = context.getBean(ChatMemoryRepository.class); + var conversationId = UUID.randomUUID().toString(); + var message = switch (messageType) { + case ASSISTANT -> new AssistantMessage(content + " - " + conversationId); + case USER -> new UserMessage(content + " - " + conversationId); + case SYSTEM -> new SystemMessage(content + " - " + conversationId); + default -> throw new IllegalArgumentException("Type not supported: " + messageType); + }; + + chatMemoryRepository.save(conversationId, List.of(message)); + + var jdbcTemplate = context.getBean(JdbcTemplate.class); + var query = "SELECT conversation_id, content, type, \"timestamp\" FROM ai_chat_memory WHERE conversation_id = ?"; + var result = jdbcTemplate.queryForMap(query, conversationId); + + assertThat(result.size()).isEqualTo(4); + assertThat(result.get("conversation_id")).isEqualTo(conversationId); + assertThat(result.get("content")).isEqualTo(message.getText()); + assertThat(result.get("type")).isEqualTo(messageType.name()); + assertThat(result.get("timestamp")).isInstanceOf(Timestamp.class); + }); + } + + @Test + void saveMultipleMessages() { + this.contextRunner.run(context -> { + var chatMemoryRepository = context.getBean(ChatMemoryRepository.class); + var conversationId = UUID.randomUUID().toString(); + var messages = List.of(new AssistantMessage("Message from assistant - " + conversationId), + new UserMessage("Message from user - " + conversationId), + new SystemMessage("Message from system - " + conversationId)); + + chatMemoryRepository.save(conversationId, messages); + + var jdbcTemplate = context.getBean(JdbcTemplate.class); + var query = "SELECT conversation_id, content, type, \"timestamp\" FROM ai_chat_memory WHERE conversation_id = ?"; + var results = jdbcTemplate.queryForList(query, conversationId); + + assertThat(results.size()).isEqualTo(messages.size()); + + for (var i = 0; i < messages.size(); i++) { + var message = messages.get(i); + var result = results.get(i); + + assertThat(result.get("conversation_id")).isNotNull(); + assertThat(result.get("conversation_id")).isEqualTo(conversationId); + assertThat(result.get("content")).isEqualTo(message.getText()); + assertThat(result.get("type")).isEqualTo(message.getMessageType().name()); + assertThat(result.get("timestamp")).isInstanceOf(Timestamp.class); + } + }); + } + + @Test + void findMessagesByConversationId() { + this.contextRunner.run(context -> { + var chatMemoryRepository = context.getBean(ChatMemoryRepository.class); + var conversationId = UUID.randomUUID().toString(); + var messages = List.of(new AssistantMessage("Message from assistant 1 - " + conversationId), + new AssistantMessage("Message from assistant 2 - " + conversationId), + new UserMessage("Message from user - " + conversationId), + new SystemMessage("Message from system - " + conversationId)); + + chatMemoryRepository.save(conversationId, messages); + + var results = chatMemoryRepository.findById(conversationId); + + assertThat(results.size()).isEqualTo(messages.size()); + assertThat(results).isEqualTo(messages); + }); + } + + @Test + void deleteMessagesByConversationId() { + this.contextRunner.run(context -> { + var chatMemoryRepository = context.getBean(ChatMemoryRepository.class); + var conversationId = UUID.randomUUID().toString(); + var messages = List.of(new AssistantMessage("Message from assistant - " + conversationId), + new UserMessage("Message from user - " + conversationId), + new SystemMessage("Message from system - " + conversationId)); + + chatMemoryRepository.save(conversationId, messages); + + chatMemoryRepository.deleteById(conversationId); + + var jdbcTemplate = context.getBean(JdbcTemplate.class); + var count = jdbcTemplate.queryForObject("SELECT COUNT(*) FROM ai_chat_memory WHERE conversation_id = ?", + Integer.class, conversationId); + + assertThat(count).isZero(); + }); + } + + @SpringBootConfiguration + @EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class }) + static class TestApplication { + + @Bean + ChatMemoryRepository chatMemoryRepository(JdbcTemplate jdbcTemplate) { + JdbcChatMemoryConfig config = JdbcChatMemoryConfig.builder().jdbcTemplate(jdbcTemplate).build(); + return JdbcChatMemoryRepository.create(config); + } + + @Bean + JdbcTemplate jdbcTemplate(DataSource dataSource) { + return new JdbcTemplate(dataSource); + } + + @Bean + @Primary + @ConfigurationProperties("myapp.datasource") + DataSourceProperties dataSourceProperties() { + return new DataSourceProperties(); + } + + @Bean + public DataSource dataSource(DataSourceProperties dataSourceProperties) { + return dataSourceProperties.initializeDataSourceBuilder().build(); + } + + } + +} diff --git a/memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/aot/hint/JdbcChatMemoryRuntimeHintsTest.java b/memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/aot/hint/JdbcChatMemoryRepositoryRuntimeHintsTest.java similarity index 83% rename from memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/aot/hint/JdbcChatMemoryRuntimeHintsTest.java rename to memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/aot/hint/JdbcChatMemoryRepositoryRuntimeHintsTest.java index 90c65272d72..d507c712eda 100644 --- a/memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/aot/hint/JdbcChatMemoryRuntimeHintsTest.java +++ b/memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/aot/hint/JdbcChatMemoryRepositoryRuntimeHintsTest.java @@ -38,18 +38,18 @@ /** * @author Jonathan Leijendekker */ -class JdbcChatMemoryRuntimeHintsTest { +class JdbcChatMemoryRepositoryRuntimeHintsTest { private final RuntimeHints hints = new RuntimeHints(); - private final JdbcChatMemoryRuntimeHints jdbcChatMemoryRuntimeHints = new JdbcChatMemoryRuntimeHints(); + private final JdbcChatMemoryRepositoryRuntimeHints jdbcChatMemoryRepositoryRuntimeHints = new JdbcChatMemoryRepositoryRuntimeHints(); @Test void aotFactoriesContainsRegistrar() { var match = SpringFactoriesLoader.forResourceLocation("META-INF/spring/aot.factories") .load(RuntimeHintsRegistrar.class) .stream() - .anyMatch(registrar -> registrar instanceof JdbcChatMemoryRuntimeHints); + .anyMatch(registrar -> registrar instanceof JdbcChatMemoryRepositoryRuntimeHints); assertThat(match).isTrue(); } @@ -57,7 +57,7 @@ void aotFactoriesContainsRegistrar() { @ParameterizedTest @MethodSource("getSchemaFileNames") void jdbcSchemasHasHints(String schemaFileName) { - this.jdbcChatMemoryRuntimeHints.registerHints(this.hints, getClass().getClassLoader()); + this.jdbcChatMemoryRepositoryRuntimeHints.registerHints(this.hints, getClass().getClassLoader()); var predicate = RuntimeHintsPredicates.resource() .forResource("org/springframework/ai/chat/memory/jdbc/" + schemaFileName); @@ -67,7 +67,7 @@ void jdbcSchemasHasHints(String schemaFileName) { @Test void dataSourceHasHints() { - this.jdbcChatMemoryRuntimeHints.registerHints(this.hints, getClass().getClassLoader()); + this.jdbcChatMemoryRepositoryRuntimeHints.registerHints(this.hints, getClass().getClassLoader()); assertThat(RuntimeHintsPredicates.reflection().onType(DataSource.class)).accepts(this.hints); } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java index 5dbd922602c..85b7e0cc617 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java @@ -34,6 +34,8 @@ import org.junit.jupiter.params.provider.ValueSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.memory.ChatMemory; +import org.springframework.ai.chat.memory.MessageWindowChatMemory; import reactor.core.publisher.Flux; import org.springframework.ai.chat.client.ChatClient; @@ -628,6 +630,29 @@ void validateStoreAndMetadata() { assertThat(response).isNotNull(); } + @Test + void chatMemory() { + ChatMemory memory = MessageWindowChatMemory.builder().build(); + String conversationId = "007"; + + UserMessage userMessage1 = new UserMessage("My name is James Bond"); + memory.add(conversationId, userMessage1); + ChatResponse response1 = chatModel.call(new Prompt(memory.get(conversationId))); + + assertThat(response1).isNotNull(); + memory.add(conversationId, response1.getResult().getOutput()); + + UserMessage userMessage2 = new UserMessage("What is my name?"); + memory.add(conversationId, userMessage2); + ChatResponse response2 = chatModel.call(new Prompt(memory.get(conversationId))); + + assertThat(response2).isNotNull(); + memory.add(conversationId, response2.getResult().getOutput()); + + assertThat(response2.getResults()).hasSize(1); + assertThat(response2.getResult().getOutput().getText()).contains("James Bond"); + } + record ActorsFilmsRecord(String actor, List movies) { } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientIT.java index 125148b1be1..943ffa47e40 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientIT.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,6 +30,7 @@ import org.junit.jupiter.params.provider.ValueSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.memory.MessageWindowChatMemory; import reactor.core.publisher.Flux; import org.springframework.ai.chat.client.ChatClient; @@ -378,6 +379,69 @@ void multiModalityAudioResponse() { logger.info("Response: " + response); } + @Test + void chatMemoryWithDefaults() { + ChatClient chatClient = ChatClient.builder(this.chatModel) + .defaultMemory(MessageWindowChatMemory.builder().build()) + .build(); + + String conversationId = "007"; + + ChatResponse response1 = chatClient.prompt("My name is Bond. James Bond.") + .conversationId(conversationId) + .call() + .chatResponse(); + + assertThat(response1).isNotNull(); + + ChatResponse response2 = chatClient.prompt("What is my name?") + .conversationId(conversationId) + .call() + .chatResponse(); + + assertThat(response2).isNotNull(); + assertThat(response2.getResults()).hasSize(1); + assertThat(response2.getResults().get(0).getOutput().getText()).contains("James Bond"); + } + + @Test + void chatMemoryWithMessageWindowSize() { + ChatClient chatClient = ChatClient.builder(this.chatModel) + .defaultMemory(MessageWindowChatMemory.builder().maxMessages(3).build()) + .build(); + + String conversationId = "007"; + + ChatResponse response1 = chatClient.prompt("The cat is on the table") + .conversationId(conversationId) + .call() + .chatResponse(); + + assertThat(response1).isNotNull(); + + ChatResponse response2 = chatClient.prompt("My name is Bond. James Bond.") + .conversationId(conversationId) + .call() + .chatResponse(); + + assertThat(response2).isNotNull(); + + ChatResponse response3 = chatClient.prompt("What is my name?") + .conversationId(conversationId) + .call() + .chatResponse(); + + assertThat(response3).isNotNull(); + + ChatResponse response4 = chatClient.prompt("Where is the cat?") + .conversationId(conversationId) + .call() + .chatResponse(); + + assertThat(response4).isNotNull(); + assertThat(response2.getResults().get(0).getOutput().getText()).doesNotContainIgnoringCase("table"); + } + record ActorsFilms(String actor, List movies) { } diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClient.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClient.java index 14ca4890d9f..b9da1a89fd8 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClient.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClient.java @@ -27,6 +27,7 @@ import org.springframework.ai.chat.client.advisor.api.Advisor; import org.springframework.ai.chat.client.observation.ChatClientObservationConvention; +import org.springframework.ai.chat.memory.ChatMemory; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; @@ -247,6 +248,10 @@ interface ChatClientRequestSpec { ChatClientRequestSpec user(Consumer consumer); + ChatClientRequestSpec memory(ChatMemory chatMemory); + + ChatClientRequestSpec conversationId(String conversationId); + CallResponseSpec call(); StreamResponseSpec stream(); @@ -294,6 +299,8 @@ interface Builder { Builder defaultToolContext(Map toolContext); + Builder defaultMemory(ChatMemory chatMemory); + Builder clone(); ChatClient build(); diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClientResponse.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClientResponse.java index eae44a51728..8fba5514899 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClientResponse.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/ChatClientResponse.java @@ -51,7 +51,7 @@ public static class Builder { private Builder() { } - public Builder chatResponse(ChatResponse chatResponse) { + public Builder chatResponse(@Nullable ChatResponse chatResponse) { this.chatResponse = chatResponse; return this; } diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java index 3a4655ed95e..c7decba6ff4 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java @@ -44,6 +44,7 @@ import org.springframework.ai.chat.client.observation.ChatClientObservationConvention; import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation; import org.springframework.ai.chat.client.observation.DefaultChatClientObservationConvention; +import org.springframework.ai.chat.memory.ChatMemory; import org.springframework.ai.chat.messages.AbstractMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.MessageType; @@ -131,7 +132,7 @@ private static AdvisedRequest toAdvisedRequest(DefaultChatClientRequestSpec inpu public static DefaultChatClientRequestSpec toDefaultChatClientRequestSpec(AdvisedRequest advisedRequest, ObservationRegistry observationRegistry, ChatClientObservationConvention customObservationConvention) { - return new DefaultChatClientRequestSpec(advisedRequest.chatModel(), advisedRequest.userText(), + return new DefaultChatClientRequestSpec(advisedRequest.chatModel(), null, null, advisedRequest.userText(), advisedRequest.userParams(), advisedRequest.systemText(), advisedRequest.systemParams(), advisedRequest.toolCallbacks(), advisedRequest.messages(), advisedRequest.toolNames(), advisedRequest.media(), advisedRequest.chatOptions(), advisedRequest.advisors(), @@ -660,10 +661,13 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe private final Map advisorParams = new HashMap<>(); - private final DefaultAroundAdvisorChain.Builder aroundAdvisorChainBuilder; - private final Map toolContext = new HashMap<>(); + private String conversationId; + + @Nullable + private ChatMemory chatMemory; + @Nullable private String userText; @@ -675,16 +679,17 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe /* copy constructor */ DefaultChatClientRequestSpec(DefaultChatClientRequestSpec ccr) { - this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.systemText, ccr.systemParams, ccr.toolCallbacks, - ccr.messages, ccr.toolNames, ccr.media, ccr.chatOptions, ccr.advisors, ccr.advisorParams, - ccr.observationRegistry, ccr.observationConvention, ccr.toolContext); + this(ccr.chatModel, ccr.chatMemory, ccr.conversationId, ccr.userText, ccr.userParams, ccr.systemText, + ccr.systemParams, ccr.toolCallbacks, ccr.messages, ccr.toolNames, ccr.media, ccr.chatOptions, + ccr.advisors, ccr.advisorParams, ccr.observationRegistry, ccr.observationConvention, + ccr.toolContext); } - public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userText, - Map userParams, @Nullable String systemText, Map systemParams, - List toolCallbacks, List messages, List toolNames, List media, - @Nullable ChatOptions chatOptions, List advisors, Map advisorParams, - ObservationRegistry observationRegistry, + public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable ChatMemory chatMemory, + @Nullable String conversationId, @Nullable String userText, Map userParams, + @Nullable String systemText, Map systemParams, List toolCallbacks, + List messages, List toolNames, List media, @Nullable ChatOptions chatOptions, + List advisors, Map advisorParams, ObservationRegistry observationRegistry, @Nullable ChatClientObservationConvention observationConvention, Map toolContext) { Assert.notNull(chatModel, "chatModel cannot be null"); @@ -703,6 +708,10 @@ public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userTe this.chatOptions = chatOptions != null ? chatOptions.copy() : (chatModel.getDefaultOptions() != null) ? chatModel.getDefaultOptions().copy() : null; + this.chatMemory = chatMemory; + this.conversationId = StringUtils.hasText(conversationId) ? conversationId + : ChatMemory.DEFAULT_CONVERSATION_ID; + this.userText = userText; this.userParams.putAll(userParams); this.systemText = systemText; @@ -723,9 +732,6 @@ public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userTe // They play the role of the last advisors in the advisor chain. this.advisors.add(new ChatModelCallAdvisor(chatModel)); this.advisors.add(new ChatModelStreamAdvisor(chatModel)); - - this.aroundAdvisorChainBuilder = DefaultAroundAdvisorChain.builder(observationRegistry) - .pushAll(this.advisors); } private ObservationRegistry getObservationRegistry() { @@ -787,6 +793,15 @@ public Map getToolContext() { return this.toolContext; } + public String getConversationId() { + return this.conversationId; + } + + @Nullable + public ChatMemory getChatMemory() { + return this.chatMemory; + } + /** * Return a {@code ChatClient2Builder} to create a new {@code ChatClient2} whose * settings are replicated from this {@code ChatClientRequest}. @@ -822,7 +837,6 @@ public ChatClientRequestSpec advisors(Consumer consumer) consumer.accept(advisorSpec); this.advisorParams.putAll(advisorSpec.getParams()); this.advisors.addAll(advisorSpec.getAdvisors()); - this.aroundAdvisorChainBuilder.pushAll(advisorSpec.getAdvisors()); return this; } @@ -830,7 +844,6 @@ public ChatClientRequestSpec advisors(Advisor... advisors) { Assert.notNull(advisors, "advisors cannot be null"); Assert.noNullElements(advisors, "advisors cannot contain null elements"); this.advisors.addAll(Arrays.asList(advisors)); - this.aroundAdvisorChainBuilder.pushAll(Arrays.asList(advisors)); return this; } @@ -838,7 +851,6 @@ public ChatClientRequestSpec advisors(List advisors) { Assert.notNull(advisors, "advisors cannot be null"); Assert.noNullElements(advisors, "advisors cannot contain null elements"); this.advisors.addAll(advisors); - this.aroundAdvisorChainBuilder.pushAll(advisors); return this; } @@ -982,18 +994,40 @@ public ChatClientRequestSpec user(Consumer consumer) { return this; } + @Override + public ChatClientRequestSpec memory(ChatMemory chatMemory) { + Assert.notNull(chatMemory, "chatMemory cannot be null"); + this.chatMemory = chatMemory; + return this; + } + + @Override + public ChatClientRequestSpec conversationId(String conversationId) { + Assert.hasText(conversationId, "conversationId cannot be null or empty"); + this.conversationId = conversationId; + return this; + } + public CallResponseSpec call() { - BaseAdvisorChain advisorChain = aroundAdvisorChainBuilder.build(); + BaseAdvisorChain advisorChain = buildAdvisorChain(); return new DefaultCallResponseSpec(toAdvisedRequest(this).toChatClientRequest(), advisorChain, observationRegistry, observationConvention); } public StreamResponseSpec stream() { - BaseAdvisorChain advisorChain = aroundAdvisorChainBuilder.build(); + BaseAdvisorChain advisorChain = buildAdvisorChain(); return new DefaultStreamResponseSpec(toAdvisedRequest(this).toChatClientRequest(), advisorChain, observationRegistry, observationConvention); } + private BaseAdvisorChain buildAdvisorChain() { + return DefaultAroundAdvisorChain.builder(this.observationRegistry) + .conversationId(this.conversationId) + .chatMemory(this.chatMemory) + .pushAll(this.advisors) + .build(); + } + } // Prompt diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java index 02b3e29f681..b2d7a477746 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java @@ -30,6 +30,7 @@ import org.springframework.ai.chat.client.DefaultChatClient.DefaultChatClientRequestSpec; import org.springframework.ai.chat.client.advisor.api.Advisor; import org.springframework.ai.chat.client.observation.ChatClientObservationConvention; +import org.springframework.ai.chat.memory.ChatMemory; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.prompt.ChatOptions; @@ -64,8 +65,8 @@ public DefaultChatClientBuilder(ChatModel chatModel, ObservationRegistry observa @Nullable ChatClientObservationConvention customObservationConvention) { Assert.notNull(chatModel, "the " + ChatModel.class.getName() + " must be non-null"); Assert.notNull(observationRegistry, "the " + ObservationRegistry.class.getName() + " must be non-null"); - this.defaultRequest = new DefaultChatClientRequestSpec(chatModel, null, Map.of(), null, Map.of(), List.of(), - List.of(), List.of(), List.of(), null, List.of(), Map.of(), observationRegistry, + this.defaultRequest = new DefaultChatClientRequestSpec(chatModel, null, null, null, Map.of(), null, Map.of(), + List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(), observationRegistry, customObservationConvention, Map.of()); } @@ -190,6 +191,13 @@ public Builder defaultToolContext(Map toolContext) { return this; } + @Override + public Builder defaultMemory(ChatMemory chatMemory) { + Assert.notNull(chatMemory, "chatMemory cannot be null"); + this.defaultRequest.memory(chatMemory); + return this; + } + void addMessages(List messages) { this.defaultRequest.messages(messages); } diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java index 72a69f657a3..79e9d5289f6 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java @@ -19,6 +19,7 @@ import java.util.Map; import java.util.function.Function; +import org.springframework.ai.chat.memory.ChatMemory; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; @@ -38,7 +39,10 @@ * @author Christian Tzolov * @author Ilayaperumal Gopinathan * @since 1.0.0 + * @deprecated in favour of providing the ChatClient directly with a {@link ChatMemory} + * instance. */ +@Deprecated public abstract class AbstractChatMemoryAdvisor implements CallAroundAdvisor, StreamAroundAdvisor { /** diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/ChatModelCallAdvisor.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/ChatModelCallAdvisor.java index 7d33112899d..e515af0ed8a 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/ChatModelCallAdvisor.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/ChatModelCallAdvisor.java @@ -20,11 +20,16 @@ import org.springframework.ai.chat.client.ChatClientResponse; import org.springframework.ai.chat.client.advisor.api.CallAdvisor; import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; +import org.springframework.ai.chat.memory.ChatMemory; +import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; import org.springframework.core.Ordered; import org.springframework.util.Assert; +import java.util.List; import java.util.Map; /** @@ -45,7 +50,29 @@ public ChatModelCallAdvisor(ChatModel chatModel) { public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAroundAdvisorChain chain) { Assert.notNull(chatClientRequest, "the chatClientRequest cannot be null"); - ChatResponse chatResponse = chatModel.call(chatClientRequest.prompt()); + ChatMemory chatMemory = chain.getChatMemory(); + + ChatResponse chatResponse; + if (chatMemory == null) { + chatResponse = chatModel.call(chatClientRequest.prompt()); + } + else { + String conversationId = chain.getConversationId(); + chatMemory.add(conversationId, chatClientRequest.prompt().getInstructions()); + Prompt prompt = chatClientRequest.prompt().mutate().messages(chatMemory.get(conversationId)).build(); + chatResponse = chatModel.call(prompt); + if (chatResponse != null) { + List generations = chatResponse.getResults(); + if (generations != null) { + List assistantMessages = generations.stream() + .map(generation -> (Message) generation.getOutput()) + .toList(); + chatMemory.add(conversationId, assistantMessages); + } + } + + } + return ChatClientResponse.builder() .chatResponse(chatResponse) .context(Map.copyOf(chatClientRequest.context())) diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChain.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChain.java index 6ee2aedb30b..254d5a7bfa2 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChain.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChain.java @@ -33,6 +33,8 @@ import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; import org.springframework.ai.chat.client.advisor.api.StreamAdvisor; import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; +import org.springframework.ai.chat.memory.ChatMemory; +import org.springframework.lang.Nullable; import reactor.core.publisher.Flux; import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationContext; @@ -61,15 +63,23 @@ public class DefaultAroundAdvisorChain implements BaseAdvisorChain { private final Deque streamAroundAdvisors; + private final String conversationId; + + @Nullable + private final ChatMemory chatMemory; + private final ObservationRegistry observationRegistry; - DefaultAroundAdvisorChain(ObservationRegistry observationRegistry, Deque callAroundAdvisors, + private DefaultAroundAdvisorChain(String conversationId, @Nullable ChatMemory chatMemory, + ObservationRegistry observationRegistry, Deque callAroundAdvisors, Deque streamAroundAdvisors) { - + Assert.hasText(conversationId, "the conversationId must not be null or empty"); Assert.notNull(observationRegistry, "the observationRegistry must be non-null"); Assert.notNull(callAroundAdvisors, "the callAroundAdvisors must be non-null"); Assert.notNull(streamAroundAdvisors, "the streamAroundAdvisors must be non-null"); + this.conversationId = conversationId; + this.chatMemory = chatMemory; this.observationRegistry = observationRegistry; this.callAroundAdvisors = callAroundAdvisors; this.streamAroundAdvisors = streamAroundAdvisors; @@ -79,6 +89,17 @@ public static Builder builder(ObservationRegistry observationRegistry) { return new Builder(observationRegistry); } + @Override + public String getConversationId() { + return this.conversationId; + } + + @Override + @Nullable + public ChatMemory getChatMemory() { + return this.chatMemory; + } + @Override public ChatClientResponse nextCall(ChatClientRequest chatClientRequest) { Assert.notNull(chatClientRequest, "the chatClientRequest cannot be null"); @@ -240,12 +261,27 @@ public static class Builder { private final Deque streamAroundAdvisors; + private String conversationId = ChatMemory.DEFAULT_CONVERSATION_ID; + + @Nullable + private ChatMemory chatMemory; + public Builder(ObservationRegistry observationRegistry) { this.observationRegistry = observationRegistry; this.callAroundAdvisors = new ConcurrentLinkedDeque<>(); this.streamAroundAdvisors = new ConcurrentLinkedDeque<>(); } + public Builder conversationId(String conversationId) { + this.conversationId = conversationId; + return this; + } + + public Builder chatMemory(@Nullable ChatMemory chatMemory) { + this.chatMemory = chatMemory; + return this; + } + public Builder push(Advisor aroundAdvisor) { Assert.notNull(aroundAdvisor, "the aroundAdvisor must be non-null"); return this.pushAll(List.of(aroundAdvisor)); @@ -293,8 +329,8 @@ private void reOrder() { } public DefaultAroundAdvisorChain build() { - return new DefaultAroundAdvisorChain(this.observationRegistry, this.callAroundAdvisors, - this.streamAroundAdvisors); + return new DefaultAroundAdvisorChain(this.conversationId, this.chatMemory, this.observationRegistry, + this.callAroundAdvisors, this.streamAroundAdvisors); } } diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java index cd1c53cb301..03eaebceac3 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,6 +27,7 @@ import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; import org.springframework.ai.chat.memory.ChatMemory; +import org.springframework.ai.chat.memory.MessageWindowChatMemory; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.MessageAggregator; @@ -36,7 +37,10 @@ * * @author Christian Tzolov * @since 1.0.0 + * @deprecated in favor of providing the ChatClient directly with a + * {@link MessageWindowChatMemory} instance. */ +@Deprecated public class MessageChatMemoryAdvisor extends AbstractChatMemoryAdvisor { public MessageChatMemoryAdvisor(ChatMemory chatMemory) { diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisorChain.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisorChain.java new file mode 100644 index 00000000000..0f108713afe --- /dev/null +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisorChain.java @@ -0,0 +1,37 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.client.advisor.api; + +import org.springframework.ai.chat.memory.ChatMemory; +import org.springframework.lang.Nullable; + +/** + * Defines the context for executing a chain of advisors as part of processing a chat + * request. + */ +public interface AdvisorChain { + + default String getConversationId() { + return ChatMemory.DEFAULT_CONVERSATION_ID; + } + + @Nullable + default ChatMemory getChatMemory() { + return null; + } + +} diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisorChain.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisorChain.java index 8f4a62825ef..e768f45e1cc 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisorChain.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/CallAroundAdvisorChain.java @@ -28,7 +28,7 @@ * @deprecated in favor of {@link CallAdvisorChain} */ @Deprecated -public interface CallAroundAdvisorChain { +public interface CallAroundAdvisorChain extends AdvisorChain { /** * Invokes the next Around Advisor in the CallAroundAdvisorChain with the given diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisorChain.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisorChain.java index 7ab9631785a..abb4a62e60f 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisorChain.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/api/StreamAroundAdvisorChain.java @@ -29,7 +29,7 @@ * @deprecated in favor of {@link StreamAdvisorChain} */ @Deprecated -public interface StreamAroundAdvisorChain { +public interface StreamAroundAdvisorChain extends AdvisorChain { /** * This method delegates the call to the next StreamAroundAdvisor in the chain and is diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/memory/InMemoryChatMemory.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/memory/InMemoryChatMemory.java index 73318fc8c4b..3056bb5de18 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/memory/InMemoryChatMemory.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/memory/InMemoryChatMemory.java @@ -34,7 +34,10 @@ * @see ChatMemory * @author Christian Tzolov * @since 1.0.0 M1 + * @deprecated in favor of {@link MessageWindowChatMemory}, which internally uses + * {@link InMemoryChatMemoryRepository}. */ +@Deprecated public class InMemoryChatMemory implements ChatMemory { Map> conversationHistory = new ConcurrentHashMap<>(); diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientBuilderTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientBuilderTests.java index 4bb321f85d9..a42287aaf02 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientBuilderTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientBuilderTests.java @@ -95,4 +95,11 @@ void whenSystemCharsetIsNullThenThrows() { .hasMessage("charset cannot be null"); } + @Test + void whenChatMemoryIsNullThenThrows() { + DefaultChatClientBuilder builder = new DefaultChatClientBuilder(mock(ChatModel.class)); + assertThatThrownBy(() -> builder.defaultMemory(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("chatMemory cannot be null"); + } + } diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java index 3a8d269fd77..69b0298b4a6 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java @@ -20,16 +20,15 @@ import java.net.URI; import java.net.URL; import java.nio.charset.Charset; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.util.*; import java.util.function.Consumer; import io.micrometer.observation.ObservationRegistry; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; +import org.springframework.ai.chat.memory.ChatMemory; +import org.springframework.ai.chat.memory.MessageWindowChatMemory; import reactor.core.publisher.Flux; import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor; @@ -1300,15 +1299,15 @@ void whenChatResponseContentIsNullThenReturnFlux() { void buildChatClientRequestSpec() { ChatModel chatModel = mock(ChatModel.class); DefaultChatClient.DefaultChatClientRequestSpec spec = new DefaultChatClient.DefaultChatClientRequestSpec( - chatModel, null, Map.of(), null, Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), - Map.of(), ObservationRegistry.NOOP, null, Map.of()); + chatModel, null, null, null, Map.of(), null, Map.of(), List.of(), List.of(), List.of(), List.of(), null, + List.of(), Map.of(), ObservationRegistry.NOOP, null, Map.of()); assertThat(spec).isNotNull(); } @Test void whenChatModelIsNullThenThrow() { - assertThatThrownBy(() -> new DefaultChatClient.DefaultChatClientRequestSpec(null, null, Map.of(), null, - Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(), + assertThatThrownBy(() -> new DefaultChatClient.DefaultChatClientRequestSpec(null, null, null, null, Map.of(), + null, Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(), ObservationRegistry.NOOP, null, Map.of())) .isInstanceOf(IllegalArgumentException.class) .hasMessage("chatModel cannot be null"); @@ -1316,9 +1315,9 @@ void whenChatModelIsNullThenThrow() { @Test void whenObservationRegistryIsNullThenThrow() { - assertThatThrownBy(() -> new DefaultChatClient.DefaultChatClientRequestSpec(mock(ChatModel.class), null, - Map.of(), null, Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(), null, - null, Map.of())) + assertThatThrownBy(() -> new DefaultChatClient.DefaultChatClientRequestSpec(mock(ChatModel.class), null, null, + null, Map.of(), null, Map.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(), + null, null, Map.of())) .isInstanceOf(IllegalArgumentException.class) .hasMessage("observationRegistry cannot be null"); } @@ -1910,6 +1909,50 @@ void whenUserConsumerWithoutUserTextThenReturn() { assertThat(defaultSpec.getMedia()).hasSize(1); } + @Test + void whenConversationIdThenReturn() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + String conversationId = UUID.randomUUID().toString(); + spec = spec.conversationId(conversationId); + DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; + assertThat(defaultSpec.getConversationId()).isEqualTo(conversationId); + } + + @Test + void whenConversationIdIsNullThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + assertThatThrownBy(() -> spec.conversationId(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("conversationId cannot be null or empty"); + } + + @Test + void whenConversationIdIsEmptyThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + assertThatThrownBy(() -> spec.conversationId("")).isInstanceOf(IllegalArgumentException.class) + .hasMessage("conversationId cannot be null or empty"); + } + + @Test + void whenChatMemoryThenReturn() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + ChatMemory chatMemory = MessageWindowChatMemory.builder().build(); + spec = spec.memory(chatMemory); + DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; + assertThat(defaultSpec.getChatMemory()).isEqualTo(chatMemory); + } + + @Test + void whenChatMemoryIsNullThenThrow() { + ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); + ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); + assertThatThrownBy(() -> spec.memory(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("chatMemory cannot be null"); + } + record Person(String name) { } diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc index 86e963deb0d..c2e825e0f0a 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc @@ -94,6 +94,7 @@ * xref:api/retrieval-augmented-generation.adoc[Retrieval Augmented Generation (RAG)] ** xref:api/etl-pipeline.adoc[] * xref:api/structured-output-converter.adoc[Structured Output] +* xref:api/memory.adoc[Memory] * xref:api/tools.adoc[Tool Calling] ** xref:api/tools-migration.adoc[Migrating to ToolCallback API] * xref:api/mcp/mcp-overview.adoc[Model Context Protocol (MCP)] diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/memory.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/memory.adoc new file mode 100644 index 00000000000..22a48121f69 --- /dev/null +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/memory.adoc @@ -0,0 +1,152 @@ +[[Memory]] += Memory + +Large language models (LLMs) are stateless, meaning they do not retain information about previous interactions. This can be a limitation when you want to maintain context or state across multiple interactions. To address this, Spring AI provides a `ChatMemory` abstraction that allows you to store and retrieve information across multiple interactions with the LLM. + +WARNING: In previous versions of Spring AI, the way to provide memory to a Chat Client was by using one of the available advisors (e.g. `MessageChatMemoryAdvisor`, `PromptChatMemoryAdvisor`, or `VectorStoreChatMemoryAdvisor`). This approach has been deprecated in favor of a more flexible and extensible memory system. The new system allows you to provide memory directly to the `ChatClient`, making it easier to manage and customize memory behavior. When adopting the new system, make sure to remove any old memory advisors from your configuration to avoid unintended behavior. + +== Usage + +=== Providing Memory to Chat Client + +The `ChatClient` can be configured with a memory implementation to maintain conversation context across multiple interactions. Here's an example of how to set up a `ChatClient` with memory: + +[source,java] +---- +ChatClient chatClient = ChatClient.builder(chatModel) + .defaultMemory(MessageWindowChatMemory.builder().build()) + .build(); +---- + +When sending a prompt to the model, you can specify a `conversationId` to associate the message with a specific conversation. This allows the model to remember previous interactions within that conversation. All messages with the same `conversationId` are stored together in the memory. + +[source,java] +---- +String conversationId = "007"; + +// First interaction +ChatResponse response1 = chatClient.prompt("My name is Bond. James Bond.") + .conversationId(conversationId) + .call() + .chatResponse(); + +// Second interaction - the model remembers the name from the first interaction +ChatResponse response2 = chatClient.prompt("What is my name?") + .conversationId(conversationId) + .call() + .chatResponse(); + +// The response will contain "James Bond" +---- + +=== Choosing a Storage for the Memory + +A `ChatMemory` implementation typically uses a `ChatMemoryRepository` to store messages. The repository is responsible for persisting the messages and providing retrieval capabilities. You can choose from various repository implementations based on your storage needs, such as in-memory, JDBC, Cassandra, or Neo4j. + +Let's consider `MessageWindowChatMemory` as an example. By default, it uses an in-memory repository (`InMemoryChatMemoryRepository`), which is suitable for simple use cases. However, if you need to persist messages across application restarts or share memory between different instances, you can configure it to use a different repository. + +[source,java] +---- +// Create a custom repository +ChatMemoryRepository repository = new InMemoryChatMemoryRepository(); + +// Configure the memory with the custom repository +MessageWindowChatMemory memory = MessageWindowChatMemory.builder() + .chatMemoryRepository(repository) + .maxMessages(10) // Limit the number of messages stored + .build(); + +// Create a ChatClient with the configured memory +ChatClient chatClient = ChatClient.builder(chatModel) + .defaultMemory(memory) + .build(); +---- + +=== Managing Memory Manually with ChatModel + +If you're working directly with a `ChatModel` instead of a `ChatClient`, you can manage the memory manually: + +[source,java] +---- +// Create a memory instance +ChatMemory memory = MessageWindowChatMemory.builder().build(); +String conversationId = "007"; + +// First interaction +UserMessage userMessage1 = new UserMessage("My name is James Bond"); +memory.add(conversationId, userMessage1); +ChatResponse response1 = chatModel.call(new Prompt(memory.get(conversationId))); +memory.add(conversationId, response1.getResult().getOutput()); + +// Second interaction +UserMessage userMessage2 = new UserMessage("What is my name?"); +memory.add(conversationId, userMessage2); +ChatResponse response2 = chatModel.call(new Prompt(memory.get(conversationId))); +memory.add(conversationId, response2.getResult().getOutput()); + +// The response will contain "James Bond" +---- + +== Memory Types + +The `ChatMemory` abstraction allows you to implement various types of memory to suit different use cases. The choice of memory type can significantly impact the performance and behavior of your application. This section describes the built-in memory types provided by Spring AI and their characteristics. + +=== Message Window Chat Memory + +`MessageWindowChatMemory` maintains a window of messages up to a specified maximum size. When the number of messages exceeds the maximum, older messages are removed while preserving system messages. The default window size is 200 messages. + +[source,java] +---- +// Create a memory with a window size of 10 messages +MessageWindowChatMemory memory = MessageWindowChatMemory.builder() + .maxMessages(10) + .build(); +---- + +== Storage + +Spring AI offers the `ChatMemoryRepository` abstraction for storing chat memory. This section describes the built-in repositories provided by Spring AI and how to use them, but you can also implement your own repository if needed. + +=== In-Memory Repository + +`InMemoryChatMemoryRepository` stores messages in memory using a `ConcurrentHashMap`. + +[source,java] +---- +ChatMemoryRepository repository = new InMemoryChatMemoryRepository(); +---- + +=== JDBC Repository + +`JdbcChatMemoryRepository` is a built-in implementation that uses JDBC to store messages in a relational database. It is suitable for applications that require persistent storage of chat memory. + +WARNING: When you use the JDBC repository, make sure to size the chat memory appropriately. For example, the `MessageWindowChatMemory` implementation has a default size of 200 messages, which might be too much to retrieve and store in a single database transaction. You can adjust the size using the `maxMessages` property. + +[source,java] +---- +JdbcChatMemoryConfig config = JdbcChatMemoryConfig.builder().jdbcTemplate(jdbcTemplate).build(); +JdbcChatMemoryRepository chatMemoryRepository = JdbcChatMemoryRepository.create(config); + +MessageWindowChatMemory memory = MessageWindowChatMemory.builder() + .chatMemoryRepository(chatMemoryRepository) + .maxMessages(20) // Limit the number of messages stored + .build(); +---- + +Spring AI provides auto-configuration for `JdbcChatMemoryRepository`. + +[source,java] +---- +@RestController +class AiController { + + private final ChatMemory chatMemory; + + public AiController(JdbcChatMemoryRepository chatMemoryRepository) { + this.chatMemory = MessageWindowChatMemory.builder() + .chatMemoryRepository(chatMemoryRepository) + .maxMessages(20) + .build(); + } +} +---- diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/ChatMemory.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/ChatMemory.java new file mode 100644 index 00000000000..299c78d1aed --- /dev/null +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/ChatMemory.java @@ -0,0 +1,67 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.memory; + +import org.springframework.ai.chat.messages.Message; +import org.springframework.util.Assert; + +import java.util.Collections; +import java.util.List; + +/** + * The contract for storing and managing the history of chat conversations. + * + * @author Christian Tzolov + * @author Thomas Vitale + * @since 1.0.0 + */ +public interface ChatMemory { + + String DEFAULT_CONVERSATION_ID = "default"; + + /** + * Save the specified message in the chat memory for the specified conversation. + */ + default void add(String conversationId, Message message) { + this.add(conversationId, Collections.singletonList(message)); + } + + /** + * Save the specified messages in the chat memory for the specified conversation. + */ + void add(String conversationId, List messages); + + /** + * Get the messages in the chat memory for the specified conversation. + */ + default List get(String conversationId) { + Assert.hasText(conversationId, "conversationId cannot be null or empty"); + return get(conversationId, Integer.MAX_VALUE); + } + + /** + * @deprecated in favor of {@link MessageWindowChatMemory}. + */ + @Deprecated + List get(String conversationId, int lastN); + + /** + * Clear the chat memory for the specified conversation. + */ + void clear(String conversationId); + +} diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/memory/ChatMemory.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/ChatMemoryRepository.java similarity index 52% rename from spring-ai-client-chat/src/main/java/org/springframework/ai/chat/memory/ChatMemory.java rename to spring-ai-model/src/main/java/org/springframework/ai/chat/memory/ChatMemoryRepository.java index 7003457df16..4a7229c6818 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/memory/ChatMemory.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/ChatMemoryRepository.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,30 +16,22 @@ package org.springframework.ai.chat.memory; -import java.util.List; - import org.springframework.ai.chat.messages.Message; +import java.util.List; + /** - * The ChatMemory interface represents a storage for chat conversation history. It - * provides methods to add messages to a conversation, retrieve messages from a - * conversation, and clear the conversation history. + * A repository for storing and retrieving chat messages. * - * @author Christian Tzolov + * @author Thomas Vitale * @since 1.0.0 */ -public interface ChatMemory { - - // TODO: consider a non-blocking interface for streaming usages - - default void add(String conversationId, Message message) { - this.add(conversationId, List.of(message)); - } +public interface ChatMemoryRepository { - void add(String conversationId, List messages); + List findById(String conversationId); - List get(String conversationId, int lastN); + void save(String conversationId, List messages); - void clear(String conversationId); + void deleteById(String conversationId); } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/DefaultMessageWindowProcessingPolicy.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/DefaultMessageWindowProcessingPolicy.java new file mode 100644 index 00000000000..2c217d601a4 --- /dev/null +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/DefaultMessageWindowProcessingPolicy.java @@ -0,0 +1,80 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.memory; + +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.util.Assert; + +import java.util.ArrayList; +import java.util.List; + +/** + * A policy that adds new messages to the existing history messages and ensures that the + * total number of messages does not exceed the specified limit. + *

+ * Messages of type {@link SystemMessage} are treated specially: if a new + * {@link SystemMessage} is added, all previous {@link SystemMessage} instances are + * removed from the history. Also, if the total number of messages exceeds the limit, the + * {@link SystemMessage} messages are preserved while removing other types of messages. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public final class DefaultMessageWindowProcessingPolicy implements MessageWindowProcessingPolicy { + + @Override + public List process(List historyMessages, List newMessages, int limit) { + Assert.notNull(historyMessages, "historyMessages cannot be null"); + Assert.noNullElements(historyMessages, "historyMessages cannot contain null elements"); + Assert.notNull(newMessages, "newMessages cannot be null"); + Assert.noNullElements(newMessages, "newMessages cannot contain null elements"); + Assert.isTrue(limit > 0, "limit must be greater than 0"); + + List processedMessages = new ArrayList<>(historyMessages); + + for (Message newMessage : newMessages) { + if (newMessage instanceof SystemMessage systemMessage && !processedMessages.contains(systemMessage)) { + // If a new SystemMessage is added, remove all previous SystemMessages + processedMessages.removeIf(m -> m instanceof SystemMessage); + break; + } + } + + processedMessages.addAll(new ArrayList<>(newMessages)); + + if (processedMessages.size() <= limit) { + return processedMessages; + } + + int messagesToRemove = processedMessages.size() - limit; + int index = 0; + + while (messagesToRemove > 0 && index < processedMessages.size()) { + if (!(processedMessages.get(index) instanceof SystemMessage)) { + processedMessages.remove(index); + messagesToRemove--; + } + else { + index++; + } + } + + return processedMessages; + } + +} diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/InMemoryChatMemoryRepository.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/InMemoryChatMemoryRepository.java new file mode 100644 index 00000000000..ace49b462c5 --- /dev/null +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/InMemoryChatMemoryRepository.java @@ -0,0 +1,60 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.memory; + +import org.springframework.ai.chat.messages.Message; +import org.springframework.util.Assert; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * An in-memory implementation of {@link ChatMemoryRepository}. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public final class InMemoryChatMemoryRepository implements ChatMemoryRepository { + + Map> chatMemoryStore = new ConcurrentHashMap<>(); + + @Override + public List findById(String conversationId) { + Assert.hasText(conversationId, "conversationId cannot be null or empty"); + List messages = this.chatMemoryStore.get(conversationId); + return messages != null ? new ArrayList<>(messages) : List.of(); + } + + @Override + public void save(String conversationId, List messages) { + Assert.hasText(conversationId, "conversationId cannot be null or empty"); + Assert.notNull(messages, "messages cannot be null"); + Assert.noNullElements(messages, "messages cannot contain null elements"); + + this.chatMemoryStore.putIfAbsent(conversationId, new ArrayList<>()); + this.chatMemoryStore.get(conversationId).addAll(messages); + } + + @Override + public void deleteById(String conversationId) { + Assert.hasText(conversationId, "conversationId cannot be null or empty"); + this.chatMemoryStore.remove(conversationId); + } + +} diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/MessageWindowChatMemory.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/MessageWindowChatMemory.java new file mode 100644 index 00000000000..80e2082408c --- /dev/null +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/MessageWindowChatMemory.java @@ -0,0 +1,119 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.memory; + +import org.springframework.ai.chat.messages.Message; +import org.springframework.util.Assert; + +import java.util.List; + +/** + * A chat memory implementation that maintains a message window of a specified size. When + * the number of messages exceeds the maximum size, older messages are removed while + * preserving SystemMessages. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public final class MessageWindowChatMemory implements ChatMemory { + + private static final int DEFAULT_MAX_MESSAGES = 200; + + private static final ChatMemoryRepository DEFAULT_CHAT_MEMORY_REPOSITORY = new InMemoryChatMemoryRepository(); + + private static final MessageWindowProcessingPolicy DEFAULT_MESSAGE_WINDOW_EVICTION_POLICY = new DefaultMessageWindowProcessingPolicy(); + + private final ChatMemoryRepository chatMemoryRepository; + + private final MessageWindowProcessingPolicy messageWindowProcessingPolicy; + + private final int maxMessages; + + private MessageWindowChatMemory(ChatMemoryRepository chatMemoryRepository, + MessageWindowProcessingPolicy messageWindowProcessingPolicy, int maxMessages) { + this.chatMemoryRepository = chatMemoryRepository; + this.messageWindowProcessingPolicy = messageWindowProcessingPolicy; + this.maxMessages = maxMessages; + } + + @Override + public void add(String conversationId, List messages) { + Assert.hasText(conversationId, "conversationId cannot be null or empty"); + Assert.notNull(messages, "messages cannot be null"); + Assert.noNullElements(messages, "messages cannot contain null elements"); + + List historyMessages = this.chatMemoryRepository.findById(conversationId); + List processedMessages = this.messageWindowProcessingPolicy.process(historyMessages, messages, + this.maxMessages); + this.chatMemoryRepository.save(conversationId, processedMessages); + } + + @Override + public List get(String conversationId) { + Assert.hasText(conversationId, "conversationId cannot be null or empty"); + return this.chatMemoryRepository.findById(conversationId); + } + + @Override + @Deprecated // in favor of get(conversationId) + public List get(String conversationId, int lastN) { + return get(conversationId); + } + + @Override + public void clear(String conversationId) { + Assert.hasText(conversationId, "conversationId cannot be null or empty"); + this.chatMemoryRepository.deleteById(conversationId); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private ChatMemoryRepository chatMemoryRepository = DEFAULT_CHAT_MEMORY_REPOSITORY; + + private int maxMessages = DEFAULT_MAX_MESSAGES; + + private MessageWindowProcessingPolicy messageWindowProcessingPolicy = DEFAULT_MESSAGE_WINDOW_EVICTION_POLICY; + + private Builder() { + } + + public Builder chatMemoryRepository(ChatMemoryRepository chatMemoryRepository) { + this.chatMemoryRepository = chatMemoryRepository; + return this; + } + + public Builder messageWindowEvictionPolicy(MessageWindowProcessingPolicy messageWindowProcessingPolicy) { + this.messageWindowProcessingPolicy = messageWindowProcessingPolicy; + return this; + } + + public Builder maxMessages(int maxMessages) { + this.maxMessages = maxMessages; + return this; + } + + public MessageWindowChatMemory build() { + return new MessageWindowChatMemory(chatMemoryRepository, messageWindowProcessingPolicy, maxMessages); + } + + } + +} diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/MessageWindowProcessingPolicy.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/MessageWindowProcessingPolicy.java new file mode 100644 index 00000000000..0c73a54ac80 --- /dev/null +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/MessageWindowProcessingPolicy.java @@ -0,0 +1,40 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.memory; + +import org.springframework.ai.chat.messages.Message; + +import java.util.List; + +/** + * A policy for processing a message window in a chat memory system. It defines a strategy + * for handling the addition of new messages to the existing message history, ensuring + * that the total number of messages does not exceed a specified limit. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public interface MessageWindowProcessingPolicy { + + /** + * Processes the message window by adding new messages to the existing history + * messages and ensuring that the total number of messages does not exceed the + * specified limit. + */ + List process(List historyMessages, List newMessages, int limit); + +} diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/package-info.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/package-info.java new file mode 100644 index 00000000000..2dd55ff2556 --- /dev/null +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/memory/package-info.java @@ -0,0 +1,22 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +@NonNullApi +@NonNullFields +package org.springframework.ai.chat.memory; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-ai-model/src/test/java/org/springframework/ai/chat/memory/DefaultMessageWindowProcessingPolicyTests.java b/spring-ai-model/src/test/java/org/springframework/ai/chat/memory/DefaultMessageWindowProcessingPolicyTests.java new file mode 100644 index 00000000000..a41566fef82 --- /dev/null +++ b/spring-ai-model/src/test/java/org/springframework/ai/chat/memory/DefaultMessageWindowProcessingPolicyTests.java @@ -0,0 +1,178 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.memory; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; + +import java.util.ArrayList; +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Unit tests for {@link DefaultMessageWindowProcessingPolicy}. + * + * @author Thomas Vitale + */ +public class DefaultMessageWindowProcessingPolicyTests { + + private final MessageWindowProcessingPolicy processingPolicy = new DefaultMessageWindowProcessingPolicy(); + + @Test + void noEvictionWhenMessagesWithinLimit() { + List historyMessages = new ArrayList<>( + List.of(new UserMessage("Hello"), new AssistantMessage("Hi there"))); + List newMessages = new ArrayList<>(List.of(new UserMessage("How are you?"))); + int limit = 3; + + List result = processingPolicy.process(historyMessages, newMessages, limit); + + assertThat(result).hasSize(3); + assertThat(result).containsExactly(new UserMessage("Hello"), new AssistantMessage("Hi there"), + new UserMessage("How are you?")); + } + + @Test + void evictionWhenMessagesExceedLimit() { + List historyMessages = new ArrayList<>( + List.of(new UserMessage("Message 1"), new AssistantMessage("Response 1"))); + List newMessages = new ArrayList<>( + List.of(new UserMessage("Message 2"), new AssistantMessage("Response 2"))); + int limit = 2; + + List result = processingPolicy.process(historyMessages, newMessages, limit); + + assertThat(result).hasSize(2); + assertThat(result).containsExactly(new UserMessage("Message 2"), new AssistantMessage("Response 2")); + } + + @Test + void systemMessageIsPreserved() { + List historyMessages = new ArrayList<>(List.of(new SystemMessage("System instruction"), + new UserMessage("Message 1"), new AssistantMessage("Response 1"))); + List newMessages = new ArrayList<>( + List.of(new UserMessage("Message 2"), new AssistantMessage("Response 2"))); + int limit = 3; + + List result = processingPolicy.process(historyMessages, newMessages, limit); + + assertThat(result).hasSize(3); + assertThat(result).containsExactly(new SystemMessage("System instruction"), new UserMessage("Message 2"), + new AssistantMessage("Response 2")); + } + + @Test + void multipleSystemMessagesArePreserved() { + List historyMessages = new ArrayList<>( + List.of(new SystemMessage("System instruction 1"), new SystemMessage("System instruction 2"), + new UserMessage("Message 1"), new AssistantMessage("Response 1"))); + List newMessages = new ArrayList<>( + List.of(new UserMessage("Message 2"), new AssistantMessage("Response 2"))); + int limit = 3; + + List result = processingPolicy.process(historyMessages, newMessages, limit); + + assertThat(result).hasSize(3); + assertThat(result).containsExactly(new SystemMessage("System instruction 1"), + new SystemMessage("System instruction 2"), new AssistantMessage("Response 2")); + } + + @Test + void emptyMessageList() { + List historyMessages = new ArrayList<>(); + List newMessages = new ArrayList<>(); + int limit = 5; + + List result = processingPolicy.process(historyMessages, newMessages, limit); + + assertThat(result).isEmpty(); + } + + @Test + void zeroLimitNotAllowed() { + List historyMessages = new ArrayList<>(List.of(new UserMessage("Message 1"))); + List newMessages = new ArrayList<>(List.of(new AssistantMessage("Response 1"))); + + assertThatThrownBy(() -> processingPolicy.process(historyMessages, newMessages, 0)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("limit must be greater than 0"); + } + + @Test + void negativeLimitNotAllowed() { + List historyMessages = new ArrayList<>(List.of(new UserMessage("Message 1"))); + List newMessages = new ArrayList<>(List.of(new AssistantMessage("Response 1"))); + + assertThatThrownBy(() -> processingPolicy.process(historyMessages, newMessages, -1)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("limit must be greater than 0"); + } + + @Test + void oldSystemMessagesAreRemovedEvenWithCountLessThanLimit() { + List historyMessages = new ArrayList<>( + List.of(new SystemMessage("System instruction 1"), new SystemMessage("System instruction 2"))); + List newMessages = new ArrayList<>(List.of(new SystemMessage("System instruction 3"))); + int limit = 2; + + List result = processingPolicy.process(historyMessages, newMessages, limit); + + // Old system messages are moved if a new one is provided, even if there's room in + // the window. + assertThat(result).hasSize(1); + assertThat(result).containsExactly(new SystemMessage("System instruction 3")); + } + + @Test + void mixedMessagesWithLimitEqualToSystemMessageCount() { + List historyMessages = new ArrayList<>( + List.of(new SystemMessage("System instruction 1"), new SystemMessage("System instruction 2"))); + List newMessages = new ArrayList<>( + List.of(new UserMessage("Message 1"), new AssistantMessage("Response 1"))); + int limit = 2; + + List result = processingPolicy.process(historyMessages, newMessages, limit); + + assertThat(result).hasSize(2); + assertThat(result).containsExactly(new SystemMessage("System instruction 1"), + new SystemMessage("System instruction 2")); + } + + @Test + void originalListIsNotModified() { + List historyMessages = new ArrayList<>( + List.of(new UserMessage("Message 1"), new AssistantMessage("Response 1"))); + List newMessages = new ArrayList<>( + List.of(new UserMessage("Message 2"), new AssistantMessage("Response 2"))); + List originalHistoryMessages = new ArrayList<>(historyMessages); + List originalNewMessages = new ArrayList<>(newMessages); + int limit = 4; + + List result = processingPolicy.process(historyMessages, newMessages, limit); + + assertThat(historyMessages).isEqualTo(originalHistoryMessages); + assertThat(newMessages).isEqualTo(originalNewMessages); + assertThat(result).isNotSameAs(historyMessages); + assertThat(result).isNotSameAs(newMessages); + } + +} diff --git a/spring-ai-model/src/test/java/org/springframework/ai/chat/memory/InMemoryChatMemoryRepositoryTests.java b/spring-ai-model/src/test/java/org/springframework/ai/chat/memory/InMemoryChatMemoryRepositoryTests.java new file mode 100644 index 00000000000..0343d625d44 --- /dev/null +++ b/spring-ai-model/src/test/java/org/springframework/ai/chat/memory/InMemoryChatMemoryRepositoryTests.java @@ -0,0 +1,137 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.memory; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; + +import java.util.ArrayList; +import java.util.List; +import java.util.UUID; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Unit tests for {@link InMemoryChatMemoryRepository}. + * + * @author Thomas Vitale + */ +public class InMemoryChatMemoryRepositoryTests { + + private final InMemoryChatMemoryRepository chatMemoryRepository = new InMemoryChatMemoryRepository(); + + @Test + void saveAndFindMultipleMessagesInConversation() { + String conversationId = UUID.randomUUID().toString(); + List messages = List.of(new AssistantMessage("I, Robot"), new UserMessage("Hello")); + + chatMemoryRepository.save(conversationId, messages); + + assertThat(chatMemoryRepository.findById(conversationId)).containsAll(messages); + + chatMemoryRepository.deleteById(conversationId); + + assertThat(chatMemoryRepository.findById(conversationId)).isEmpty(); + } + + @Test + void saveAndFindSingleMessageInConversation() { + String conversationId = UUID.randomUUID().toString(); + Message message = new UserMessage("Hello"); + List messages = List.of(message); + + chatMemoryRepository.save(conversationId, messages); + + assertThat(chatMemoryRepository.findById(conversationId)).contains(message); + + chatMemoryRepository.deleteById(conversationId); + + assertThat(chatMemoryRepository.findById(conversationId)).isEmpty(); + } + + @Test + void findNonExistingConversation() { + String conversationId = UUID.randomUUID().toString(); + + assertThat(chatMemoryRepository.findById(conversationId)).isEmpty(); + } + + @Test + void saveMultipleMessagesForSameConversation() { + String conversationId = UUID.randomUUID().toString(); + List firstMessages = List.of(new UserMessage("Hello")); + List secondMessages = List.of(new AssistantMessage("Hi there")); + + chatMemoryRepository.save(conversationId, firstMessages); + chatMemoryRepository.save(conversationId, secondMessages); + + List allMessages = new ArrayList<>(); + allMessages.addAll(firstMessages); + allMessages.addAll(secondMessages); + + assertThat(chatMemoryRepository.findById(conversationId)).containsExactlyElementsOf(allMessages); + } + + @Test + void nullConversationIdNotAllowed() { + assertThatThrownBy(() -> chatMemoryRepository.save(null, List.of(new UserMessage("Hello")))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("conversationId cannot be null or empty"); + + assertThatThrownBy(() -> chatMemoryRepository.findById(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("conversationId cannot be null or empty"); + + assertThatThrownBy(() -> chatMemoryRepository.deleteById(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("conversationId cannot be null or empty"); + } + + @Test + void emptyConversationIdNotAllowed() { + assertThatThrownBy(() -> chatMemoryRepository.save("", List.of(new UserMessage("Hello")))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("conversationId cannot be null or empty"); + + assertThatThrownBy(() -> chatMemoryRepository.findById("")).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("conversationId cannot be null or empty"); + + assertThatThrownBy(() -> chatMemoryRepository.deleteById("")).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("conversationId cannot be null or empty"); + } + + @Test + void nullMessagesNotAllowed() { + String conversationId = UUID.randomUUID().toString(); + assertThatThrownBy(() -> chatMemoryRepository.save(conversationId, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("messages cannot be null"); + } + + @Test + void messagesWithNullElementsNotAllowed() { + String conversationId = UUID.randomUUID().toString(); + List messagesWithNull = new ArrayList<>(); + messagesWithNull.add(null); + + assertThatThrownBy(() -> chatMemoryRepository.save(conversationId, messagesWithNull)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("messages cannot contain null elements"); + } + +} diff --git a/spring-ai-model/src/test/java/org/springframework/ai/chat/memory/MessageWindowChatMemoryTests.java b/spring-ai-model/src/test/java/org/springframework/ai/chat/memory/MessageWindowChatMemoryTests.java new file mode 100644 index 00000000000..af0fad95656 --- /dev/null +++ b/spring-ai-model/src/test/java/org/springframework/ai/chat/memory/MessageWindowChatMemoryTests.java @@ -0,0 +1,147 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.memory; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; + +import java.util.ArrayList; +import java.util.List; +import java.util.UUID; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Unit tests for {@link MessageWindowChatMemory}. + * + * @author Thomas Vitale + */ +public class MessageWindowChatMemoryTests { + + private final MessageWindowChatMemory chatMemory = MessageWindowChatMemory.builder().build(); + + @Test + void handleMultipleMessagesInConversation() { + String conversationId = UUID.randomUUID().toString(); + List messages = List.of(new AssistantMessage("I, Robot"), new UserMessage("Hello")); + + chatMemory.add(conversationId, messages); + + assertThat(chatMemory.get(conversationId)).containsAll(messages); + + chatMemory.clear(conversationId); + + assertThat(chatMemory.get(conversationId)).isEmpty(); + } + + @Test + void handleSingleMessageInConversation() { + String conversationId = UUID.randomUUID().toString(); + Message message = new UserMessage("Hello"); + + chatMemory.add(conversationId, message); + + assertThat(chatMemory.get(conversationId)).contains(message); + + chatMemory.clear(conversationId); + + assertThat(chatMemory.get(conversationId)).isEmpty(); + } + + @Test + void nullConversationIdNotAllowed() { + assertThatThrownBy(() -> chatMemory.add(null, List.of(new UserMessage("Hello")))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("conversationId cannot be null or empty"); + + assertThatThrownBy(() -> chatMemory.add(null, new UserMessage("Hello"))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("conversationId cannot be null or empty"); + + assertThatThrownBy(() -> chatMemory.get(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("conversationId cannot be null or empty"); + + assertThatThrownBy(() -> chatMemory.clear(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("conversationId cannot be null or empty"); + } + + @Test + void emptyConversationIdNotAllowed() { + assertThatThrownBy(() -> chatMemory.add("", List.of(new UserMessage("Hello")))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("conversationId cannot be null or empty"); + + assertThatThrownBy(() -> chatMemory.add(null, new UserMessage("Hello"))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("conversationId cannot be null or empty"); + + assertThatThrownBy(() -> chatMemory.get("")).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("conversationId cannot be null or empty"); + + assertThatThrownBy(() -> chatMemory.clear("")).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("conversationId cannot be null or empty"); + } + + @Test + void nullMessagesNotAllowed() { + String conversationId = UUID.randomUUID().toString(); + assertThatThrownBy(() -> chatMemory.add(conversationId, (List) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("messages cannot be null"); + } + + @Test + void nullMessageNotAllowed() { + String conversationId = UUID.randomUUID().toString(); + assertThatThrownBy(() -> chatMemory.add(conversationId, (Message) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("messages cannot contain null elements"); + } + + @Test + void messagesWithNullElementsNotAllowed() { + String conversationId = UUID.randomUUID().toString(); + List messagesWithNull = new ArrayList<>(); + messagesWithNull.add(null); + + assertThatThrownBy(() -> chatMemory.add(conversationId, messagesWithNull)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("messages cannot contain null elements"); + } + + @Test + void customMaxMessages() { + String conversationId = UUID.randomUUID().toString(); + int customMaxMessages = 2; + + MessageWindowChatMemory customChatMemory = MessageWindowChatMemory.builder() + .maxMessages(customMaxMessages) + .build(); + + List messages = List.of(new UserMessage("Message 1"), new AssistantMessage("Response 1"), + new UserMessage("Message 2"), new AssistantMessage("Response 2"), new UserMessage("Message 3")); + + customChatMemory.add(conversationId, messages); + List result = customChatMemory.get(conversationId); + + assertThat(result).hasSize(2); + } + +}