Skip to content

Commit 9cf0269

Browse files
authored
Merge branch 'main' into main
Signed-off-by: Senrey_Song <[email protected]>
2 parents bbcb968 + e56eb12 commit 9cf0269

File tree

84 files changed

+8403
-637
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

84 files changed

+8403
-637
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,7 @@ qodana.yaml
5151
__pycache__/
5252
*.pyc
5353
tmp
54+
55+
56+
plans
57+

advisors/spring-ai-advisors-vector-store/src/main/java/org/springframework/ai/chat/client/advisor/vectorstore/VectorStoreChatMemoryAdvisor.java

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.util.HashMap;
2121
import java.util.List;
2222
import java.util.Map;
23+
import java.util.stream.Collectors;
2324

2425
import reactor.core.publisher.Flux;
2526
import reactor.core.publisher.Mono;
@@ -37,9 +38,11 @@
3738
import org.springframework.ai.chat.messages.AssistantMessage;
3839
import org.springframework.ai.chat.messages.Message;
3940
import org.springframework.ai.chat.messages.MessageType;
41+
import org.springframework.ai.chat.messages.SystemMessage;
4042
import org.springframework.ai.chat.messages.UserMessage;
4143
import org.springframework.ai.chat.prompt.PromptTemplate;
4244
import org.springframework.ai.document.Document;
45+
import org.springframework.ai.vectorstore.SearchRequest;
4346
import org.springframework.ai.vectorstore.VectorStore;
4447
import org.springframework.util.Assert;
4548

@@ -122,31 +125,23 @@ public ChatClientRequest before(ChatClientRequest request, AdvisorChain advisorC
122125
String query = request.prompt().getUserMessage() != null ? request.prompt().getUserMessage().getText() : "";
123126
int topK = getChatMemoryTopK(request.context());
124127
String filter = DOCUMENT_METADATA_CONVERSATION_ID + "=='" + conversationId + "'";
125-
var searchRequest = org.springframework.ai.vectorstore.SearchRequest.builder()
126-
.query(query)
127-
.topK(topK)
128-
.filterExpression(filter)
129-
.build();
130-
java.util.List<org.springframework.ai.document.Document> documents = this.vectorStore
131-
.similaritySearch(searchRequest);
128+
SearchRequest searchRequest = SearchRequest.builder().query(query).topK(topK).filterExpression(filter).build();
129+
List<Document> documents = this.vectorStore.similaritySearch(searchRequest);
132130

133131
String longTermMemory = documents == null ? ""
134-
: documents.stream()
135-
.map(org.springframework.ai.document.Document::getText)
136-
.collect(java.util.stream.Collectors.joining(System.lineSeparator()));
132+
: documents.stream().map(Document::getText).collect(Collectors.joining(System.lineSeparator()));
137133

138-
org.springframework.ai.chat.messages.SystemMessage systemMessage = request.prompt().getSystemMessage();
134+
SystemMessage systemMessage = request.prompt().getSystemMessage();
139135
String augmentedSystemText = this.systemPromptTemplate
140-
.render(java.util.Map.of("instructions", systemMessage.getText(), "long_term_memory", longTermMemory));
136+
.render(Map.of("instructions", systemMessage.getText(), "long_term_memory", longTermMemory));
141137

142138
ChatClientRequest processedChatClientRequest = request.mutate()
143139
.prompt(request.prompt().augmentSystemMessage(augmentedSystemText))
144140
.build();
145141

146-
org.springframework.ai.chat.messages.UserMessage userMessage = processedChatClientRequest.prompt()
147-
.getUserMessage();
142+
UserMessage userMessage = processedChatClientRequest.prompt().getUserMessage();
148143
if (userMessage != null) {
149-
this.vectorStore.write(toDocuments(java.util.List.of(userMessage), conversationId));
144+
this.vectorStore.write(toDocuments(List.of(userMessage), conversationId));
150145
}
151146

152147
return processedChatClientRequest;
@@ -186,10 +181,11 @@ public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest
186181
}
187182

