-
Notifications
You must be signed in to change notification settings - Fork 76
Expand file tree
/
Copy pathserver.py
More file actions
195 lines (165 loc) · 7.8 KB
/
server.py
File metadata and controls
195 lines (165 loc) · 7.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import asyncio
from importlib import import_module
from typing import Any, Awaitable, Optional, Type
from redis import __version__ as redis_py_version
from redisvl.exceptions import RedisSearchError
from redisvl.index import AsyncSearchIndex
from redisvl.mcp.config import MCPConfig, load_mcp_config
from redisvl.mcp.settings import MCPSettings
from redisvl.mcp.tools.search import register_search_tool
from redisvl.mcp.tools.upsert import register_upsert_tool
from redisvl.redis.connection import RedisConnectionFactory, is_version_gte
from redisvl.schema import IndexSchema
try:
from fastmcp import FastMCP
except ImportError:
class FastMCP: # type: ignore[no-redef]
"""Import-safe stand-in used when the optional MCP SDK is unavailable."""
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
def resolve_vectorizer_class(class_name: str) -> Type[Any]:
"""Resolve a vectorizer class from the public RedisVL vectorizer module."""
vectorize_module = import_module("redisvl.utils.vectorize")
try:
return getattr(vectorize_module, class_name)
except AttributeError as exc:
raise ValueError(f"Unknown vectorizer class: {class_name}") from exc
class RedisVLMCPServer(FastMCP):
"""MCP server exposing RedisVL capabilities for one existing Redis index."""
def __init__(self, settings: MCPSettings):
"""Create a server shell with lazy config, index, and vectorizer state."""
super().__init__("redisvl")
self.mcp_settings = settings
self.config: Optional[MCPConfig] = None
self._index: Optional[AsyncSearchIndex] = None
self._vectorizer: Optional[Any] = None
self._semaphore: Optional[asyncio.Semaphore] = None
self._tools_registered = False
async def startup(self) -> None:
"""Load config, inspect the configured index, and initialize dependencies."""
self.config = load_mcp_config(self.mcp_settings.config)
self._semaphore = asyncio.Semaphore(self.config.runtime.max_concurrency)
timeout = self.config.runtime.startup_timeout_seconds
client = None
try:
client = await asyncio.wait_for(
RedisConnectionFactory._get_aredis_connection(
redis_url=self.config.server.redis_url
),
timeout=timeout,
)
await asyncio.wait_for(client.info("server"), timeout=timeout)
try:
index_info = await asyncio.wait_for(
AsyncSearchIndex._info(self.config.redis_name, client),
timeout=timeout,
)
except RedisSearchError as exc:
if self._is_missing_index_error(exc):
raise ValueError(
f"Configured Redis index '{self.config.redis_name}' does not exist"
) from exc
raise
inspected_schema = self.config.inspected_schema_from_index_info(index_info)
effective_schema = self.config.to_index_schema(inspected_schema)
self._index = AsyncSearchIndex(schema=effective_schema, redis_client=client)
# The server acquired this client explicitly during startup, so hand
# ownership to the index for a single shutdown path.
self._index._owns_redis_client = True
self.config.validate_search(
supports_native_hybrid_search=await self.supports_native_hybrid_search(),
)
self._vectorizer = await asyncio.wait_for(
asyncio.to_thread(self._build_vectorizer),
timeout=timeout,
)
self._validate_vectorizer_dims(effective_schema)
self._register_tools()
except Exception:
if self._index is not None:
await self.shutdown()
elif client is not None:
await client.aclose()
raise
async def shutdown(self) -> None:
"""Release owned vectorizer and Redis resources."""
vectorizer = self._vectorizer
self._vectorizer = None
try:
if vectorizer is not None:
aclose = getattr(vectorizer, "aclose", None)
close = getattr(vectorizer, "close", None)
if callable(aclose):
await aclose()
elif callable(close):
close()
finally:
if self._index is not None:
index = self._index
self._index = None
await index.disconnect()
async def get_index(self) -> AsyncSearchIndex:
"""Return the initialized async index or fail if startup has not run."""
if self._index is None:
raise RuntimeError("MCP server has not been started")
return self._index
async def get_vectorizer(self) -> Any:
"""Return the initialized vectorizer or fail if startup has not run."""
if self._vectorizer is None:
raise RuntimeError("MCP server has not been started")
return self._vectorizer
async def run_guarded(self, operation_name: str, awaitable: Awaitable[Any]) -> Any:
"""Run a coroutine under the configured concurrency and timeout limits."""
del operation_name
if self.config is None or self._semaphore is None:
raise RuntimeError("MCP server has not been started")
async with self._semaphore:
return await asyncio.wait_for(
awaitable,
timeout=self.config.runtime.request_timeout_seconds,
)
def _build_vectorizer(self) -> Any:
"""Instantiate the configured vectorizer class from validated config."""
if self.config is None:
raise RuntimeError("MCP server config not loaded")
vectorizer_class = resolve_vectorizer_class(self.config.vectorizer.class_name)
return vectorizer_class(**self.config.vectorizer.to_init_kwargs())
def _validate_vectorizer_dims(self, schema: IndexSchema) -> None:
"""Fail startup when vectorizer dimensions disagree with schema dimensions."""
if self.config is None or self._vectorizer is None:
return
configured_dims = self.config.get_vector_field_dims(schema)
actual_dims = getattr(self._vectorizer, "dims", None)
if (
configured_dims is not None
and actual_dims is not None
and configured_dims != actual_dims
):
raise ValueError(
f"Vectorizer dims {actual_dims} do not match configured vector field dims {configured_dims}"
)
async def supports_native_hybrid_search(self) -> bool:
"""Return whether the current runtime supports Redis native hybrid search."""
if self._index is None:
raise RuntimeError("MCP server has not been started")
if not is_version_gte(redis_py_version, "7.1.0"):
return False
client = await self._index._get_client()
info = await client.info("server")
if not is_version_gte(info.get("redis_version", "0.0.0"), "8.4.0"):
return False
return hasattr(client.ft(self._index.schema.index.name), "hybrid_search")
def _register_tools(self) -> None:
"""Register MCP tools once the server is ready."""
if self._tools_registered or not hasattr(self, "tool"):
return
register_search_tool(self)
if not self.mcp_settings.read_only:
register_upsert_tool(self)
self._tools_registered = True
@staticmethod
def _is_missing_index_error(exc: RedisSearchError) -> bool:
"""Detect the Redis search errors that mean the configured index is absent."""
message = str(exc).lower()
return "unknown index name" in message or "no such index" in message