diff --git a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/pom.xml b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/pom.xml new file mode 100644 index 00000000000..19e4f0e7bb7 --- /dev/null +++ b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/pom.xml @@ -0,0 +1,71 @@ + + + 4.0.0 + + org.springframework.ai + spring-ai-parent + 1.0.0-SNAPSHOT + ../../../../../pom.xml + + spring-ai-autoconfigure-model-chat-memory-jdbc + jar + Spring AI JDBC Chat Memory Auto Configuration + Spring JDBC AI Chat Memory Auto Configuration + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + + + org.springframework.ai + spring-ai-model-chat-memory-jdbc + ${project.parent.version} + + + + + org.springframework.boot + spring-boot-starter + + + + org.springframework.boot + spring-boot-configuration-processor + true + + + + + org.springframework.boot + spring-boot-starter-test + test + + + + org.postgresql + postgresql + ${postgresql.version} + test + + + + org.testcontainers + junit-jupiter + test + + + + org.testcontainers + postgresql + test + + + + diff --git a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/main/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryAutoConfiguration.java b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/main/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryAutoConfiguration.java new file mode 100644 index 00000000000..bc811c3ded6 --- /dev/null +++ b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/main/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryAutoConfiguration.java @@ -0,0 +1,65 @@ +/* + * Copyright 2024-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.chat.memory.jdbc.autoconfigure; + +import javax.sql.DataSource; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.springframework.ai.chat.memory.jdbc.JdbcChatMemory; +import org.springframework.ai.chat.memory.jdbc.JdbcChatMemoryConfig; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.boot.autoconfigure.jdbc.JdbcTemplateAutoConfiguration; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.boot.jdbc.init.DataSourceScriptDatabaseInitializer; +import org.springframework.context.annotation.Bean; +import org.springframework.jdbc.core.JdbcTemplate; + +/** + * @author Jonathan Leijendekker + * @since 1.0.0 + */ +@AutoConfiguration(after = JdbcTemplateAutoConfiguration.class) +@ConditionalOnClass({ JdbcChatMemory.class, DataSource.class, JdbcTemplate.class }) +@EnableConfigurationProperties(JdbcChatMemoryProperties.class) +public class JdbcChatMemoryAutoConfiguration { + + private static final Logger logger = LoggerFactory.getLogger(JdbcChatMemoryAutoConfiguration.class); + + @Bean + @ConditionalOnMissingBean + public JdbcChatMemory chatMemory(JdbcTemplate jdbcTemplate) { + var config = JdbcChatMemoryConfig.builder().jdbcTemplate(jdbcTemplate).build(); + + return JdbcChatMemory.create(config); + } + + @Bean + @ConditionalOnMissingBean + @ConditionalOnProperty(value = "spring.ai.chat.memory.jdbc.initialize-schema", havingValue = "true", + matchIfMissing = true) + public DataSourceScriptDatabaseInitializer jdbcChatMemoryScriptDatabaseInitializer(DataSource dataSource) { + logger.debug("Initializing JdbcChatMemory schema"); + + return new JdbcChatMemoryDataSourceScriptDatabaseInitializer(dataSource); + } + +} diff --git a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/main/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryDataSourceScriptDatabaseInitializer.java b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/main/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryDataSourceScriptDatabaseInitializer.java new file mode 100644 index 00000000000..2f1927048df --- /dev/null +++ b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/main/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryDataSourceScriptDatabaseInitializer.java @@ -0,0 +1,35 @@ +package org.springframework.ai.model.chat.memory.jdbc.autoconfigure; + +import java.util.List; + +import javax.sql.DataSource; + +import org.springframework.boot.jdbc.init.DataSourceScriptDatabaseInitializer; +import org.springframework.boot.jdbc.init.PlatformPlaceholderDatabaseDriverResolver; +import org.springframework.boot.sql.init.DatabaseInitializationMode; +import org.springframework.boot.sql.init.DatabaseInitializationSettings; + +class JdbcChatMemoryDataSourceScriptDatabaseInitializer extends DataSourceScriptDatabaseInitializer { + + private static final String SCHEMA_LOCATION = "classpath:org/springframework/ai/chat/memory/jdbc/schema-@@platform@@.sql"; + + public JdbcChatMemoryDataSourceScriptDatabaseInitializer(DataSource dataSource) { + super(dataSource, getSettings(dataSource)); + } + + static DatabaseInitializationSettings getSettings(DataSource dataSource) { + var settings = new DatabaseInitializationSettings(); + settings.setSchemaLocations(resolveSchemaLocations(dataSource)); + settings.setMode(DatabaseInitializationMode.ALWAYS); + settings.setContinueOnError(true); + + return settings; + } + + static List resolveSchemaLocations(DataSource dataSource) { + var platformResolver = new PlatformPlaceholderDatabaseDriverResolver(); + + return platformResolver.resolveAll(dataSource, SCHEMA_LOCATION); + } + +} diff --git a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/main/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryProperties.java b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/main/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryProperties.java new file mode 100644 index 00000000000..1c33ffbb0a5 --- /dev/null +++ b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/main/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryProperties.java @@ -0,0 +1,40 @@ +/* + * Copyright 2024-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.chat.memory.jdbc.autoconfigure; + +import org.springframework.boot.context.properties.ConfigurationProperties; + +/** + * @author Jonathan Leijendekker + * @since 1.0.0 + */ +@ConfigurationProperties(JdbcChatMemoryProperties.CONFIG_PREFIX) +public class JdbcChatMemoryProperties { + + public static final String CONFIG_PREFIX = "spring.ai.chat.memory.jdbc"; + + private boolean initializeSchema = true; + + public boolean isInitializeSchema() { + return this.initializeSchema; + } + + public void setInitializeSchema(boolean initializeSchema) { + this.initializeSchema = initializeSchema; + } + +} diff --git a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports new file mode 100644 index 00000000000..6820c4237de --- /dev/null +++ b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports @@ -0,0 +1,16 @@ +# +# Copyright 2024-2025 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +org.springframework.ai.model.chat.memory.jdbc.autoconfigure.JdbcChatMemoryAutoConfiguration diff --git a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryAutoConfigurationIT.java b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryAutoConfigurationIT.java new file mode 100644 index 00000000000..6f9573d3eb0 --- /dev/null +++ b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryAutoConfigurationIT.java @@ -0,0 +1,101 @@ +/* + * Copyright 2024-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.chat.memory.jdbc.autoconfigure; + +import java.util.List; +import java.util.UUID; + +import org.junit.jupiter.api.Test; +import org.testcontainers.containers.PostgreSQLContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.utility.DockerImageName; + +import org.springframework.ai.chat.memory.jdbc.JdbcChatMemory; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.JdbcTemplateAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Jonathan Leijendekker + */ +@Testcontainers +class JdbcChatMemoryAutoConfigurationIT { + + static final DockerImageName DEFAULT_IMAGE_NAME = DockerImageName.parse("postgres:17"); + + @Container + @SuppressWarnings("resource") + static PostgreSQLContainer postgresContainer = new PostgreSQLContainer<>(DEFAULT_IMAGE_NAME) + .withDatabaseName("chat_memory_auto_configuration_test") + .withUsername("postgres") + .withPassword("postgres"); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(JdbcChatMemoryAutoConfiguration.class, + JdbcTemplateAutoConfiguration.class, DataSourceAutoConfiguration.class)) + .withPropertyValues(String.format("spring.datasource.url=%s", postgresContainer.getJdbcUrl()), + String.format("spring.datasource.username=%s", postgresContainer.getUsername()), + String.format("spring.datasource.password=%s", postgresContainer.getPassword())); + + @Test + void jdbcChatMemoryScriptDatabaseInitializer_shouldBeLoaded() { + this.contextRunner.withPropertyValues("spring.ai.chat.memory.jdbc.initialize-schema=true").run(context -> { + assertThat(context.containsBean("jdbcChatMemoryScriptDatabaseInitializer")).isTrue(); + }); + } + + @Test + void jdbcChatMemoryScriptDatabaseInitializer_shouldNotBeLoaded() { + this.contextRunner.withPropertyValues("spring.ai.chat.memory.jdbc.initialize-schema=false").run(context -> { + assertThat(context.containsBean("jdbcChatMemoryScriptDatabaseInitializer")).isFalse(); + }); + } + + @Test + void addGetAndClear_shouldAllExecute() { + this.contextRunner.withPropertyValues("spring.ai.chat.memory.jdbc.initialize-schema=true").run(context -> { + var chatMemory = context.getBean(JdbcChatMemory.class); + var conversationId = UUID.randomUUID().toString(); + var userMessage = new UserMessage("Message from the user"); + + chatMemory.add(conversationId, userMessage); + + assertThat(chatMemory.get(conversationId, Integer.MAX_VALUE)).hasSize(1); + assertThat(chatMemory.get(conversationId, Integer.MAX_VALUE)).isEqualTo(List.of(userMessage)); + + chatMemory.clear(conversationId); + + assertThat(chatMemory.get(conversationId, Integer.MAX_VALUE)).isEmpty(); + + var multipleMessages = List.of(new UserMessage("Message from the user 1"), + new AssistantMessage("Message from the assistant 1")); + + chatMemory.add(conversationId, multipleMessages); + + assertThat(chatMemory.get(conversationId, Integer.MAX_VALUE)).hasSize(multipleMessages.size()); + assertThat(chatMemory.get(conversationId, Integer.MAX_VALUE)).isEqualTo(multipleMessages); + }); + } + +} diff --git a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryDataSourceScriptDatabaseInitializerTests.java b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryDataSourceScriptDatabaseInitializerTests.java new file mode 100644 index 00000000000..bcb1a9daacc --- /dev/null +++ b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryDataSourceScriptDatabaseInitializerTests.java @@ -0,0 +1,51 @@ +package org.springframework.ai.model.chat.memory.jdbc.autoconfigure; + +import javax.sql.DataSource; + +import org.junit.jupiter.api.Test; +import org.testcontainers.containers.PostgreSQLContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.utility.DockerImageName; + +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.JdbcTemplateAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Jonathan Leijendekker + */ +@Testcontainers +class JdbcChatMemoryDataSourceScriptDatabaseInitializerTests { + + static final DockerImageName DEFAULT_IMAGE_NAME = DockerImageName.parse("postgres:17"); + + @Container + @SuppressWarnings("resource") + static PostgreSQLContainer postgresContainer = new PostgreSQLContainer<>(DEFAULT_IMAGE_NAME) + .withDatabaseName("chat_memory_initializer_test") + .withUsername("postgres") + .withPassword("postgres"); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(JdbcChatMemoryAutoConfiguration.class, + JdbcTemplateAutoConfiguration.class, DataSourceAutoConfiguration.class)) + .withPropertyValues(String.format("spring.datasource.url=%s", postgresContainer.getJdbcUrl()), + String.format("spring.datasource.username=%s", postgresContainer.getUsername()), + String.format("spring.datasource.password=%s", postgresContainer.getPassword())); + + @Test + void getSettings_shouldHaveSchemaLocations() { + this.contextRunner.run(context -> { + var dataSource = context.getBean(DataSource.class); + var settings = JdbcChatMemoryDataSourceScriptDatabaseInitializer.getSettings(dataSource); + + assertThat(settings.getSchemaLocations()) + .containsOnly("classpath:org/springframework/ai/chat/memory/jdbc/schema-postgresql.sql"); + }); + } + +} diff --git a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryPropertiesTests.java b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryPropertiesTests.java new file mode 100644 index 00000000000..196176149ff --- /dev/null +++ b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryPropertiesTests.java @@ -0,0 +1,43 @@ +/* + * Copyright 2024-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.chat.memory.jdbc.autoconfigure; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Jonathan Leijendekker + */ +class JdbcChatMemoryPropertiesTests { + + @Test + void defaultValues() { + var props = new JdbcChatMemoryProperties(); + + assertThat(props.isInitializeSchema()).isTrue(); + } + + @Test + void customValues() { + var props = new JdbcChatMemoryProperties(); + props.setInitializeSchema(false); + + assertThat(props.isInitializeSchema()).isFalse(); + } + +} diff --git a/memory/spring-ai-model-chat-memory-jdbc/README.md b/memory/spring-ai-model-chat-memory-jdbc/README.md new file mode 100644 index 00000000000..8e100ad20a3 --- /dev/null +++ b/memory/spring-ai-model-chat-memory-jdbc/README.md @@ -0,0 +1 @@ +[Chat Memory Documentation](https://docs.spring.io/spring-ai/reference/api/chatclient.html#_chat_memory) diff --git a/memory/spring-ai-model-chat-memory-jdbc/pom.xml b/memory/spring-ai-model-chat-memory-jdbc/pom.xml new file mode 100644 index 00000000000..3e7adb52392 --- /dev/null +++ b/memory/spring-ai-model-chat-memory-jdbc/pom.xml @@ -0,0 +1,95 @@ + + + + + 4.0.0 + + org.springframework.ai + spring-ai-parent + 1.0.0-SNAPSHOT + ../../pom.xml + + + spring-ai-model-chat-memory-jdbc + Spring AI JDBC Chat Memory + Spring AI JDBC Chat Memory implementation + + + + org.springframework.ai + spring-ai-client-chat + ${project.version} + + + + org.springframework + spring-jdbc + + + + com.zaxxer + HikariCP + + + + org.postgresql + postgresql + ${postgresql.version} + true + + + + org.mariadb.jdbc + mariadb-java-client + ${mariadb.version} + true + + + + + org.springframework.boot + spring-boot-starter-test + test + + + + org.testcontainers + testcontainers + test + + + + org.testcontainers + postgresql + test + + + + org.testcontainers + mariadb + test + + + + org.testcontainers + junit-jupiter + test + + + diff --git a/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemory.java b/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemory.java new file mode 100644 index 00000000000..477f7509a19 --- /dev/null +++ b/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemory.java @@ -0,0 +1,107 @@ +/* + * Copyright 2024-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.memory.jdbc; + +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.List; + +import org.springframework.ai.chat.memory.ChatMemory; +import org.springframework.ai.chat.messages.*; +import org.springframework.jdbc.core.BatchPreparedStatementSetter; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.jdbc.core.RowMapper; + +/** + * An implementation of {@link ChatMemory} for JDBC. Creating an instance of + * JdbcChatMemory example: + * JdbcChatMemory.create(JdbcChatMemoryConfig.builder().jdbcTemplate(jdbcTemplate).build()); + * + * @author Jonathan Leijendekker + * @since 1.0.0 + */ +public class JdbcChatMemory implements ChatMemory { + + private static final String QUERY_ADD = """ + INSERT INTO ai_chat_memory (conversation_id, content, type) VALUES (?, ?, ?)"""; + + private static final String QUERY_GET = """ + SELECT content, type FROM ai_chat_memory WHERE conversation_id = ? ORDER BY "timestamp" DESC LIMIT ?"""; + + private static final String QUERY_CLEAR = "DELETE FROM ai_chat_memory WHERE conversation_id = ?"; + + private final JdbcTemplate jdbcTemplate; + + public JdbcChatMemory(JdbcChatMemoryConfig config) { + this.jdbcTemplate = config.getJdbcTemplate(); + } + + public static JdbcChatMemory create(JdbcChatMemoryConfig config) { + return new JdbcChatMemory(config); + } + + @Override + public void add(String conversationId, List messages) { + this.jdbcTemplate.batchUpdate(QUERY_ADD, new AddBatchPreparedStatement(conversationId, messages)); + } + + @Override + public List get(String conversationId, int lastN) { + return this.jdbcTemplate.query(QUERY_GET, new MessageRowMapper(), conversationId, lastN); + } + + @Override + public void clear(String conversationId) { + this.jdbcTemplate.update(QUERY_CLEAR, conversationId); + } + + private record AddBatchPreparedStatement(String conversationId, + List messages) implements BatchPreparedStatementSetter { + @Override + public void setValues(PreparedStatement ps, int i) throws SQLException { + var message = this.messages.get(i); + + ps.setString(1, this.conversationId); + ps.setString(2, message.getText()); + ps.setString(3, message.getMessageType().name()); + } + + @Override + public int getBatchSize() { + return this.messages.size(); + } + } + + private static class MessageRowMapper implements RowMapper { + + @Override + public Message mapRow(ResultSet rs, int i) throws SQLException { + var content = rs.getString(1); + var type = MessageType.valueOf(rs.getString(2)); + + return switch (type) { + case USER -> new UserMessage(content); + case ASSISTANT -> new AssistantMessage(content); + case SYSTEM -> new SystemMessage(content); + default -> null; + }; + } + + } + +} diff --git a/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryConfig.java b/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryConfig.java new file mode 100644 index 00000000000..5a503aef051 --- /dev/null +++ b/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryConfig.java @@ -0,0 +1,66 @@ +/* + * Copyright 2024-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.memory.jdbc; + +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.util.Assert; + +/** + * Configuration for {@link JdbcChatMemory}. + * + * @author Jonathan Leijendekker + * @since 1.0.0 + */ +public final class JdbcChatMemoryConfig { + + private final JdbcTemplate jdbcTemplate; + + private JdbcChatMemoryConfig(Builder builder) { + this.jdbcTemplate = builder.jdbcTemplate; + } + + public static Builder builder() { + return new Builder(); + } + + JdbcTemplate getJdbcTemplate() { + return this.jdbcTemplate; + } + + public static final class Builder { + + private JdbcTemplate jdbcTemplate; + + private Builder() { + } + + public Builder jdbcTemplate(JdbcTemplate jdbcTemplate) { + Assert.notNull(jdbcTemplate, "jdbc template must not be null"); + + this.jdbcTemplate = jdbcTemplate; + return this; + } + + public JdbcChatMemoryConfig build() { + Assert.notNull(this.jdbcTemplate, "jdbc template must not be null"); + + return new JdbcChatMemoryConfig(this); + } + + } + +} diff --git a/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/aot/hint/JdbcChatMemoryRuntimeHints.java b/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/aot/hint/JdbcChatMemoryRuntimeHints.java new file mode 100644 index 00000000000..eae3206f9a9 --- /dev/null +++ b/memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/aot/hint/JdbcChatMemoryRuntimeHints.java @@ -0,0 +1,26 @@ +package org.springframework.ai.chat.memory.jdbc.aot.hint; + +import javax.sql.DataSource; + +import org.springframework.aot.hint.MemberCategory; +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.aot.hint.RuntimeHintsRegistrar; + +/** + * A {@link RuntimeHintsRegistrar} for JDBC Chat Memory hints + * + * @author Jonathan Leijendekker + */ +class JdbcChatMemoryRuntimeHints implements RuntimeHintsRegistrar { + + @Override + public void registerHints(RuntimeHints hints, ClassLoader classLoader) { + hints.reflection() + .registerType(DataSource.class, (hint) -> hint.withMembers(MemberCategory.INVOKE_DECLARED_METHODS)); + + hints.resources() + .registerPattern("org/springframework/ai/chat/memory/jdbc/schema-mariadb.sql") + .registerPattern("org/springframework/ai/chat/memory/jdbc/schema-postgresql.sql"); + } + +} diff --git a/memory/spring-ai-model-chat-memory-jdbc/src/main/resources/META-INF/spring/aot.factories b/memory/spring-ai-model-chat-memory-jdbc/src/main/resources/META-INF/spring/aot.factories new file mode 100644 index 00000000000..4b6f4a8f5ce --- /dev/null +++ b/memory/spring-ai-model-chat-memory-jdbc/src/main/resources/META-INF/spring/aot.factories @@ -0,0 +1,2 @@ +org.springframework.aot.hint.RuntimeHintsRegistrar=\ +org.springframework.ai.chat.memory.jdbc.aot.hint.JdbcChatMemoryRuntimeHints diff --git a/memory/spring-ai-model-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/jdbc/schema-mariadb.sql b/memory/spring-ai-model-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/jdbc/schema-mariadb.sql new file mode 100644 index 00000000000..88c0ea11ba0 --- /dev/null +++ b/memory/spring-ai-model-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/jdbc/schema-mariadb.sql @@ -0,0 +1,10 @@ +CREATE TABLE IF NOT EXISTS ai_chat_memory ( + conversation_id VARCHAR(36) NOT NULL, + content TEXT NOT NULL, + type VARCHAR(10) NOT NULL, + `timestamp` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + CONSTRAINT type_check CHECK (type IN ('USER', 'ASSISTANT', 'SYSTEM', 'TOOL')) +); + +CREATE INDEX IF NOT EXISTS ai_chat_memory_conversation_id_timestamp_idx +ON ai_chat_memory(conversation_id, `timestamp` DESC); diff --git a/memory/spring-ai-model-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/jdbc/schema-postgresql.sql b/memory/spring-ai-model-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/jdbc/schema-postgresql.sql new file mode 100644 index 00000000000..11e60194b60 --- /dev/null +++ b/memory/spring-ai-model-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/jdbc/schema-postgresql.sql @@ -0,0 +1,9 @@ +CREATE TABLE IF NOT EXISTS ai_chat_memory ( + conversation_id VARCHAR(36) NOT NULL, + content TEXT NOT NULL, + type VARCHAR(10) NOT NULL CHECK (type IN ('USER', 'ASSISTANT', 'SYSTEM', 'TOOL')), + "timestamp" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +CREATE INDEX IF NOT EXISTS ai_chat_memory_conversation_id_timestamp_idx +ON ai_chat_memory(conversation_id, "timestamp" DESC); diff --git a/memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryConfigTest.java b/memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryConfigTest.java new file mode 100644 index 00000000000..7ae2c4477a1 --- /dev/null +++ b/memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryConfigTest.java @@ -0,0 +1,34 @@ +package org.springframework.ai.chat.memory.jdbc; + +import org.junit.jupiter.api.Test; + +import org.springframework.jdbc.core.JdbcTemplate; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + +/** + * @author Jonathan Leijendekker + */ +class JdbcChatMemoryConfigTest { + + @Test + void setValues() { + var jdbcTemplate = mock(JdbcTemplate.class); + var config = JdbcChatMemoryConfig.builder().jdbcTemplate(jdbcTemplate).build(); + + assertThat(config.getJdbcTemplate()).isEqualTo(jdbcTemplate); + } + + @Test + void setJdbcTemplateToNull_shouldThrow() { + assertThatThrownBy(() -> JdbcChatMemoryConfig.builder().jdbcTemplate(null)); + } + + @Test + void buildWithNullJdbcTemplate_shouldThrow() { + assertThatThrownBy(() -> JdbcChatMemoryConfig.builder().build()); + } + +} diff --git a/memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryIT.java b/memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryIT.java new file mode 100644 index 00000000000..1651bd49e87 --- /dev/null +++ b/memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryIT.java @@ -0,0 +1,211 @@ +/* + * Copyright 2024-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.memory.jdbc; + +import java.sql.Timestamp; +import java.util.List; +import java.util.UUID; + +import javax.sql.DataSource; + +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.testcontainers.containers.PostgreSQLContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.utility.MountableFile; + +import org.springframework.ai.chat.memory.ChatMemory; +import org.springframework.ai.chat.messages.*; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.autoconfigure.EnableAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.DataSourceProperties; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Primary; +import org.springframework.jdbc.core.JdbcTemplate; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Jonathan Leijendekker + */ +@Testcontainers +class JdbcChatMemoryIT { + + @Container + @SuppressWarnings("resource") + static PostgreSQLContainer postgresContainer = new PostgreSQLContainer<>("postgres:17") + .withDatabaseName("chat_memory_test") + .withUsername("postgres") + .withPassword("postgres") + .withCopyFileToContainer( + MountableFile.forClasspathResource("org/springframework/ai/chat/memory/jdbc/schema-postgresql.sql"), + "/docker-entrypoint-initdb.d/schema.sql"); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(TestApplication.class) + .withPropertyValues(String.format("app.datasource.url=%s", postgresContainer.getJdbcUrl()), + String.format("app.datasource.username=%s", postgresContainer.getUsername()), + String.format("app.datasource.password=%s", postgresContainer.getPassword())); + + @BeforeAll + static void beforeAll() { + + } + + @Test + void correctChatMemoryInstance() { + this.contextRunner.run(context -> { + var chatMemory = context.getBean(ChatMemory.class); + + assertThat(chatMemory).isInstanceOf(JdbcChatMemory.class); + }); + } + + @ParameterizedTest + @CsvSource({ "Message from assistant,ASSISTANT", "Message from user,USER", "Message from system,SYSTEM" }) + void add_shouldInsertSingleMessage(String content, MessageType messageType) { + this.contextRunner.run(context -> { + var chatMemory = context.getBean(ChatMemory.class); + var conversationId = UUID.randomUUID().toString(); + var message = switch (messageType) { + case ASSISTANT -> new AssistantMessage(content + " - " + conversationId); + case USER -> new UserMessage(content + " - " + conversationId); + case SYSTEM -> new SystemMessage(content + " - " + conversationId); + default -> throw new IllegalArgumentException("Type not supported: " + messageType); + }; + + chatMemory.add(conversationId, message); + + var jdbcTemplate = context.getBean(JdbcTemplate.class); + var query = "SELECT conversation_id, content, type, \"timestamp\" FROM ai_chat_memory WHERE conversation_id = ?"; + var result = jdbcTemplate.queryForMap(query, conversationId); + + assertThat(result.size()).isEqualTo(4); + assertThat(result.get("conversation_id")).isEqualTo(conversationId); + assertThat(result.get("content")).isEqualTo(message.getText()); + assertThat(result.get("type")).isEqualTo(messageType.name()); + assertThat(result.get("timestamp")).isInstanceOf(Timestamp.class); + }); + } + + @Test + void add_shouldInsertMessages() { + this.contextRunner.run(context -> { + var chatMemory = context.getBean(ChatMemory.class); + var conversationId = UUID.randomUUID().toString(); + var messages = List.of(new AssistantMessage("Message from assistant - " + conversationId), + new UserMessage("Message from user - " + conversationId), + new SystemMessage("Message from system - " + conversationId)); + + chatMemory.add(conversationId, messages); + + var jdbcTemplate = context.getBean(JdbcTemplate.class); + var query = "SELECT conversation_id, content, type, \"timestamp\" FROM ai_chat_memory WHERE conversation_id = ?"; + var results = jdbcTemplate.queryForList(query, conversationId); + + assertThat(results.size()).isEqualTo(messages.size()); + + for (var i = 0; i < messages.size(); i++) { + var message = messages.get(i); + var result = results.get(i); + + assertThat(result.get("conversation_id")).isNotNull(); + assertThat(result.get("conversation_id")).isEqualTo(conversationId); + assertThat(result.get("content")).isEqualTo(message.getText()); + assertThat(result.get("type")).isEqualTo(message.getMessageType().name()); + assertThat(result.get("timestamp")).isInstanceOf(Timestamp.class); + } + }); + } + + @Test + void get_shouldReturnMessages() { + this.contextRunner.run(context -> { + var chatMemory = context.getBean(ChatMemory.class); + var conversationId = UUID.randomUUID().toString(); + var messages = List.of(new AssistantMessage("Message from assistant 1 - " + conversationId), + new AssistantMessage("Message from assistant 2 - " + conversationId), + new UserMessage("Message from user - " + conversationId), + new SystemMessage("Message from system - " + conversationId)); + + chatMemory.add(conversationId, messages); + + var results = chatMemory.get(conversationId, Integer.MAX_VALUE); + + assertThat(results.size()).isEqualTo(messages.size()); + assertThat(results).isEqualTo(messages); + }); + } + + @Test + void clear_shouldDeleteMessages() { + this.contextRunner.run(context -> { + var chatMemory = context.getBean(ChatMemory.class); + var conversationId = UUID.randomUUID().toString(); + var messages = List.of(new AssistantMessage("Message from assistant - " + conversationId), + new UserMessage("Message from user - " + conversationId), + new SystemMessage("Message from system - " + conversationId)); + + chatMemory.add(conversationId, messages); + + chatMemory.clear(conversationId); + + var jdbcTemplate = context.getBean(JdbcTemplate.class); + var count = jdbcTemplate.queryForObject("SELECT COUNT(*) FROM ai_chat_memory WHERE conversation_id = ?", + Integer.class, conversationId); + + assertThat(count).isZero(); + }); + } + + @SpringBootConfiguration + @EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class }) + static class TestApplication { + + @Bean + public ChatMemory chatMemory(JdbcTemplate jdbcTemplate) { + var config = JdbcChatMemoryConfig.builder().jdbcTemplate(jdbcTemplate).build(); + + return JdbcChatMemory.create(config); + } + + @Bean + public JdbcTemplate jdbcTemplate(DataSource dataSource) { + return new JdbcTemplate(dataSource); + } + + @Bean + @Primary + @ConfigurationProperties("app.datasource") + public DataSourceProperties dataSourceProperties() { + return new DataSourceProperties(); + } + + @Bean + public DataSource dataSource(DataSourceProperties dataSourceProperties) { + return dataSourceProperties.initializeDataSourceBuilder().build(); + } + + } + +} diff --git a/memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/aot/hint/JdbcChatMemoryRuntimeHintsTest.java b/memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/aot/hint/JdbcChatMemoryRuntimeHintsTest.java new file mode 100644 index 00000000000..1a6c7ff2d83 --- /dev/null +++ b/memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/aot/hint/JdbcChatMemoryRuntimeHintsTest.java @@ -0,0 +1,82 @@ +/* + * Copyright 2024-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.memory.jdbc.aot.hint; + +import java.io.IOException; +import java.util.Arrays; +import java.util.stream.Stream; + +import javax.sql.DataSource; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.aot.hint.RuntimeHintsRegistrar; +import org.springframework.aot.hint.predicate.RuntimeHintsPredicates; +import org.springframework.core.io.Resource; +import org.springframework.core.io.support.PathMatchingResourcePatternResolver; +import org.springframework.core.io.support.SpringFactoriesLoader; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Jonathan Leijendekker + */ +class JdbcChatMemoryRuntimeHintsTest { + + private final RuntimeHints hints = new RuntimeHints(); + + private final JdbcChatMemoryRuntimeHints jdbcChatMemoryRuntimeHints = new JdbcChatMemoryRuntimeHints(); + + @Test + void aotFactoriesContainsRegistrar() { + var match = SpringFactoriesLoader.forResourceLocation("META-INF/spring/aot.factories") + .load(RuntimeHintsRegistrar.class) + .stream() + .anyMatch((registrar) -> registrar instanceof JdbcChatMemoryRuntimeHints); + + assertThat(match).isTrue(); + } + + @ParameterizedTest + @MethodSource("getSchemaFileNames") + void jdbcSchemasHasHints(String schemaFileName) { + this.jdbcChatMemoryRuntimeHints.registerHints(this.hints, getClass().getClassLoader()); + + var predicate = RuntimeHintsPredicates.resource() + .forResource("org/springframework/ai/chat/memory/jdbc/" + schemaFileName); + + assertThat(predicate).accepts(this.hints); + } + + @Test + void dataSourceHasHints() { + this.jdbcChatMemoryRuntimeHints.registerHints(this.hints, getClass().getClassLoader()); + + assertThat(RuntimeHintsPredicates.reflection().onType(DataSource.class)).accepts(this.hints); + } + + private static Stream getSchemaFileNames() throws IOException { + var resources = new PathMatchingResourcePatternResolver() + .getResources("classpath*:org/springframework/ai/chat/memory/jdbc/schema-*.sql"); + + return Arrays.stream(resources).map(Resource::getFilename); + } + +} diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java index 2d36014a719..841c7771e31 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java @@ -16,12 +16,7 @@ package org.springframework.ai.anthropic; -import java.util.ArrayList; -import java.util.Base64; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Set; +import java.util.*; import java.util.stream.Collectors; import com.fasterxml.jackson.core.type.TypeReference; @@ -30,6 +25,7 @@ import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.messages.*; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; @@ -42,10 +38,7 @@ import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Source; import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Type; import org.springframework.ai.anthropic.api.AnthropicApi.Role; -import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.chat.messages.MessageType; -import org.springframework.ai.chat.messages.ToolResponseMessage; -import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.anthropic.api.AnthropicCacheType; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.metadata.DefaultUsage; @@ -256,50 +249,50 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha // @formatter:off Flux chatResponseFlux = response.flatMap(chatCompletionResponse -> { - AnthropicApi.Usage usage = chatCompletionResponse.usage(); - Usage currentChatResponseUsage = usage != null ? this.getDefaultUsage(chatCompletionResponse.usage()) : new EmptyUsage(); - Usage accumulatedUsage = UsageCalculator.getCumulativeUsage(currentChatResponseUsage, previousChatResponse); - ChatResponse chatResponse = toChatResponse(chatCompletionResponse, accumulatedUsage); - - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse)) { - - if (chatResponse.hasFinishReasons(Set.of("tool_use"))) { - // FIXME: bounded elastic needs to be used since tool calling - // is currently only synchronous - return Flux.deferContextual((ctx) -> { - // TODO: factor out the tool execution logic with setting context into a utility. - ToolExecutionResult toolExecutionResult; - try { - ToolCallReactiveContextHolder.setContext(ctx); - toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse); - } finally { - ToolCallReactiveContextHolder.clearContext(); + AnthropicApi.Usage usage = chatCompletionResponse.usage(); + Usage currentChatResponseUsage = usage != null ? this.getDefaultUsage(chatCompletionResponse.usage()) : new EmptyUsage(); + Usage accumulatedUsage = UsageCalculator.getCumulativeUsage(currentChatResponseUsage, previousChatResponse); + ChatResponse chatResponse = toChatResponse(chatCompletionResponse, accumulatedUsage); + + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse)) { + + if (chatResponse.hasFinishReasons(Set.of("tool_use"))) { + // FIXME: bounded elastic needs to be used since tool calling + // is currently only synchronous + return Flux.deferContextual((ctx) -> { + // TODO: factor out the tool execution logic with setting context into a utility. + ToolExecutionResult toolExecutionResult; + try { + ToolCallReactiveContextHolder.setContext(ctx); + toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse); + } finally { + ToolCallReactiveContextHolder.clearContext(); + } + if (toolExecutionResult.returnDirect()) { + // Return tool execution result directly to the client. + return Flux.just(ChatResponse.builder().from(chatResponse) + .generations(ToolExecutionResult.buildGenerations(toolExecutionResult)) + .build()); + } + else { + // Send the tool execution result back to the model. + return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), + chatResponse); + } + }).subscribeOn(Schedulers.boundedElastic()); + + } else { + return Mono.empty(); } - if (toolExecutionResult.returnDirect()) { - // Return tool execution result directly to the client. - return Flux.just(ChatResponse.builder().from(chatResponse) - .generations(ToolExecutionResult.buildGenerations(toolExecutionResult)) - .build()); - } - else { - // Send the tool execution result back to the model. - return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), - chatResponse); - } - }).subscribeOn(Schedulers.boundedElastic()); - } else { - return Mono.empty(); - } - - } else { - // If internal tool execution is not required, just return the chat response. - return Mono.just(chatResponse); - } - }) - .doOnError(observation::error) - .doFinally(s -> observation.stop()) - .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); + } else { + // If internal tool execution is not required, just return the chat response. + return Mono.just(chatResponse); + } + }) + .doOnError(observation::error) + .doFinally(s -> observation.stop()) + .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); // @formatter:on return new MessageAggregator().aggregate(chatResponseFlux, observationContext::setResponse); @@ -482,12 +475,35 @@ private Map mergeHttpHeaders(Map runtimeHttpHead ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { + List userMessagesList = prompt.getInstructions() + .stream() + .filter(message -> message.getMessageType() == MessageType.USER) + .toList(); + Message lastUserMessage = userMessagesList.isEmpty() ? null : userMessagesList.get(userMessagesList.size() - 1); + + List assistantMessageList = prompt.getInstructions() + .stream() + .filter(message -> message.getMessageType() == MessageType.ASSISTANT) + .toList(); + Message lastAssistantMessage = assistantMessageList.isEmpty() ? null + : assistantMessageList.get(assistantMessageList.size() - 1); + List userMessages = prompt.getInstructions() .stream() .filter(message -> message.getMessageType() != MessageType.SYSTEM) .map(message -> { + AbstractMessage abstractMessage = (AbstractMessage) message; if (message.getMessageType() == MessageType.USER) { - List contents = new ArrayList<>(List.of(new ContentBlock(message.getText()))); + List contents; + boolean isLastItem = message.equals(lastUserMessage); + if (isLastItem && abstractMessage.getCache() != null) { + AnthropicCacheType cacheType = AnthropicCacheType.valueOf(abstractMessage.getCache()); + contents = new ArrayList<>( + List.of(new ContentBlock(message.getText(), cacheType.cacheControl()))); + } + else { + contents = new ArrayList<>(List.of(new ContentBlock(message.getText()))); + } if (message instanceof UserMessage userMessage) { if (!CollectionUtils.isEmpty(userMessage.getMedia())) { List mediaContent = userMessage.getMedia().stream().map(media -> { @@ -503,8 +519,15 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { else if (message.getMessageType() == MessageType.ASSISTANT) { AssistantMessage assistantMessage = (AssistantMessage) message; List contentBlocks = new ArrayList<>(); + boolean isLastItem = message.equals(lastAssistantMessage); if (StringUtils.hasText(message.getText())) { - contentBlocks.add(new ContentBlock(message.getText())); + if (isLastItem && abstractMessage.getCache() != null) { + AnthropicCacheType cacheType = AnthropicCacheType.valueOf(abstractMessage.getCache()); + contentBlocks.add(new ContentBlock(message.getText(), cacheType.cacheControl())); + } + else { + contentBlocks.add(new ContentBlock(message.getText())); + } } if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) { for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) { @@ -543,6 +566,7 @@ else if (message.getMessageType() == MessageType.TOOL) { // Add the tool definitions to the request's tools parameter. List toolDefinitions = this.toolCallingManager.resolveToolDefinitions(requestOptions); if (!CollectionUtils.isEmpty(toolDefinitions)) { + var tool = getFunctionTools(toolDefinitions); request = ModelOptionsUtils.merge(request, this.defaultOptions, ChatCompletionRequest.class); request = ChatCompletionRequest.from(request).tools(getFunctionTools(toolDefinitions)).build(); } diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java index cf410690216..70f231a088f 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java @@ -24,6 +24,15 @@ import java.util.function.Consumer; import java.util.function.Predicate; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonSubTypes; +import com.fasterxml.jackson.annotation.JsonTypeInfo; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest.CacheControl; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.anthropic.api.StreamHelper.ChatCompletionResponseBuilder; @@ -46,14 +55,6 @@ import org.springframework.web.reactive.function.client.WebClient; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonInclude.Include; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.annotation.JsonSubTypes; -import com.fasterxml.jackson.annotation.JsonTypeInfo; - -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; /** * The Anthropic API client. @@ -94,6 +95,8 @@ public static Builder builder() { private static final String HEADER_ANTHROPIC_BETA = "anthropic-beta"; + public static final String BETA_PROMPT_CACHING = "prompt-caching-2024-07-31"; + private static final Predicate SSE_DONE_PREDICATE = "[DONE]"::equals; private final String completionsPath; @@ -172,14 +175,14 @@ public ResponseEntity chatCompletionEntity(ChatCompletio // @formatter:off return this.restClient.post() - .uri(this.completionsPath) - .headers(headers -> { - headers.addAll(additionalHttpHeader); - addDefaultHeadersIfMissing(headers); - }) - .body(chatRequest) - .retrieve() - .toEntity(ChatCompletionResponse.class); + .uri(this.completionsPath) + .headers(headers -> { + headers.addAll(additionalHttpHeader); + addDefaultHeadersIfMissing(headers); + }) + .body(chatRequest) + .retrieve() + .toEntity(ChatCompletionResponse.class); // @formatter:on } @@ -213,45 +216,45 @@ public Flux chatCompletionStream(ChatCompletionRequest c // @formatter:off return this.webClient.post() - .uri(this.completionsPath) - .headers(headers -> { - headers.addAll(additionalHttpHeader); - addDefaultHeadersIfMissing(headers); - }) // @formatter:off - .body(Mono.just(chatRequest), ChatCompletionRequest.class) - .retrieve() - .bodyToFlux(String.class) - .takeUntil(SSE_DONE_PREDICATE) - .filter(SSE_DONE_PREDICATE.negate()) - .map(content -> ModelOptionsUtils.jsonToObject(content, StreamEvent.class)) - .filter(event -> event.type() != EventType.PING) - // Detect if the chunk is part of a streaming function call. - .map(event -> { - - logger.debug("Received event: {}", event); - - if (this.streamHelper.isToolUseStart(event)) { - isInsideTool.set(true); - } - return event; - }) - // Group all chunks belonging to the same function call. - .windowUntil(event -> { - if (isInsideTool.get() && this.streamHelper.isToolUseFinish(event)) { - isInsideTool.set(false); - return true; - } - return !isInsideTool.get(); - }) - // Merging the window chunks into a single chunk. - .concatMapIterable(window -> { - Mono monoChunk = window.reduce(new ToolUseAggregationEvent(), - this.streamHelper::mergeToolUseEvents); - return List.of(monoChunk); - }) - .flatMap(mono -> mono) - .map(event -> this.streamHelper.eventToChatCompletionResponse(event, chatCompletionReference)) - .filter(chatCompletionResponse -> chatCompletionResponse.type() != null); + .uri(this.completionsPath) + .headers(headers -> { + headers.addAll(additionalHttpHeader); + addDefaultHeadersIfMissing(headers); + }) // @formatter:off + .body(Mono.just(chatRequest), ChatCompletionRequest.class) + .retrieve() + .bodyToFlux(String.class) + .takeUntil(SSE_DONE_PREDICATE) + .filter(SSE_DONE_PREDICATE.negate()) + .map(content -> ModelOptionsUtils.jsonToObject(content, StreamEvent.class)) + .filter(event -> event.type() != EventType.PING) + // Detect if the chunk is part of a streaming function call. + .map(event -> { + + logger.debug("Received event: {}", event); + + if (this.streamHelper.isToolUseStart(event)) { + isInsideTool.set(true); + } + return event; + }) + // Group all chunks belonging to the same function call. + .windowUntil(event -> { + if (isInsideTool.get() && this.streamHelper.isToolUseFinish(event)) { + isInsideTool.set(false); + return true; + } + return !isInsideTool.get(); + }) + // Merging the window chunks into a single chunk. + .concatMapIterable(window -> { + Mono monoChunk = window.reduce(new ToolUseAggregationEvent(), + this.streamHelper::mergeToolUseEvents); + return List.of(monoChunk); + }) + .flatMap(mono -> mono) + .map(event -> this.streamHelper.eventToChatCompletionResponse(event, chatCompletionReference)) + .filter(chatCompletionResponse -> chatCompletionResponse.type() != null); } private void addDefaultHeadersIfMissing(HttpHeaders headers) { @@ -358,7 +361,7 @@ public enum Role { // @formatter:off /** * The user role. - */ + */ @JsonProperty("user") USER, @@ -514,18 +517,18 @@ public interface StreamEvent { @JsonInclude(Include.NON_NULL) public record ChatCompletionRequest( // @formatter:off - @JsonProperty("model") String model, - @JsonProperty("messages") List messages, - @JsonProperty("system") String system, - @JsonProperty("max_tokens") Integer maxTokens, - @JsonProperty("metadata") Metadata metadata, - @JsonProperty("stop_sequences") List stopSequences, - @JsonProperty("stream") Boolean stream, - @JsonProperty("temperature") Double temperature, - @JsonProperty("top_p") Double topP, - @JsonProperty("top_k") Integer topK, - @JsonProperty("tools") List tools, - @JsonProperty("thinking") ThinkingConfig thinking) { + @JsonProperty("model") String model, + @JsonProperty("messages") List messages, + @JsonProperty("system") String system, + @JsonProperty("max_tokens") Integer maxTokens, + @JsonProperty("metadata") Metadata metadata, + @JsonProperty("stop_sequences") List stopSequences, + @JsonProperty("stream") Boolean stream, + @JsonProperty("temperature") Double temperature, + @JsonProperty("top_p") Double topP, + @JsonProperty("top_k") Integer topK, + @JsonProperty("tools") List tools, + @JsonProperty("thinking") ThinkingConfig thinking) { // @formatter:on public ChatCompletionRequest(String model, List messages, String system, Integer maxTokens, @@ -538,17 +541,7 @@ public ChatCompletionRequest(String model, List messages, Stri this(model, messages, system, maxTokens, null, stopSequences, stream, temperature, null, null, null, null); } - public static ChatCompletionRequestBuilder builder() { - return new ChatCompletionRequestBuilder(); - } - - public static ChatCompletionRequestBuilder from(ChatCompletionRequest request) { - return new ChatCompletionRequestBuilder(request); - } - /** - * Metadata about the request. - * * @param userId An external identifier for the user who is associated with the * request. This should be a uuid, hash value, or other opaque identifier. * Anthropic may use this id to help detect abuse. Do not include any identifying @@ -556,7 +549,22 @@ public static ChatCompletionRequestBuilder from(ChatCompletionRequest request) { */ @JsonInclude(Include.NON_NULL) public record Metadata(@JsonProperty("user_id") String userId) { + } + + /** + * @param type is the cache type supported by anthropic. Doc + */ + @JsonInclude(Include.NON_NULL) + public record CacheControl(String type) { + } + + public static ChatCompletionRequestBuilder builder() { + return new ChatCompletionRequestBuilder(); + } + public static ChatCompletionRequestBuilder from(ChatCompletionRequest request) { + return new ChatCompletionRequestBuilder(request); } /** @@ -719,8 +727,8 @@ public ChatCompletionRequest build() { @JsonInclude(Include.NON_NULL) public record AnthropicMessage( // @formatter:off - @JsonProperty("content") List content, - @JsonProperty("role") Role role) { + @JsonProperty("content") List content, + @JsonProperty("role") Role role) { // @formatter:on } @@ -744,29 +752,32 @@ public record AnthropicMessage( @JsonIgnoreProperties(ignoreUnknown = true) public record ContentBlock( // @formatter:off - @JsonProperty("type") Type type, - @JsonProperty("source") Source source, - @JsonProperty("text") String text, + @JsonProperty("type") Type type, + @JsonProperty("source") Source source, + @JsonProperty("text") String text, + + // applicable only for streaming responses. + @JsonProperty("index") Integer index, - // applicable only for streaming responses. - @JsonProperty("index") Integer index, + // tool_use response only + @JsonProperty("id") String id, + @JsonProperty("name") String name, + @JsonProperty("input") Map input, - // tool_use response only - @JsonProperty("id") String id, - @JsonProperty("name") String name, - @JsonProperty("input") Map input, + // tool_result response only + @JsonProperty("tool_use_id") String toolUseId, + @JsonProperty("content") String content, - // tool_result response only - @JsonProperty("tool_use_id") String toolUseId, - @JsonProperty("content") String content, + // cache object + @JsonProperty("cache_control") CacheControl cacheControl, - // Thinking only - @JsonProperty("signature") String signature, - @JsonProperty("thinking") String thinking, + // Thinking only + @JsonProperty("signature") String signature, + @JsonProperty("thinking") String thinking, - // Redacted Thinking only - @JsonProperty("data") String data - ) { + // Redacted Thinking only + @JsonProperty("data") String data + ) { // @formatter:on /** @@ -784,7 +795,7 @@ public ContentBlock(String mediaType, String data) { * @param source The source of the content. */ public ContentBlock(Type type, Source source) { - this(type, source, null, null, null, null, null, null, null, null, null, null); + this(type, source, null, null, null, null, null, null, null, null, null, null, null); } /** @@ -792,7 +803,7 @@ public ContentBlock(Type type, Source source) { * @param source The source of the content. */ public ContentBlock(Source source) { - this(Type.IMAGE, source, null, null, null, null, null, null, null, null, null, null); + this(Type.IMAGE, source, null, null, null, null, null, null, null, null, null, null, null); } /** @@ -800,7 +811,11 @@ public ContentBlock(Source source) { * @param text The text of the content. */ public ContentBlock(String text) { - this(Type.TEXT, null, text, null, null, null, null, null, null, null, null, null); + this(Type.TEXT, null, text, null, null, null, null, null, null, null, null, null, null); + } + + public ContentBlock(String text, CacheControl cache) { + this(Type.TEXT, null, text, null, null, null, null, null, null, cache, null, null, null); } // Tool result @@ -811,7 +826,7 @@ public ContentBlock(String text) { * @param content The content of the tool result. */ public ContentBlock(Type type, String toolUseId, String content) { - this(type, null, null, null, null, null, null, toolUseId, content, null, null, null); + this(type, null, null, null, null, null, null, toolUseId, content, null, null, null, null); } /** @@ -822,7 +837,7 @@ public ContentBlock(Type type, String toolUseId, String content) { * @param index The index of the content block. */ public ContentBlock(Type type, Source source, String text, Integer index) { - this(type, source, text, index, null, null, null, null, null, null, null, null); + this(type, source, text, index, null, null, null, null, null, null, null, null, null); } // Tool use input JSON delta streaming @@ -834,7 +849,7 @@ public ContentBlock(Type type, Source source, String text, Integer index) { * @param input The input of the tool use. */ public ContentBlock(Type type, String id, String name, Map input) { - this(type, null, null, null, id, name, input, null, null, null, null, null); + this(type, null, null, null, id, name, input, null, null, null, null, null, null); } /** @@ -940,10 +955,10 @@ public String getValue() { @JsonInclude(Include.NON_NULL) public record Source( // @formatter:off - @JsonProperty("type") String type, - @JsonProperty("media_type") String mediaType, - @JsonProperty("data") String data, - @JsonProperty("url") String url) { + @JsonProperty("type") String type, + @JsonProperty("media_type") String mediaType, + @JsonProperty("data") String data, + @JsonProperty("url") String url) { // @formatter:on /** @@ -977,9 +992,9 @@ public Source(String url) { @JsonInclude(Include.NON_NULL) public record Tool( // @formatter:off - @JsonProperty("name") String name, - @JsonProperty("description") String description, - @JsonProperty("input_schema") Map inputSchema) { + @JsonProperty("name") String name, + @JsonProperty("description") String description, + @JsonProperty("input_schema") Map inputSchema) { // @formatter:on } @@ -1004,14 +1019,14 @@ public record Tool( @JsonIgnoreProperties(ignoreUnknown = true) public record ChatCompletionResponse( // @formatter:off - @JsonProperty("id") String id, - @JsonProperty("type") String type, - @JsonProperty("role") Role role, - @JsonProperty("content") List content, - @JsonProperty("model") String model, - @JsonProperty("stop_reason") String stopReason, - @JsonProperty("stop_sequence") String stopSequence, - @JsonProperty("usage") Usage usage) { + @JsonProperty("id") String id, + @JsonProperty("type") String type, + @JsonProperty("role") Role role, + @JsonProperty("content") List content, + @JsonProperty("model") String model, + @JsonProperty("stop_reason") String stopReason, + @JsonProperty("stop_sequence") String stopSequence, + @JsonProperty("usage") Usage usage) { // @formatter:on } @@ -1027,17 +1042,19 @@ public record ChatCompletionResponse( @JsonIgnoreProperties(ignoreUnknown = true) public record Usage( // @formatter:off - @JsonProperty("input_tokens") Integer inputTokens, - @JsonProperty("output_tokens") Integer outputTokens) { + @JsonProperty("input_tokens") Integer inputTokens, + @JsonProperty("output_tokens") Integer outputTokens, + @JsonProperty("cache_creation_input_tokens") Integer cacheCreationInputTokens, + @JsonProperty("cache_read_input_tokens") Integer cacheReadInputTokens) { // @formatter:off } - /// ECB STOP + /// ECB STOP /** * Special event used to aggregate multiple tool use events into a single event with * list of aggregated ContentBlockToolUse. - */ + */ public static class ToolUseAggregationEvent implements StreamEvent { private Integer index; @@ -1056,17 +1073,17 @@ public EventType type() { } /** - * Get tool content blocks. - * @return The tool content blocks. - */ + * Get tool content blocks. + * @return The tool content blocks. + */ public List getToolContentBlocks() { return this.toolContentBlocks; } /** - * Check if the event is empty. - * @return True if the event is empty, false otherwise. - */ + * Check if the event is empty. + * @return True if the event is empty, false otherwise. + */ public boolean isEmpty() { return (this.index == null || this.id == null || this.name == null); } @@ -1109,25 +1126,25 @@ public String toString() { } - /////////////////////////////////////// - /// MESSAGE EVENTS - /////////////////////////////////////// + /////////////////////////////////////// + /// MESSAGE EVENTS + /////////////////////////////////////// - // MESSAGE START EVENT + // MESSAGE START EVENT /** * Content block start event. * @param type The event type. * @param index The index of the content block. * @param contentBlock The content block body. - */ + */ @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) public record ContentBlockStartEvent( // @formatter:off - @JsonProperty("type") EventType type, - @JsonProperty("index") Integer index, - @JsonProperty("content_block") ContentBlockBody contentBlock) implements StreamEvent { + @JsonProperty("type") EventType type, + @JsonProperty("index") Integer index, + @JsonProperty("content_block") ContentBlockBody contentBlock) implements StreamEvent { @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.EXISTING_PROPERTY, property = "type", visible = true) @@ -1141,31 +1158,31 @@ public interface ContentBlockBody { } /** - * Tool use content block. - * @param type The content block type. - * @param id The tool use id. - * @param name The tool use name. - * @param input The tool use input. - */ + * Tool use content block. + * @param type The content block type. + * @param id The tool use id. + * @param name The tool use name. + * @param input The tool use input. + */ @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) public record ContentBlockToolUse( - @JsonProperty("type") String type, - @JsonProperty("id") String id, - @JsonProperty("name") String name, - @JsonProperty("input") Map input) implements ContentBlockBody { + @JsonProperty("type") String type, + @JsonProperty("id") String id, + @JsonProperty("name") String name, + @JsonProperty("input") Map input) implements ContentBlockBody { } /** - * Text content block. - * @param type The content block type. - * @param text The text content. - */ + * Text content block. + * @param type The content block type. + * @param text The text content. + */ @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) public record ContentBlockText( - @JsonProperty("type") String type, - @JsonProperty("text") String text) implements ContentBlockBody { + @JsonProperty("type") String type, + @JsonProperty("text") String text) implements ContentBlockBody { } /** @@ -1175,11 +1192,11 @@ public record ContentBlockText( */ @JsonInclude(Include.NON_NULL) public record ContentBlockThinking( - @JsonProperty("type") String type, - @JsonProperty("thinking") String thinking, - @JsonProperty("signature") String signature) implements ContentBlockBody { + @JsonProperty("type") String type, + @JsonProperty("thinking") String thinking, + @JsonProperty("signature") String signature) implements ContentBlockBody { } - + } // @formatter:on @@ -1196,9 +1213,9 @@ public record ContentBlockThinking( @JsonIgnoreProperties(ignoreUnknown = true) public record ContentBlockDeltaEvent( // @formatter:off - @JsonProperty("type") EventType type, - @JsonProperty("index") Integer index, - @JsonProperty("delta") ContentBlockDeltaBody delta) implements StreamEvent { + @JsonProperty("type") EventType type, + @JsonProperty("index") Integer index, + @JsonProperty("delta") ContentBlockDeltaBody delta) implements StreamEvent { @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.EXISTING_PROPERTY, property = "type", visible = true) @@ -1215,24 +1232,24 @@ public interface ContentBlockDeltaBody { * Text content block delta. * @param type The content block type. * @param text The text content. - */ + */ @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) public record ContentBlockDeltaText( - @JsonProperty("type") String type, - @JsonProperty("text") String text) implements ContentBlockDeltaBody { + @JsonProperty("type") String type, + @JsonProperty("text") String text) implements ContentBlockDeltaBody { } /** - * JSON content block delta. - * @param type The content block type. - * @param partialJson The partial JSON content. - */ + * JSON content block delta. + * @param type The content block type. + * @param partialJson The partial JSON content. + */ @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) public record ContentBlockDeltaJson( - @JsonProperty("type") String type, - @JsonProperty("partial_json") String partialJson) implements ContentBlockDeltaBody { + @JsonProperty("type") String type, + @JsonProperty("partial_json") String partialJson) implements ContentBlockDeltaBody { } /** @@ -1243,8 +1260,8 @@ public record ContentBlockDeltaJson( @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) public record ContentBlockDeltaThinking( - @JsonProperty("type") String type, - @JsonProperty("thinking") String thinking) implements ContentBlockDeltaBody { + @JsonProperty("type") String type, + @JsonProperty("thinking") String thinking) implements ContentBlockDeltaBody { } /** @@ -1255,8 +1272,8 @@ public record ContentBlockDeltaThinking( @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) public record ContentBlockDeltaSignature( - @JsonProperty("type") String type, - @JsonProperty("signature") String signature) implements ContentBlockDeltaBody { + @JsonProperty("type") String type, + @JsonProperty("signature") String signature) implements ContentBlockDeltaBody { } } // @formatter:on @@ -1273,8 +1290,8 @@ public record ContentBlockDeltaSignature( @JsonIgnoreProperties(ignoreUnknown = true) public record ContentBlockStopEvent( // @formatter:off - @JsonProperty("type") EventType type, - @JsonProperty("index") Integer index) implements StreamEvent { + @JsonProperty("type") EventType type, + @JsonProperty("index") Integer index) implements StreamEvent { } // @formatter:on @@ -1287,8 +1304,8 @@ public record ContentBlockStopEvent( @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) public record MessageStartEvent(// @formatter:off - @JsonProperty("type") EventType type, - @JsonProperty("message") ChatCompletionResponse message) implements StreamEvent { + @JsonProperty("type") EventType type, + @JsonProperty("message") ChatCompletionResponse message) implements StreamEvent { } // @formatter:on @@ -1303,29 +1320,29 @@ public record MessageStartEvent(// @formatter:off @JsonIgnoreProperties(ignoreUnknown = true) public record MessageDeltaEvent( // @formatter:off - @JsonProperty("type") EventType type, - @JsonProperty("delta") MessageDelta delta, - @JsonProperty("usage") MessageDeltaUsage usage) implements StreamEvent { + @JsonProperty("type") EventType type, + @JsonProperty("delta") MessageDelta delta, + @JsonProperty("usage") MessageDeltaUsage usage) implements StreamEvent { /** - * @param stopReason The stop reason. - * @param stopSequence The stop sequence. - */ + * @param stopReason The stop reason. + * @param stopSequence The stop sequence. + */ @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) public record MessageDelta( - @JsonProperty("stop_reason") String stopReason, - @JsonProperty("stop_sequence") String stopSequence) { + @JsonProperty("stop_reason") String stopReason, + @JsonProperty("stop_sequence") String stopSequence) { } /** * Message delta usage. * @param outputTokens The output tokens. - */ + */ @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) public record MessageDeltaUsage( - @JsonProperty("output_tokens") Integer outputTokens) { + @JsonProperty("output_tokens") Integer outputTokens) { } } // @formatter:on @@ -1339,7 +1356,7 @@ public record MessageDeltaUsage( @JsonIgnoreProperties(ignoreUnknown = true) public record MessageStopEvent( //@formatter:off - @JsonProperty("type") EventType type) implements StreamEvent { + @JsonProperty("type") EventType type) implements StreamEvent { } // @formatter:on @@ -1356,19 +1373,19 @@ public record MessageStopEvent( @JsonIgnoreProperties(ignoreUnknown = true) public record ErrorEvent( // @formatter:off - @JsonProperty("type") EventType type, - @JsonProperty("error") Error error) implements StreamEvent { + @JsonProperty("type") EventType type, + @JsonProperty("error") Error error) implements StreamEvent { /** * Error body. * @param type The error type. * @param message The error message. - */ + */ @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) public record Error( - @JsonProperty("type") String type, - @JsonProperty("message") String message) { + @JsonProperty("type") String type, + @JsonProperty("message") String message) { } } // @formatter:on @@ -1385,7 +1402,7 @@ public record Error( @JsonIgnoreProperties(ignoreUnknown = true) public record PingEvent( // @formatter:off - @JsonProperty("type") EventType type) implements StreamEvent { + @JsonProperty("type") EventType type) implements StreamEvent { } // @formatter:on diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicCacheType.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicCacheType.java new file mode 100644 index 00000000000..06a756be42f --- /dev/null +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicCacheType.java @@ -0,0 +1,21 @@ +package org.springframework.ai.anthropic.api; + +import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest.CacheControl; + +import java.util.function.Supplier; + +public enum AnthropicCacheType { + + EPHEMERAL(() -> new CacheControl("ephemeral")); + + private Supplier value; + + AnthropicCacheType(Supplier value) { + this.value = value; + } + + public CacheControl cacheControl() { + return this.value.get(); + } + +} diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java index f636f29a158..8366ca0b712 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java @@ -159,7 +159,7 @@ else if (event.type().equals(EventType.CONTENT_BLOCK_START)) { } else if (contentBlockStartEvent.contentBlock() instanceof ContentBlockThinking thinkingBlock) { ContentBlock cb = new ContentBlock(Type.THINKING, null, null, contentBlockStartEvent.index(), null, - null, null, null, null, null, thinkingBlock.thinking(), null); + null, null, null, null, null, null, thinkingBlock.thinking(), null); contentBlockReference.get().withType(event.type().name()).withContent(List.of(cb)); } else { @@ -176,12 +176,12 @@ else if (event.type().equals(EventType.CONTENT_BLOCK_DELTA)) { } else if (contentBlockDeltaEvent.delta() instanceof ContentBlockDeltaThinking thinking) { ContentBlock cb = new ContentBlock(Type.THINKING_DELTA, null, null, contentBlockDeltaEvent.index(), - null, null, null, null, null, null, thinking.thinking(), null); + null, null, null, null, null, null, null, thinking.thinking(), null); contentBlockReference.get().withType(event.type().name()).withContent(List.of(cb)); } else if (contentBlockDeltaEvent.delta() instanceof ContentBlockDeltaSignature sig) { ContentBlock cb = new ContentBlock(Type.SIGNATURE_DELTA, null, null, contentBlockDeltaEvent.index(), - null, null, null, null, null, sig.signature(), null, null); + null, null, null, null, null, null, sig.signature(), null, null); contentBlockReference.get().withType(event.type().name()).withContent(List.of(cb)); } else { @@ -204,8 +204,10 @@ else if (event.type().equals(EventType.MESSAGE_DELTA)) { } if (messageDeltaEvent.usage() != null) { - Usage totalUsage = new Usage(contentBlockReference.get().usage.inputTokens(), - messageDeltaEvent.usage().outputTokens()); + var totalUsage = new Usage(contentBlockReference.get().usage.inputTokens(), + messageDeltaEvent.usage().outputTokens(), + contentBlockReference.get().usage.cacheCreationInputTokens(), + contentBlockReference.get().usage.cacheReadInputTokens()); contentBlockReference.get().withUsage(totalUsage); } } diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java index c78386fb7ce..33301cd7573 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java @@ -18,6 +18,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Random; import java.util.stream.Collectors; import org.junit.jupiter.api.Test; @@ -37,6 +38,9 @@ import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; +import org.springframework.web.client.RestClient; +import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.publisher.Flux; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -70,6 +74,67 @@ public class AnthropicApiIT { } """))); + @Test + void chatWithPromptCacheInAssistantMessage() { + String assistantMessageText = "It could be either a contraction of the full title Quenta Silmarillion (\"Tale of the Silmarils\") or also a plain Genitive which " + + "(as in Ancient Greek) signifies reference. This genitive is translated in English with \"about\" or \"of\" " + + "constructions; the titles of the chapters in The Silmarillion are examples of this genitive in poetic English " + + "(Of the Sindar, Of Men, Of the Darkening of Valinor etc), where \"of\" means \"about\" or \"concerning\". " + + "In the same way, Silmarillion can be taken to mean \"Of/About the Silmarils\""; + + AnthropicMessage chatCompletionMessage = new AnthropicMessage( + List.of(new ContentBlock(assistantMessageText.repeat(20), AnthropicCacheType.EPHEMERAL.cacheControl())), + Role.ASSISTANT); + + ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder() + .model(AnthropicApi.ChatModel.CLAUDE_3_5_HAIKU) + .messages(List.of(chatCompletionMessage)) + .maxTokens(1500) + .temperature(0.8) + .stream(false) + .build(); + + AnthropicApi.Usage createdCacheToken = anthropicApi.chatCompletionEntity(chatCompletionRequest) + .getBody() + .usage(); + + assertThat(createdCacheToken.cacheCreationInputTokens()).isGreaterThan(0); + assertThat(createdCacheToken.cacheReadInputTokens()).isEqualTo(0); + + AnthropicApi.Usage readCacheToken = anthropicApi.chatCompletionEntity(chatCompletionRequest).getBody().usage(); + + assertThat(readCacheToken.cacheCreationInputTokens()).isEqualTo(0); + assertThat(readCacheToken.cacheReadInputTokens()).isGreaterThan(0); + } + + @Test + void chatWithPromptCache() { + String userMessageText = "It could be either a contraction of the full title Quenta Silmarillion (\"Tale of the Silmarils\") or also a plain Genitive which " + + "(as in Ancient Greek) signifies reference. This genitive is translated in English with \"about\" or \"of\" " + + "constructions; the titles of the chapters in The Silmarillion are examples of this genitive in poetic English " + + "(Of the Sindar, Of Men, Of the Darkening of Valinor etc), where \"of\" means \"about\" or \"concerning\". " + + "In the same way, Silmarillion can be taken to mean \"Of/About the Silmarils\""; + + AnthropicMessage chatCompletionMessage = new AnthropicMessage( + List.of(new ContentBlock(userMessageText.repeat(20), AnthropicCacheType.EPHEMERAL.cacheControl())), + Role.USER); + + ChatCompletionRequest chatCompletionRequest = new ChatCompletionRequest( + AnthropicApi.ChatModel.CLAUDE_3_HAIKU.getValue(), List.of(chatCompletionMessage), null, 100, 0.8, + false); + AnthropicApi.Usage createdCacheToken = anthropicApi.chatCompletionEntity(chatCompletionRequest) + .getBody() + .usage(); + + assertThat(createdCacheToken.cacheCreationInputTokens()).isGreaterThan(0); + assertThat(createdCacheToken.cacheReadInputTokens()).isEqualTo(0); + + AnthropicApi.Usage readCacheToken = anthropicApi.chatCompletionEntity(chatCompletionRequest).getBody().usage(); + + assertThat(readCacheToken.cacheCreationInputTokens()).isEqualTo(0); + assertThat(readCacheToken.cacheReadInputTokens()).isGreaterThan(0); + } + @Test void chatCompletionEntity() { diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc index 34beeaba557..d4864e38730 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc @@ -683,6 +683,10 @@ xref:ROOT:api/retrieval-augmented-generation.adoc#_questionansweradvisor[Learn a Refer to the xref:ROOT:api/retrieval-augmented-generation.adoc[Retrieval Augmented Generation] guide. +=== Retrieval Augmented Generation + +Refer to the xref:ROOT:api/retrieval-augmented-generation.adoc[Retrieval Augmented Generation] guide. + === Logging The `SimpleLoggerAdvisor` is an advisor that logs the `request` and `response` data of the `ChatClient`. diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java index 6e37fd7548b..72f6dac9748 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/AbstractMessage.java @@ -53,19 +53,26 @@ public abstract class AbstractMessage implements Message { @Nullable protected final String textContent; + protected String cache; + /** * Additional options for the message to influence the response, not a generative map. */ protected final Map metadata; - /** - * Create a new AbstractMessage with the given message type, text content, and - * metadata. - * @param messageType the message type - * @param textContent the text content - * @param metadata the metadata - */ - protected AbstractMessage(MessageType messageType, @Nullable String textContent, Map metadata) { + protected AbstractMessage(MessageType messageType, String textContent, Map metadata, String cache) { + Assert.notNull(messageType, "Message type must not be null"); + if (messageType == MessageType.SYSTEM || messageType == MessageType.USER) { + Assert.notNull(textContent, "Content must not be null for SYSTEM or USER messages"); + } + this.messageType = messageType; + this.textContent = textContent; + this.metadata = new HashMap<>(metadata); + this.metadata.put(MESSAGE_TYPE, messageType); + this.cache = cache; + } + + protected AbstractMessage(MessageType messageType, String textContent, Map metadata) { Assert.notNull(messageType, "Message type must not be null"); if (messageType == MessageType.SYSTEM || messageType == MessageType.USER) { Assert.notNull(textContent, "Content must not be null for SYSTEM or USER messages"); @@ -98,6 +105,20 @@ protected AbstractMessage(MessageType messageType, Resource resource, Map metadata, String cache) { + Assert.notNull(resource, "Resource must not be null"); + try (InputStream inputStream = resource.getInputStream()) { + this.textContent = StreamUtils.copyToString(inputStream, Charset.defaultCharset()); + } + catch (IOException ex) { + throw new RuntimeException("Failed to read resource", ex); + } + this.messageType = messageType; + this.metadata = new HashMap<>(metadata); + this.metadata.put(MESSAGE_TYPE, messageType); + this.cache = cache; + } + /** * Get the content of the message. * @return the content of the message @@ -126,6 +147,10 @@ public MessageType getMessageType() { return this.messageType; } + public String getCache() { + return cache; + } + @Override public boolean equals(Object o) { if (this == o) { diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/UserMessage.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/UserMessage.java index fc005392c34..4c4546925e9 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/UserMessage.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/messages/UserMessage.java @@ -40,6 +40,17 @@ public class UserMessage extends AbstractMessage implements MediaContent { protected final List media; + public UserMessage(String textContent, String cache) { + this(MessageType.USER, textContent, new ArrayList<>(), Map.of(), cache); + } + + public UserMessage(MessageType messageType, String textContent, Collection media, + Map metadata, String cache) { + super(messageType, textContent, metadata, cache); + Assert.notNull(media, "media data must not be null"); + this.media = new ArrayList<>(media); + } + public UserMessage(String textContent) { this(textContent, new ArrayList<>(), Map.of()); } diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-model-chat-memory-jdbc/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-model-chat-memory-jdbc/pom.xml new file mode 100644 index 00000000000..112be3cc248 --- /dev/null +++ b/spring-ai-spring-boot-starters/spring-ai-starter-model-chat-memory-jdbc/pom.xml @@ -0,0 +1,58 @@ + + + + + 4.0.0 + + org.springframework.ai + spring-ai-parent + 1.0.0-SNAPSHOT + ../../pom.xml + + spring-ai-starter-model-chat-memory-jdbc + jar + Spring AI Starter - JDBC Chat Memory + Spring AI JDBC Chat Memory Starter + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + + + org.springframework.boot + spring-boot-starter + + + + org.springframework.ai + spring-ai-autoconfigure-model-chat-memory-jdbc + ${project.parent.version} + + + + org.springframework.ai + spring-ai-model-chat-memory-jdbc + ${project.parent.version} + + + +