Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@
import java.util.Map;
import java.util.stream.Collectors;

import reactor.core.scheduler.Scheduler;
import reactor.core.scheduler.Schedulers;

import org.springframework.ai.chat.client.ChatClientRequest;
import org.springframework.ai.chat.client.ChatClientResponse;
import org.springframework.ai.chat.client.advisor.api.AdvisorChain;
import org.springframework.ai.chat.client.advisor.api.BaseAdvisor;
import org.springframework.ai.chat.messages.UserMessage;
import reactor.core.scheduler.Scheduler;
import reactor.core.scheduler.Schedulers;

import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.document.Document;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import java.util.List;
import java.util.Map;

import org.springframework.util.Assert;
import reactor.core.scheduler.Scheduler;

import org.springframework.ai.chat.client.ChatClientRequest;
Expand All @@ -38,6 +37,7 @@
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.util.Assert;

/**
* Memory is retrieved from a VectorStore added into the prompt's system text.
Expand All @@ -50,7 +50,7 @@
* @author Mark Pollack
* @since 1.0.0
*/
public class VectorStoreChatMemoryAdvisor implements BaseChatMemoryAdvisor {
public final class VectorStoreChatMemoryAdvisor implements BaseChatMemoryAdvisor {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this by design that we don't want to allow any customization?


public static final String TOP_K = "chat_memory_vector_store_top_k";

Expand Down Expand Up @@ -104,7 +104,7 @@ public static Builder builder(VectorStore chatMemory) {

@Override
public int getOrder() {
return order;
return this.order;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,24 @@
/*
* Copyright 2025-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.chat.client.advisor.vectorstore;

import org.junit.jupiter.api.Test;
import org.mockito.Mockito;

import org.springframework.ai.vectorstore.VectorStore;

import static org.assertj.core.api.Assertions.assertThatThrownBy;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,33 +109,6 @@ public enum ClientType {
*/
private Toolcallback toolcallback = new Toolcallback();

/**
* Represents a callback configuration for tools.
* <p>
* This record is used to encapsulate the configuration for enabling or disabling tool
* callbacks in the MCP client.
*
* @param enabled A boolean flag indicating whether the tool callback is enabled. If
* true, the tool callback is active; otherwise, it is disabled.
*/
public static class Toolcallback {

/**
* A boolean flag indicating whether the tool callback is enabled. If true, the
* tool callback is active; otherwise, it is disabled.
*/
private boolean enabled = true;

public void setEnabled(boolean enabled) {
this.enabled = enabled;
}

public boolean isEnabled() {
return this.enabled;
}

}

public boolean isEnabled() {
return this.enabled;
}
Expand Down Expand Up @@ -193,11 +166,38 @@ public void setRootChangeNotification(boolean rootChangeNotification) {
}

public Toolcallback getToolcallback() {
return toolcallback;
return this.toolcallback;
}

public void setToolcallback(Toolcallback toolcallback) {
this.toolcallback = toolcallback;
}

/**
* Represents a callback configuration for tools.
* <p>
* This record is used to encapsulate the configuration for enabling or disabling tool
* callbacks in the MCP client.
*
* @param enabled A boolean flag indicating whether the tool callback is enabled. If
* true, the tool callback is active; otherwise, it is disabled.
*/
public static class Toolcallback {

/**
* A boolean flag indicating whether the tool callback is enabled. If true, the
* tool callback is active; otherwise, it is disabled.
*/
private boolean enabled = true;

public void setEnabled(boolean enabled) {
this.enabled = enabled;
}

public boolean isEnabled() {
return this.enabled;
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@

package org.springframework.ai.mcp.client.autoconfigure.properties;

import java.time.Duration;

import org.junit.jupiter.api.Test;

import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
import org.springframework.context.annotation.Configuration;

import java.time.Duration;

import static org.assertj.core.api.Assertions.assertThat;

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@

package org.springframework.ai.mcp.client.autoconfigure.properties;

import java.util.Map;

import org.junit.jupiter.api.Test;

import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
import org.springframework.context.annotation.Configuration;

import java.util.Map;

import static org.assertj.core.api.Assertions.assertThat;

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,4 +351,4 @@ public McpAsyncServer mcpAsyncServer(McpServerTransportProvider transportProvide
return serverBuilder.build();
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,21 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.mcp.server.autoconfigure;

import static org.assertj.core.api.Assertions.assertThat;
package org.springframework.ai.mcp.server.autoconfigure;

import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider;
import org.junit.jupiter.api.Test;

import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.autoconfigure.jackson.JacksonAutoConfiguration;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.reactive.function.server.RouterFunction;

import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider;
import static org.assertj.core.api.Assertions.assertThat;

class McpWebFluxServerAutoConfigurationTests {

Expand All @@ -36,7 +37,7 @@ class McpWebFluxServerAutoConfigurationTests {

@Test
void shouldConfigureWebFluxTransportWithCustomObjectMapper() {
this.contextRunner.run((context) -> {
this.contextRunner.run(context -> {
assertThat(context).hasSingleBean(WebFluxSseServerTransportProvider.class);
assertThat(context).hasSingleBean(RouterFunction.class);
assertThat(context).hasSingleBean(McpServerProperties.class);
Expand All @@ -48,13 +49,15 @@ void shouldConfigureWebFluxTransportWithCustomObjectMapper() {
.isEnabled(com.fasterxml.jackson.databind.DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES)).isFalse();

// Test with a JSON payload containing unknown fields
// CHECKSTYLE:OFF
String jsonWithUnknownField = """
{
"tools": ["tool1", "tool2"],
"name": "test",
"unknownField": "value"
}
""";
// CHECKSTYLE:ON

// This should not throw an exception
TestMessage message = objectMapper.readValue(jsonWithUnknownField, TestMessage.class);
Expand All @@ -75,7 +78,7 @@ static class TestMessage {
private String name;

public String getName() {
return name;
return this.name;
}

public void setName(String name) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import io.micrometer.tracing.Tracer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.ChatClientCustomizer;
import org.springframework.ai.chat.client.observation.ChatClientObservationContext;
Expand All @@ -30,7 +31,11 @@
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
import org.springframework.boot.autoconfigure.condition.*;
import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import io.micrometer.tracing.Tracer;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;

import org.springframework.ai.chat.client.observation.ChatClientObservationContext;
import org.springframework.ai.chat.client.observation.ChatClientPromptContentObservationHandler;
import org.springframework.ai.observation.TracingAwareLoggingObservationHandler;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

import com.datastax.oss.driver.api.core.CqlSession;

import org.springframework.ai.chat.memory.repository.cassandra.CassandraChatMemoryRepositoryConfig;
import org.springframework.ai.chat.memory.repository.cassandra.CassandraChatMemoryRepository;
import org.springframework.ai.chat.memory.repository.cassandra.CassandraChatMemoryRepositoryConfig;
import org.springframework.ai.model.chat.memory.autoconfigure.ChatMemoryAutoConfiguration;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.cassandra.CassandraAutoConfiguration;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

import javax.sql.DataSource;

import org.springframework.ai.chat.memory.repository.jdbc.JdbcChatMemoryRepositoryDialect;
import org.springframework.ai.chat.memory.repository.jdbc.JdbcChatMemoryRepository;
import org.springframework.ai.chat.memory.repository.jdbc.JdbcChatMemoryRepositoryDialect;
import org.springframework.ai.model.chat.memory.autoconfigure.ChatMemoryAutoConfiguration;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public void setInitializeSchema(DatabaseInitializationMode initializeSchema) {
}

public String getPlatform() {
return platform;
return this.platform;
}

public void setPlatform(String platform) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,11 @@ public void setUp() {

// Debug: Print current schemas and tables
try {
List<String> schemas = jdbcTemplate.queryForList("SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA",
String.class);
List<String> schemas = this.jdbcTemplate
.queryForList("SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA", String.class);
System.out.println("Available schemas: " + schemas);

List<String> tables = jdbcTemplate
List<String> tables = this.jdbcTemplate
.queryForList("SELECT TABLE_SCHEMA, TABLE_NAME FROM INFORMATION_SCHEMA.TABLES", String.class);
System.out.println("Available tables: " + tables);
}
Expand All @@ -89,22 +89,22 @@ public void setUp() {
// Try a more direct approach with explicit SQL statements
try {
// Drop the table first if it exists to avoid any conflicts
jdbcTemplate.execute("DROP TABLE SPRING_AI_CHAT_MEMORY IF EXISTS");
this.jdbcTemplate.execute("DROP TABLE SPRING_AI_CHAT_MEMORY IF EXISTS");
System.out.println("Dropped existing table if it existed");

// Create the table with a simplified schema
jdbcTemplate.execute("CREATE TABLE SPRING_AI_CHAT_MEMORY (" + "conversation_id VARCHAR(36) NOT NULL, "
+ "content LONGVARCHAR NOT NULL, " + "type VARCHAR(10) NOT NULL, "
+ "timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL)");
this.jdbcTemplate.execute("CREATE TABLE SPRING_AI_CHAT_MEMORY ("
+ "conversation_id VARCHAR(36) NOT NULL, " + "content LONGVARCHAR NOT NULL, "
+ "type VARCHAR(10) NOT NULL, " + "timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL)");
System.out.println("Created table with simplified schema");

// Create index
jdbcTemplate.execute(
this.jdbcTemplate.execute(
"CREATE INDEX SPRING_AI_CHAT_MEMORY_IDX ON SPRING_AI_CHAT_MEMORY(conversation_id, timestamp DESC)");
System.out.println("Created index");

// Verify table was created
boolean tableExists = jdbcTemplate.queryForObject(
boolean tableExists = this.jdbcTemplate.queryForObject(
"SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = 'SPRING_AI_CHAT_MEMORY'",
Integer.class) > 0;
System.out.println("Table SPRING_AI_CHAT_MEMORY exists after creation: " + tableExists);
Expand All @@ -125,7 +125,7 @@ public void setUp() {
@Test
public void useAutoConfiguredChatMemoryWithJdbc() {
// Check that the custom schema initializer is present
assertThat(context.containsBean("jdbcChatMemoryScriptDatabaseInitializer")).isTrue();
assertThat(this.context.containsBean("jdbcChatMemoryScriptDatabaseInitializer")).isTrue();

// Debug: List all schema-hsqldb.sql resources on the classpath
try {
Expand All @@ -144,7 +144,7 @@ public void useAutoConfiguredChatMemoryWithJdbc() {

// Verify the table exists by executing a direct query
try {
boolean tableExists = jdbcTemplate.queryForObject(
boolean tableExists = this.jdbcTemplate.queryForObject(
"SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = 'SPRING_AI_CHAT_MEMORY'",
Integer.class) > 0;
System.out.println("Table SPRING_AI_CHAT_MEMORY exists: " + tableExists);
Expand All @@ -157,10 +157,10 @@ public void useAutoConfiguredChatMemoryWithJdbc() {
}

// Now test the ChatMemory functionality
assertThat(context.getBean(org.springframework.ai.chat.memory.ChatMemory.class)).isNotNull();
assertThat(context.getBean(JdbcChatMemoryRepository.class)).isNotNull();
assertThat(this.context.getBean(org.springframework.ai.chat.memory.ChatMemory.class)).isNotNull();
assertThat(this.context.getBean(JdbcChatMemoryRepository.class)).isNotNull();

var chatMemory = context.getBean(org.springframework.ai.chat.memory.ChatMemory.class);
var chatMemory = this.context.getBean(org.springframework.ai.chat.memory.ChatMemory.class);
var conversationId = java.util.UUID.randomUUID().toString();
var userMessage = new UserMessage("Message from the user");

Expand Down
Loading