188183
private List<Document> toDocuments(List<Message> messages, String conversationId) {
189-
List<Document> docs = messages.stream()
184+
return messages.stream()
190185
.filter(m -> m.getMessageType() == MessageType.USER || m.getMessageType() == MessageType.ASSISTANT)
191186
.map(message -> {
192-
var metadata = new HashMap<>(message.getMetadata() != null ? message.getMetadata() : new HashMap<>());
187+
Map<String, Object> metadata = new HashMap<>(
188+
message.getMetadata() != null ? message.getMetadata() : new HashMap<>());
193189
metadata.put(DOCUMENT_METADATA_CONVERSATION_ID, conversationId);
194190
metadata.put(DOCUMENT_METADATA_MESSAGE_TYPE, message.getMessageType().name());
195191
if (message instanceof UserMessage userMessage) {
@@ -208,8 +204,6 @@ else if (message instanceof AssistantMessage assistantMessage) {
208204
throw new RuntimeException("Unknown message type: " + message.getMessageType());
209205
})
210206
.toList();
211-
212-
return docs;
213207
}
214208

215209
/**

auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/main/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfiguration.java

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222
import io.modelcontextprotocol.client.McpSyncClient;
2323

2424
import org.springframework.ai.mcp.AsyncMcpToolCallbackProvider;
25+
import org.springframework.ai.mcp.McpToolFilter;
26+
import org.springframework.ai.mcp.McpToolNamePrefixGenerator;
2527
import org.springframework.ai.mcp.SyncMcpToolCallbackProvider;
28+
import org.springframework.ai.mcp.ToolContextToMcpMetaConverter;
2629
import org.springframework.ai.mcp.client.common.autoconfigure.properties.McpClientCommonProperties;
2730
import org.springframework.beans.factory.ObjectProvider;
2831
import org.springframework.boot.autoconfigure.AutoConfiguration;
@@ -45,22 +48,45 @@ public class McpToolCallbackAutoConfiguration {
4548
* <p>
4649
* These callbacks enable integration with Spring AI's tool execution framework,
4750
* allowing MCP tools to be used as part of AI interactions.
51+
* @param syncClientsToolFilter list of {@link McpToolFilter}s for the sync client to
52+
* filter the discovered tools
4853
* @param syncMcpClients provider of MCP sync clients
54+
* @param mcpToolNamePrefixGenerator the tool name prefix generator
4955
* @return list of tool callbacks for MCP integration
5056
*/
5157
@Bean
5258
@ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "SYNC",
5359
matchIfMissing = true)
54-
public SyncMcpToolCallbackProvider mcpToolCallbacks(ObjectProvider<List<McpSyncClient>> syncMcpClients) {
60+
public SyncMcpToolCallbackProvider mcpToolCallbacks(ObjectProvider<McpToolFilter> syncClientsToolFilter,
61+
ObjectProvider<List<McpSyncClient>> syncMcpClients,
62+
ObjectProvider<McpToolNamePrefixGenerator> mcpToolNamePrefixGenerator,
63+
ObjectProvider<ToolContextToMcpMetaConverter> toolContextToMcpMetaConverter) {
5564
List<McpSyncClient> mcpClients = syncMcpClients.stream().flatMap(List::stream).toList();
56-
return new SyncMcpToolCallbackProvider(mcpClients);
65+
return SyncMcpToolCallbackProvider.builder()
66+
.mcpClients(mcpClients)
67+
.toolFilter(syncClientsToolFilter.getIfUnique((() -> (McpSyncClient, tool) -> true)))
68+
.toolNamePrefixGenerator(
69+
mcpToolNamePrefixGenerator.getIfUnique(() -> McpToolNamePrefixGenerator.defaultGenerator()))
70+
.toolContextToMcpMetaConverter(
71+
toolContextToMcpMetaConverter.getIfUnique(() -> ToolContextToMcpMetaConverter.defaultConverter()))
72+
.build();
5773
}
5874

5975
@Bean
6076
@ConditionalOnProperty(prefix = McpClientCommonProperties.CONFIG_PREFIX, name = "type", havingValue = "ASYNC")
61-
public AsyncMcpToolCallbackProvider mcpAsyncToolCallbacks(ObjectProvider<List<McpAsyncClient>> mcpClientsProvider) {
77+
public AsyncMcpToolCallbackProvider mcpAsyncToolCallbacks(ObjectProvider<McpToolFilter> asyncClientsToolFilter,
78+
ObjectProvider<List<McpAsyncClient>> mcpClientsProvider,
79+
ObjectProvider<McpToolNamePrefixGenerator> toolNamePrefixGenerator,
80+
ObjectProvider<ToolContextToMcpMetaConverter> toolContextToMcpMetaConverter) { // TODO
6281
List<McpAsyncClient> mcpClients = mcpClientsProvider.stream().flatMap(List::stream).toList();
63-
return new AsyncMcpToolCallbackProvider(mcpClients);
82+
return AsyncMcpToolCallbackProvider.builder()
83+
.toolFilter(asyncClientsToolFilter.getIfUnique(() -> (McpAsyncClient, tool) -> true))
84+
.toolNamePrefixGenerator(
85+
toolNamePrefixGenerator.getIfUnique(() -> McpToolNamePrefixGenerator.defaultGenerator()))
86+
.toolContextToMcpMetaConverter(
87+
toolContextToMcpMetaConverter.getIfUnique(() -> ToolContextToMcpMetaConverter.defaultConverter()))
88+
.mcpClients(mcpClients)
89+
.build();
6490
}
6591

6692
public static class McpToolCallbackAutoConfigurationCondition extends AllNestedConditions {

auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common/src/test/java/org/springframework/ai/mcp/client/common/autoconfigure/McpToolCallbackAutoConfigurationConditionTests.java

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,28 @@
1616

1717
package org.springframework.ai.mcp.client.common.autoconfigure;
1818

19+
import java.lang.reflect.Field;
20+
import java.util.List;
21+
22+
import io.modelcontextprotocol.client.McpAsyncClient;
23+
import io.modelcontextprotocol.client.McpSyncClient;
24+
import io.modelcontextprotocol.spec.McpSchema;
1925
import org.junit.jupiter.api.Test;
26+
import reactor.core.publisher.Mono;
2027

28+
import org.springframework.ai.mcp.AsyncMcpToolCallbackProvider;
29+
import org.springframework.ai.mcp.McpConnectionInfo;
30+
import org.springframework.ai.mcp.McpToolFilter;
31+
import org.springframework.ai.mcp.SyncMcpToolCallbackProvider;
2132
import org.springframework.ai.mcp.client.common.autoconfigure.McpToolCallbackAutoConfiguration.McpToolCallbackAutoConfigurationCondition;
2233
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
2334
import org.springframework.context.annotation.Bean;
2435
import org.springframework.context.annotation.Conditional;
2536
import org.springframework.context.annotation.Configuration;
2637

2738
import static org.assertj.core.api.Assertions.assertThat;
39+
import static org.mockito.Mockito.mock;
40+
import static org.mockito.Mockito.when;
2841

2942
/**
3043
* Tests for {@link McpToolCallbackAutoConfigurationCondition}.
@@ -73,6 +86,62 @@ void doesMatchWhenBothPropertiesAreMissing() {
7386
this.contextRunner.run(context -> assertThat(context).hasBean("testBean"));
7487
}
7588

89+
@Test
90+
void verifySyncToolCallbackFilterConfiguration() {
91+
this.contextRunner
92+
.withUserConfiguration(McpToolCallbackAutoConfiguration.class, McpClientFilterConfiguration.class)
93+
.withPropertyValues("spring.ai.mcp.client.type=SYNC")
94+
.run(context -> {
95+
assertThat(context).hasBean("mcpClientFilter");
96+
SyncMcpToolCallbackProvider toolCallbackProvider = context.getBean(SyncMcpToolCallbackProvider.class);
97+
Field field = SyncMcpToolCallbackProvider.class.getDeclaredField("toolFilter");
98+
field.setAccessible(true);
99+
McpToolFilter toolFilter = (McpToolFilter) field.get(toolCallbackProvider);
100+
McpSyncClient syncClient1 = mock(McpSyncClient.class);
101+
var clientInfo1 = new McpSchema.Implementation("client1", "1.0.0");
102+
when(syncClient1.getClientInfo()).thenReturn(clientInfo1);
103+
McpSchema.Tool tool1 = mock(McpSchema.Tool.class);
104+
when(tool1.name()).thenReturn("tool1");
105+
McpSchema.Tool tool2 = mock(McpSchema.Tool.class);
106+
when(tool2.name()).thenReturn("tool2");
107+
McpSchema.ListToolsResult listToolsResult1 = mock(McpSchema.ListToolsResult.class);
108+
when(listToolsResult1.tools()).thenReturn(List.of(tool1, tool2));
109+
when(syncClient1.listTools()).thenReturn(listToolsResult1);
110+
assertThat(toolFilter.test(new McpConnectionInfo(null, syncClient1.getClientInfo(), null), tool1))
111+
.isFalse();
112+
assertThat(toolFilter.test(new McpConnectionInfo(null, syncClient1.getClientInfo(), null), tool2))
113+
.isTrue();
114+
});
115+
}
116+
117+
@Test
118+
void verifyAsyncToolCallbackFilterConfiguration() {
119+
this.contextRunner
120+
.withUserConfiguration(McpToolCallbackAutoConfiguration.class, McpClientFilterConfiguration.class)
121+
.withPropertyValues("spring.ai.mcp.client.type=ASYNC")
122+
.run(context -> {
123+
assertThat(context).hasBean("mcpClientFilter");
124+
AsyncMcpToolCallbackProvider toolCallbackProvider = context.getBean(AsyncMcpToolCallbackProvider.class);
125+
Field field = AsyncMcpToolCallbackProvider.class.getDeclaredField("toolFilter");
126+
field.setAccessible(true);
127+
McpToolFilter toolFilter = (McpToolFilter) field.get(toolCallbackProvider);
128+
McpAsyncClient asyncClient1 = mock(McpAsyncClient.class);
129+
var clientInfo1 = new McpSchema.Implementation("client1", "1.0.0");
130+
when(asyncClient1.getClientInfo()).thenReturn(clientInfo1);
131+
McpSchema.Tool tool1 = mock(McpSchema.Tool.class);
132+
when(tool1.name()).thenReturn("tool1");
133+
McpSchema.Tool tool2 = mock(McpSchema.Tool.class);
134+
when(tool2.name()).thenReturn("tool2");
135+
McpSchema.ListToolsResult listToolsResult1 = mock(McpSchema.ListToolsResult.class);
136+
when(listToolsResult1.tools()).thenReturn(List.of(tool1, tool2));
137+
when(asyncClient1.listTools()).thenReturn(Mono.just(listToolsResult1));
138+
assertThat(toolFilter.test(new McpConnectionInfo(null, asyncClient1.getClientInfo(), null), tool1))
139+
.isFalse();
140+
assertThat(toolFilter.test(new McpConnectionInfo(null, asyncClient1.getClientInfo(), null), tool2))
141+
.isTrue();
142+
});
143+
}
144+
76145
@Configuration
77146
@Conditional(McpToolCallbackAutoConfigurationCondition.class)
78147
static class TestConfiguration {
@@ -84,4 +153,22 @@ String testBean() {
84153

85154
}
86155

156+
@Configuration
157+
static class McpClientFilterConfiguration {
158+
159+
@Bean
160+
McpToolFilter mcpClientFilter() {
161+
return new McpToolFilter() {
162+
@Override
163+
public boolean test(McpConnectionInfo metadata, McpSchema.Tool tool) {
164+
if (metadata.clientInfo().name().equals("client1") && tool.name().contains("tool1")) {
165+
return false;
166+
}
167+
return true;
168+
}
169+
};
170+
}
171+
172+
}
173+
87174
}

0 commit comments

Comments
 (0)