diff --git a/.gitignore b/.gitignore index cd800581..16c99a57 100644 --- a/.gitignore +++ b/.gitignore @@ -189,6 +189,9 @@ dmypy.json # Cython debug symbols cython_debug/ +# Codex +.codex/ + # PyCharm # JetBrains specific template is maintained in a separate JetBrains.gitignore that can # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore diff --git a/redisvl/mcp/config.py b/redisvl/mcp/config.py index 939c7c6d..e226986b 100644 --- a/redisvl/mcp/config.py +++ b/redisvl/mcp/config.py @@ -2,7 +2,7 @@ import re from copy import deepcopy from pathlib import Path -from typing import Any, Dict, Optional +from typing import Any, Dict, Literal, Optional import yaml from pydantic import BaseModel, ConfigDict, Field, model_validator @@ -71,6 +71,101 @@ class MCPServerConfig(BaseModel): redis_url: str = Field(..., min_length=1) +class MCPIndexSearchConfig(BaseModel): + """Configured search mode and query tuning for the bound index. + + The MCP request contract only exposes query text, filtering, pagination, and + field projection. Search mode and query-tuning behavior are owned entirely by + YAML config and validated here. + """ + + type: Literal["vector", "fulltext", "hybrid"] + params: Dict[str, Any] = Field(default_factory=dict) + + @model_validator(mode="after") + def _validate_params(self) -> "MCPIndexSearchConfig": + """Reject params that do not belong to the configured search mode.""" + allowed_params = { + "vector": { + "hybrid_policy", + "batch_size", + "ef_runtime", + "epsilon", + "search_window_size", + "use_search_history", + "search_buffer_capacity", + "normalize_vector_distance", + }, + "fulltext": { + "text_scorer", + "stopwords", + "text_weights", + }, + "hybrid": { + "text_scorer", + "stopwords", + "text_weights", + "vector_search_method", + "knn_ef_runtime", + "range_radius", + "range_epsilon", + "combination_method", + "rrf_window", + "rrf_constant", + "linear_text_weight", + }, + } + invalid_keys = sorted(set(self.params) - allowed_params[self.type]) + if invalid_keys: + raise ValueError( + "search.params contains keys incompatible with " + f"search.type '{self.type}': {', '.join(invalid_keys)}" + ) + + if ( + "linear_text_weight" in self.params + and self.params.get("combination_method") != "LINEAR" + ): + raise ValueError( + "search.params.linear_text_weight requires combination_method to be LINEAR" + ) + return self + + def to_query_params(self) -> Dict[str, Any]: + """Return normalized query kwargs exactly as configured.""" + return dict(self.params) + + def validate_runtime_capabilities( + self, *, supports_native_hybrid_search: bool + ) -> None: + """Fail startup when hybrid config depends on native-only FT.SEARCH params.""" + if self.type != "hybrid" or supports_native_hybrid_search: + return + + unsupported_params = set() + if self.params.get("combination_method") not in (None, "LINEAR"): + unsupported_params.add("combination_method") + unsupported_params.update( + key + for key in ( + "vector_search_method", + "knn_ef_runtime", + "range_radius", + "range_epsilon", + "rrf_window", + "rrf_constant", + ) + if key in self.params + ) + + if unsupported_params: + unsupported_list = ", ".join(sorted(unsupported_params)) + raise ValueError( + "search.params requires native hybrid search support for: " + f"{unsupported_list}" + ) + + class MCPSchemaOverrideField(BaseModel): """Allowed schema override fragment for one already-discovered field.""" @@ -91,6 +186,7 @@ class MCPIndexBindingConfig(BaseModel): redis_name: str = Field(..., min_length=1) vectorizer: MCPVectorizerConfig + search: MCPIndexSearchConfig runtime: MCPRuntimeConfig schema_overrides: MCPSchemaOverrides = Field(default_factory=MCPSchemaOverrides) @@ -134,6 +230,11 @@ def vectorizer(self) -> MCPVectorizerConfig: """Expose the sole binding's vectorizer config for phase 1.""" return self.binding.vectorizer + @property + def search(self) -> MCPIndexSearchConfig: + """Expose the sole binding's configured search behavior.""" + return self.binding.search + @property def redis_name(self) -> str: """Return the existing Redis index name that must be inspected at startup.""" @@ -255,6 +356,16 @@ def get_vector_field_dims(self, schema: IndexSchema) -> Optional[int]: attrs = self.get_vector_field(schema).attrs return getattr(attrs, "dims", None) + def validate_search( + self, + *, + supports_native_hybrid_search: bool, + ) -> None: + """Validate configured search behavior against current runtime support.""" + self.search.validate_runtime_capabilities( + supports_native_hybrid_search=supports_native_hybrid_search + ) + def _substitute_env(value: Any) -> Any: """Recursively resolve `${VAR}` and `${VAR:-default}` placeholders.""" diff --git a/redisvl/mcp/errors.py b/redisvl/mcp/errors.py index 54fb59bc..6befad3b 100644 --- a/redisvl/mcp/errors.py +++ b/redisvl/mcp/errors.py @@ -12,6 +12,7 @@ class MCPErrorCode(str, Enum): """Stable internal error codes exposed by the MCP framework.""" INVALID_REQUEST = "invalid_request" + INVALID_FILTER = "invalid_filter" DEPENDENCY_MISSING = "dependency_missing" BACKEND_UNAVAILABLE = "backend_unavailable" INTERNAL_ERROR = "internal_error" diff --git a/redisvl/mcp/filters.py b/redisvl/mcp/filters.py new file mode 100644 index 00000000..cc870439 --- /dev/null +++ b/redisvl/mcp/filters.py @@ -0,0 +1,236 @@ +from __future__ import annotations + +from typing import Any, Iterable, Optional + +from redisvl.mcp.errors import MCPErrorCode, RedisVLMCPError +from redisvl.query.filter import FilterExpression, Num, Tag, Text +from redisvl.schema import IndexSchema + + +def parse_filter( + value: Optional[str | dict[str, Any]], schema: IndexSchema +) -> Optional[str | FilterExpression]: + """Parse an MCP filter value into a RedisVL filter representation.""" + if value is None: + return None + if isinstance(value, str): + return value + if not isinstance(value, dict): + raise RedisVLMCPError( + "filter must be a string or object", + code=MCPErrorCode.INVALID_FILTER, + retryable=False, + ) + return _parse_expression(value, schema) + + +def _parse_expression(value: dict[str, Any], schema: IndexSchema) -> FilterExpression: + logical_keys = [key for key in ("and", "or", "not") if key in value] + if logical_keys: + if len(logical_keys) != 1 or len(value) != 1: + raise RedisVLMCPError( + "logical filter objects must contain exactly one operator", + code=MCPErrorCode.INVALID_FILTER, + retryable=False, + ) + + logical_key = logical_keys[0] + if logical_key == "not": + child = value["not"] + if not isinstance(child, dict): + raise RedisVLMCPError( + "not filter must wrap a single object expression", + code=MCPErrorCode.INVALID_FILTER, + retryable=False, + ) + return FilterExpression(f"(-({str(_parse_expression(child, schema))}))") + + children = value[logical_key] + if not isinstance(children, list) or not children: + raise RedisVLMCPError( + f"{logical_key} filter must contain a non-empty array", + code=MCPErrorCode.INVALID_FILTER, + retryable=False, + ) + + expressions: list[FilterExpression] = [] + for child in children: + if not isinstance(child, dict): + raise RedisVLMCPError( + "logical filter children must be objects", + code=MCPErrorCode.INVALID_FILTER, + retryable=False, + ) + expressions.append(_parse_expression(child, schema)) + + combined = expressions[0] + for child in expressions[1:]: + combined = combined & child if logical_key == "and" else combined | child + return combined + + field_name = value.get("field") + op = value.get("op") + if not isinstance(field_name, str) or not field_name.strip(): + raise RedisVLMCPError( + "filter.field must be a non-empty string", + code=MCPErrorCode.INVALID_FILTER, + retryable=False, + ) + if not isinstance(op, str) or not op.strip(): + raise RedisVLMCPError( + "filter.op must be a non-empty string", + code=MCPErrorCode.INVALID_FILTER, + retryable=False, + ) + + field = schema.fields.get(field_name) + if field is None: + raise RedisVLMCPError( + f"Unknown filter field: {field_name}", + code=MCPErrorCode.INVALID_FILTER, + retryable=False, + ) + + normalized_op = op.lower() + if normalized_op == "exists": + return FilterExpression(f"(-ismissing(@{field_name}))") + + if "value" not in value: + raise RedisVLMCPError( + "filter.value is required for this operator", + code=MCPErrorCode.INVALID_FILTER, + retryable=False, + ) + + operand = value["value"] + if field.type == "tag": + return _parse_tag_expression(field_name, normalized_op, operand) + if field.type == "text": + return _parse_text_expression(field_name, normalized_op, operand) + if field.type == "numeric": + return _parse_numeric_expression(field_name, normalized_op, operand) + + raise RedisVLMCPError( + f"Unsupported filter field type for {field_name}: {field.type}", + code=MCPErrorCode.INVALID_FILTER, + retryable=False, + ) + + +def _parse_tag_expression(field_name: str, op: str, operand: Any) -> FilterExpression: + field = Tag(field_name) + if op == "eq": + return field == _require_string(operand, field_name, op) + if op == "ne": + return field != _require_string(operand, field_name, op) + if op == "in": + return field == _require_string_list(operand, field_name, op) + if op == "like": + return field % _require_string(operand, field_name, op) + raise RedisVLMCPError( + f"Unsupported operator '{op}' for tag field '{field_name}'", + code=MCPErrorCode.INVALID_FILTER, + retryable=False, + ) + + +def _parse_text_expression(field_name: str, op: str, operand: Any) -> FilterExpression: + field = Text(field_name) + if op == "eq": + return field == _require_string(operand, field_name, op) + if op == "ne": + return field != _require_string(operand, field_name, op) + if op == "like": + return field % _require_string(operand, field_name, op) + if op == "in": + return _combine_or( + [field == item for item in _require_string_list(operand, field_name, op)] + ) + raise RedisVLMCPError( + f"Unsupported operator '{op}' for text field '{field_name}'", + code=MCPErrorCode.INVALID_FILTER, + retryable=False, + ) + + +def _parse_numeric_expression( + field_name: str, op: str, operand: Any +) -> FilterExpression: + field = Num(field_name) + if op == "eq": + return field == _require_number(operand, field_name, op) + if op == "ne": + return field != _require_number(operand, field_name, op) + if op == "gt": + return field > _require_number(operand, field_name, op) + if op == "gte": + return field >= _require_number(operand, field_name, op) + if op == "lt": + return field < _require_number(operand, field_name, op) + if op == "lte": + return field <= _require_number(operand, field_name, op) + if op == "in": + return _combine_or( + [field == item for item in _require_number_list(operand, field_name, op)] + ) + raise RedisVLMCPError( + f"Unsupported operator '{op}' for numeric field '{field_name}'", + code=MCPErrorCode.INVALID_FILTER, + retryable=False, + ) + + +def _combine_or(expressions: Iterable[FilterExpression]) -> FilterExpression: + expression_list = list(expressions) + if not expression_list: + raise RedisVLMCPError( + "in operator requires a non-empty array", + code=MCPErrorCode.INVALID_FILTER, + retryable=False, + ) + + combined = expression_list[0] + for expression in expression_list[1:]: + combined = combined | expression + return combined + + +def _require_string(value: Any, field_name: str, op: str) -> str: + if not isinstance(value, str) or not value: + raise RedisVLMCPError( + f"filter value for field '{field_name}' and operator '{op}' must be a non-empty string", + code=MCPErrorCode.INVALID_FILTER, + retryable=False, + ) + return value + + +def _require_string_list(value: Any, field_name: str, op: str) -> list[str]: + if not isinstance(value, list) or not value: + raise RedisVLMCPError( + f"filter value for field '{field_name}' and operator '{op}' must be a non-empty array", + code=MCPErrorCode.INVALID_FILTER, + retryable=False, + ) + strings = [_require_string(item, field_name, op) for item in value] + return strings + + +def _require_number(value: Any, field_name: str, op: str) -> int | float: + if isinstance(value, bool) or not isinstance(value, (int, float)): + raise RedisVLMCPError( + f"filter value for field '{field_name}' and operator '{op}' must be numeric", + code=MCPErrorCode.INVALID_FILTER, + retryable=False, + ) + return value + + +def _require_number_list(value: Any, field_name: str, op: str) -> list[int | float]: + if not isinstance(value, list) or not value: + raise RedisVLMCPError( + f"filter value for field '{field_name}' and operator '{op}' must be a non-empty array", + code=MCPErrorCode.INVALID_FILTER, + retryable=False, + ) + return [_require_number(item, field_name, op) for item in value] diff --git a/redisvl/mcp/server.py b/redisvl/mcp/server.py index 12e1d6db..fb0df041 100644 --- a/redisvl/mcp/server.py +++ b/redisvl/mcp/server.py @@ -2,11 +2,13 @@ from importlib import import_module from typing import Any, Awaitable, Optional, Type +from redis import __version__ as redis_py_version + from redisvl.exceptions import RedisSearchError from redisvl.index import AsyncSearchIndex from redisvl.mcp.config import MCPConfig, load_mcp_config from redisvl.mcp.settings import MCPSettings -from redisvl.redis.connection import RedisConnectionFactory +from redisvl.redis.connection import RedisConnectionFactory, is_version_gte from redisvl.schema import IndexSchema try: @@ -40,12 +42,15 @@ def __init__(self, settings: MCPSettings): self.config: Optional[MCPConfig] = None self._index: Optional[AsyncSearchIndex] = None self._vectorizer: Optional[Any] = None + self._supports_native_hybrid_search: Optional[bool] = None self._semaphore: Optional[asyncio.Semaphore] = None + self._tools_registered = False async def startup(self) -> None: """Load config, inspect the configured index, and initialize dependencies.""" self.config = load_mcp_config(self.mcp_settings.config) self._semaphore = asyncio.Semaphore(self.config.runtime.max_concurrency) + self._supports_native_hybrid_search = None timeout = self.config.runtime.startup_timeout_seconds client = None @@ -76,12 +81,16 @@ async def startup(self) -> None: # The server acquired this client explicitly during startup, so hand # ownership to the index for a single shutdown path. self._index._owns_redis_client = True + self.config.validate_search( + supports_native_hybrid_search=await self.supports_native_hybrid_search(), + ) self._vectorizer = await asyncio.wait_for( asyncio.to_thread(self._build_vectorizer), timeout=timeout, ) self._validate_vectorizer_dims(effective_schema) + self._register_tools() except Exception: if self._index is not None: await self.shutdown() @@ -102,6 +111,7 @@ async def shutdown(self) -> None: elif callable(close): close() finally: + self._supports_native_hybrid_search = None if self._index is not None: index = self._index self._index = None @@ -155,6 +165,37 @@ def _validate_vectorizer_dims(self, schema: IndexSchema) -> None: f"Vectorizer dims {actual_dims} do not match configured vector field dims {configured_dims}" ) + async def supports_native_hybrid_search(self) -> bool: + """Return whether the current runtime supports Redis native hybrid search.""" + if self._supports_native_hybrid_search is not None: + return self._supports_native_hybrid_search + if self._index is None: + raise RuntimeError("MCP server has not been started") + if not is_version_gte(redis_py_version, "7.1.0"): + self._supports_native_hybrid_search = False + return False + + client = await self._index._get_client() + info = await client.info("server") + if not is_version_gte(info.get("redis_version", "0.0.0"), "8.4.0"): + self._supports_native_hybrid_search = False + return False + + self._supports_native_hybrid_search = hasattr( + client.ft(self._index.schema.index.name), "hybrid_search" + ) + return self._supports_native_hybrid_search + + def _register_tools(self) -> None: + """Register MCP tools once the server is ready.""" + if self._tools_registered or not hasattr(self, "tool"): + return + + from redisvl.mcp.tools.search import register_search_tool + + register_search_tool(self) + self._tools_registered = True + @staticmethod def _is_missing_index_error(exc: RedisSearchError) -> bool: """Detect the Redis search errors that mean the configured index is absent.""" diff --git a/redisvl/mcp/tools/__init__.py b/redisvl/mcp/tools/__init__.py new file mode 100644 index 00000000..e47aef7c --- /dev/null +++ b/redisvl/mcp/tools/__init__.py @@ -0,0 +1,3 @@ +from redisvl.mcp.tools.search import search_records + +__all__ = ["search_records"] diff --git a/redisvl/mcp/tools/search.py b/redisvl/mcp/tools/search.py new file mode 100644 index 00000000..29da0496 --- /dev/null +++ b/redisvl/mcp/tools/search.py @@ -0,0 +1,402 @@ +import asyncio +import inspect +from typing import Any, Optional, Union + +from redisvl.mcp.errors import MCPErrorCode, RedisVLMCPError, map_exception +from redisvl.mcp.filters import parse_filter +from redisvl.query import AggregateHybridQuery, HybridQuery, TextQuery, VectorQuery + +DEFAULT_SEARCH_DESCRIPTION = "Search records in the configured Redis index." + +_NATIVE_HYBRID_DEFAULTS = { + "combination_method": "LINEAR", + "linear_text_weight": 0.3, +} + + +def _validate_request( + *, + query: str, + limit: Optional[int], + offset: int, + return_fields: Optional[list[str]], + server: Any, + index: Any, +) -> tuple[int, list[str]]: + """Validate a `search-records` request and resolve default projection. + + The MCP caller can only supply query text, pagination, filters, and return + fields. Search mode and tuning are sourced from config, so this validation + step focuses only on the public request contract. + """ + + runtime = server.config.runtime + + if not isinstance(query, str) or not query.strip(): + raise RedisVLMCPError( + "query must be a non-empty string", + code=MCPErrorCode.INVALID_REQUEST, + retryable=False, + ) + + effective_limit = runtime.default_limit if limit is None else limit + if not isinstance(effective_limit, int) or effective_limit <= 0: + raise RedisVLMCPError( + "limit must be greater than 0", + code=MCPErrorCode.INVALID_REQUEST, + retryable=False, + ) + if effective_limit > runtime.max_limit: + raise RedisVLMCPError( + f"limit must be less than or equal to {runtime.max_limit}", + code=MCPErrorCode.INVALID_REQUEST, + retryable=False, + ) + if not isinstance(offset, int) or offset < 0: + raise RedisVLMCPError( + "offset must be greater than or equal to 0", + code=MCPErrorCode.INVALID_REQUEST, + retryable=False, + ) + + schema_fields = set(index.schema.field_names) + vector_field_name = runtime.vector_field_name + + if return_fields is None: + fields = [ + field_name + for field_name in index.schema.field_names + if field_name != vector_field_name + ] + else: + if not isinstance(return_fields, list): + raise RedisVLMCPError( + "return_fields must be a list of field names", + code=MCPErrorCode.INVALID_REQUEST, + retryable=False, + ) + fields = [] + for field_name in return_fields: + if not isinstance(field_name, str) or not field_name: + raise RedisVLMCPError( + "return_fields must contain non-empty strings", + code=MCPErrorCode.INVALID_REQUEST, + retryable=False, + ) + if field_name not in schema_fields: + raise RedisVLMCPError( + f"Unknown return field '{field_name}'", + code=MCPErrorCode.INVALID_REQUEST, + retryable=False, + ) + if field_name == vector_field_name: + raise RedisVLMCPError( + f"Vector field '{vector_field_name}' cannot be returned", + code=MCPErrorCode.INVALID_REQUEST, + retryable=False, + ) + fields.append(field_name) + + return effective_limit, fields + + +def _normalize_record( + result: dict[str, Any], score_field: str, score_type: str +) -> dict[str, Any]: + """Convert one RedisVL result into the stable MCP result shape.""" + score = result.get(score_field) + if score is None and score_field == "score": + score = result.get("__score") + if score is None: + raise RedisVLMCPError( + f"Search result missing expected score field '{score_field}'", + code=MCPErrorCode.INTERNAL_ERROR, + retryable=False, + ) + + record = dict(result) + doc_id = record.pop("id", None) + if doc_id is None: + doc_id = record.pop("__key", None) + if doc_id is None: + doc_id = record.pop("key", None) + if doc_id is None: + raise RedisVLMCPError( + "Search result missing id", + code=MCPErrorCode.INTERNAL_ERROR, + retryable=False, + ) + + for field_name in ( + "vector_distance", + "score", + "__score", + "text_score", + "vector_similarity", + "hybrid_score", + ): + record.pop(field_name, None) + + return { + "id": doc_id, + "score": float(score), + "score_type": score_type, + "record": record, + } + + +async def _embed_query(vectorizer: Any, query: str) -> Any: + """Embed the query text, tolerating vectorizers without real async support.""" + aembed = getattr(vectorizer, "aembed", None) + if callable(aembed): + try: + return await aembed(query) + except NotImplementedError: + pass + embed = getattr(vectorizer, "embed") + if inspect.iscoroutinefunction(embed): + return await embed(query) + return await asyncio.to_thread(embed, query) + + +def _get_configured_search(server: Any) -> tuple[str, dict[str, Any]]: + """Return the configured search mode and normalized query params.""" + search_config = server.config.search + return search_config.type, search_config.to_query_params() + + +def _build_native_hybrid_kwargs( + *, + query: str, + embedding: Any, + runtime: Any, + filter_expression: Any, + return_fields: list[str], + num_results: int, + search_params: dict[str, Any], +) -> dict[str, Any]: + """Build native `HybridQuery` kwargs from MCP config-owned hybrid params.""" + params = dict(search_params) + combination_method = params.setdefault( + "combination_method", + _NATIVE_HYBRID_DEFAULTS["combination_method"], + ) + if combination_method == "LINEAR": + linear_text_weight = params.pop( + "linear_text_weight", + _NATIVE_HYBRID_DEFAULTS["linear_text_weight"], + ) + params["linear_alpha"] = linear_text_weight + else: + params.pop("linear_text_weight", None) + + return { + "text": query, + "text_field_name": runtime.text_field_name, + "vector": embedding, + "vector_field_name": runtime.vector_field_name, + "filter_expression": filter_expression, + "return_fields": ["__key", *return_fields], + "num_results": num_results, + "yield_text_score_as": "text_score", + "yield_vsim_score_as": "vector_similarity", + "yield_combined_score_as": "hybrid_score", + **params, + } + + +def _build_fallback_hybrid_kwargs( + *, + query: str, + embedding: Any, + runtime: Any, + filter_expression: Any, + return_fields: list[str], + num_results: int, + search_params: dict[str, Any], +) -> dict[str, Any]: + """Build aggregate fallback kwargs while preserving MCP fusion semantics.""" + params = { + key: value + for key, value in search_params.items() + if key in {"text_scorer", "stopwords", "text_weights"} + } + linear_text_weight = search_params.get("linear_text_weight", 0.3) + params["alpha"] = 1 - linear_text_weight + + return { + "text": query, + "text_field_name": runtime.text_field_name, + "vector": embedding, + "vector_field_name": runtime.vector_field_name, + "filter_expression": filter_expression, + "return_fields": ["__key", *return_fields], + "num_results": num_results, + **params, + } + + +async def _build_query( + *, + server: Any, + index: Any, + query: str, + limit: int, + offset: int, + filter_value: Optional[Union[str, dict[str, Any]]], + return_fields: list[str], +) -> tuple[Any, str, str, str]: + """Build the RedisVL query object from configured search mode and params. + + Returns the query instance, the raw score field to read from RedisVL + results, the public MCP `score_type`, and the configured `search_type`. + """ + runtime = server.config.runtime + search_type, search_params = _get_configured_search(server) + num_results = limit + offset + filter_expression = parse_filter(filter_value, index.schema) + + if search_type == "vector": + vectorizer = await server.get_vectorizer() + embedding = await _embed_query(vectorizer, query) + vector_kwargs = { + "vector": embedding, + "vector_field_name": runtime.vector_field_name, + "filter_expression": filter_expression, + "return_fields": return_fields, + "num_results": num_results, + **search_params, + } + if "normalize_vector_distance" not in vector_kwargs: + vector_kwargs["normalize_vector_distance"] = True + normalize_vector_distance = vector_kwargs["normalize_vector_distance"] + return ( + VectorQuery(**vector_kwargs), + "vector_distance", + ( + "vector_distance_normalized" + if normalize_vector_distance + else "vector_distance" + ), + search_type, + ) + + if search_type == "fulltext": + return ( + TextQuery( + text=query, + text_field_name=runtime.text_field_name, + filter_expression=filter_expression, + return_fields=return_fields, + num_results=num_results, + **search_params, + ), + "score", + "text_score", + search_type, + ) + + vectorizer = await server.get_vectorizer() + embedding = await _embed_query(vectorizer, query) + if await server.supports_native_hybrid_search(): + native_query = HybridQuery( + **_build_native_hybrid_kwargs( + query=query, + embedding=embedding, + runtime=runtime, + filter_expression=filter_expression, + return_fields=return_fields, + num_results=num_results, + search_params=search_params, + ) + ) + native_query.postprocessing_config.apply(__key="@__key") + return native_query, "hybrid_score", "hybrid_score", search_type + + fallback_query = AggregateHybridQuery( + **_build_fallback_hybrid_kwargs( + query=query, + embedding=embedding, + runtime=runtime, + filter_expression=filter_expression, + return_fields=return_fields, + num_results=num_results, + search_params=search_params, + ) + ) + return fallback_query, "hybrid_score", "hybrid_score", search_type + + +async def search_records( + server: Any, + *, + query: str, + limit: Optional[int] = None, + offset: int = 0, + filter: Optional[Union[str, dict[str, Any]]] = None, + return_fields: Optional[list[str]] = None, +) -> dict[str, Any]: + """Execute `search-records` against the configured Redis index binding.""" + try: + index = await server.get_index() + effective_limit, effective_return_fields = _validate_request( + query=query, + limit=limit, + offset=offset, + return_fields=return_fields, + server=server, + index=index, + ) + built_query, score_field, score_type, search_type = await _build_query( + server=server, + index=index, + query=query.strip(), + limit=effective_limit, + offset=offset, + filter_value=filter, + return_fields=effective_return_fields, + ) + raw_results = await server.run_guarded( + "search-records", + index.query(built_query), + ) + sliced_results = raw_results[offset : offset + effective_limit] + return { + "search_type": search_type, + "offset": offset, + "limit": effective_limit, + "results": [ + _normalize_record(result, score_field, score_type) + for result in sliced_results + ], + } + except RedisVLMCPError: + raise + except Exception as exc: + raise map_exception(exc) from exc + + +def register_search_tool(server: Any) -> None: + """Register the MCP `search-records` tool with its config-owned contract.""" + description = ( + server.mcp_settings.tool_search_description or DEFAULT_SEARCH_DESCRIPTION + ) + + async def search_records_tool( + query: str, + limit: Optional[int] = None, + offset: int = 0, + filter: Optional[Union[str, dict[str, Any]]] = None, + return_fields: Optional[list[str]] = None, + ): + """FastMCP wrapper for the `search-records` tool.""" + return await search_records( + server, + query=query, + limit=limit, + offset=offset, + filter=filter, + return_fields=return_fields, + ) + + server.tool(name="search-records", description=description)(search_records_tool) diff --git a/redisvl/query/filter.py b/redisvl/query/filter.py index 0295568f..f30870f2 100644 --- a/redisvl/query/filter.py +++ b/redisvl/query/filter.py @@ -164,7 +164,7 @@ def __eq__(self, other: Union[List[str], str]) -> "FilterExpression": return FilterExpression(str(self)) @check_operator_misuse - def __ne__(self, other) -> "FilterExpression": + def __ne__(self, other: Union[List[str], str]) -> "FilterExpression": """Create a Tag inequality filter expression. Args: @@ -298,7 +298,7 @@ def __eq__(self, other) -> "FilterExpression": return FilterExpression(str(self)) @check_operator_misuse - def __ne__(self, other) -> "FilterExpression": + def __ne__(self, other: GeoRadius) -> "FilterExpression": """Create a geographic filter outside of a specified GeoRadius. Args: @@ -349,11 +349,11 @@ class Num(FilterField): SUPPORTED_VAL_TYPES = (int, float, tuple, type(None)) - def __eq__(self, other: int) -> "FilterExpression": + def __eq__(self, other: Union[int, float]) -> "FilterExpression": """Create a Numeric equality filter expression. Args: - other (int): The value to filter on. + other (Union[int, float]): The value to filter on. .. code-block:: python @@ -364,11 +364,11 @@ def __eq__(self, other: int) -> "FilterExpression": self._set_value(other, self.SUPPORTED_VAL_TYPES, FilterOperator.EQ) return FilterExpression(str(self)) - def __ne__(self, other: int) -> "FilterExpression": + def __ne__(self, other: Union[int, float]) -> "FilterExpression": """Create a Numeric inequality filter expression. Args: - other (int): The value to filter on. + other (Union[int, float]): The value to filter on. .. code-block:: python @@ -380,11 +380,11 @@ def __ne__(self, other: int) -> "FilterExpression": self._set_value(other, self.SUPPORTED_VAL_TYPES, FilterOperator.NE) return FilterExpression(str(self)) - def __gt__(self, other: int) -> "FilterExpression": + def __gt__(self, other: Union[int, float]) -> "FilterExpression": """Create a Numeric greater than filter expression. Args: - other (int): The value to filter on. + other (Union[int, float]): The value to filter on. .. code-block:: python @@ -396,11 +396,11 @@ def __gt__(self, other: int) -> "FilterExpression": self._set_value(other, self.SUPPORTED_VAL_TYPES, FilterOperator.GT) return FilterExpression(str(self)) - def __lt__(self, other: int) -> "FilterExpression": + def __lt__(self, other: Union[int, float]) -> "FilterExpression": """Create a Numeric less than filter expression. Args: - other (int): The value to filter on. + other (Union[int, float]): The value to filter on. .. code-block:: python @@ -412,11 +412,11 @@ def __lt__(self, other: int) -> "FilterExpression": self._set_value(other, self.SUPPORTED_VAL_TYPES, FilterOperator.LT) return FilterExpression(str(self)) - def __ge__(self, other: int) -> "FilterExpression": + def __ge__(self, other: Union[int, float]) -> "FilterExpression": """Create a Numeric greater than or equal to filter expression. Args: - other (int): The value to filter on. + other (Union[int, float]): The value to filter on. .. code-block:: python @@ -428,11 +428,11 @@ def __ge__(self, other: int) -> "FilterExpression": self._set_value(other, self.SUPPORTED_VAL_TYPES, FilterOperator.GE) return FilterExpression(str(self)) - def __le__(self, other: int) -> "FilterExpression": + def __le__(self, other: Union[int, float]) -> "FilterExpression": """Create a Numeric less than or equal to filter expression. Args: - other (int): The value to filter on. + other (Union[int, float]): The value to filter on. .. code-block:: python @@ -759,7 +759,9 @@ def _convert_to_timestamp(self, value, end_date=False): raise TypeError(f"Unsupported type for timestamp conversion: {type(value)}") - def __eq__(self, other) -> FilterExpression: + def __eq__( + self, other: Union[datetime.datetime, datetime.date, str, int, float] + ) -> FilterExpression: """ Filter for timestamps equal to the specified value. For date objects (without time), this matches the entire day. @@ -774,6 +776,7 @@ def __eq__(self, other) -> FilterExpression: # For date objects, match the entire day if isinstance(other, str): other = datetime.datetime.strptime(other, "%Y-%m-%d").date() + assert isinstance(other, datetime.date) # validate for mypy start = datetime.datetime.combine(other, datetime.time.min).astimezone( datetime.timezone.utc ) @@ -786,7 +789,9 @@ def __eq__(self, other) -> FilterExpression: self._set_value(timestamp, self.SUPPORTED_TYPES, FilterOperator.EQ) return FilterExpression(str(self)) - def __ne__(self, other) -> FilterExpression: + def __ne__( + self, other: Union[datetime.datetime, datetime.date, str, int, float] + ) -> FilterExpression: """ Filter for timestamps not equal to the specified value. For date objects (without time), this excludes the entire day. @@ -801,6 +806,7 @@ def __ne__(self, other) -> FilterExpression: # For date objects, exclude the entire day if isinstance(other, str): other = datetime.datetime.strptime(other, "%Y-%m-%d").date() + assert isinstance(other, datetime.date) # validate for mypy start = datetime.datetime.combine(other, datetime.time.min) end = datetime.datetime.combine(other, datetime.time.max) return self.between(start, end) diff --git a/spec/MCP.md b/spec/MCP.md index 5f09e723..c160db79 100644 --- a/spec/MCP.md +++ b/spec/MCP.md @@ -13,6 +13,8 @@ metadata: This specification defines a Model Context Protocol (MCP) server for RedisVL that allows MCP clients to search and upsert data in an existing Redis index. +Search behavior is owned by server configuration. MCP clients provide query text, filtering, pagination, and field projection, but do not choose the search mode or runtime tuning parameters. + The MCP design targets indexes hosted on open-source Redis Stack, Redis Cloud, or Redis Enterprise, provided the required Search capabilities are available for the configured tool behavior. The server is designed for stdio transport first and must be runnable via: @@ -25,7 +27,7 @@ For a production-oriented usage narrative and end-to-end example, see [MCP-produ ### Goals -1. Expose RedisVL search capabilities (`vector`, `fulltext`, `hybrid`) through stable MCP tools. +1. Expose configured RedisVL search capabilities (`vector`, `fulltext`, `hybrid`) through stable MCP tools without requiring MCP clients to configure retrieval strategy. 2. Support controlled write access via an upsert tool. 3. Automatically reconstruct the index schema from an existing Redis index instead of requiring a full manual schema definition. 4. Keep the vectorizer configuration explicit and user-defined. @@ -59,7 +61,7 @@ These are hard compatibility expectations for v1. Notes: - This spec standardizes on the standalone `fastmcp` package for server implementation. It does not assume the official `mcp` package is on a 2.x line. - Client SDK examples may still use whichever client-side MCP package their ecosystem requires. -- Native hybrid support is preferred when available because it aligns with current Redis runtime capabilities, but lack of native support is not a blocker for `search_type=\"hybrid\"`. +- Native hybrid support is preferred when available because it aligns with current Redis runtime capabilities, but lack of native support is not a blocker for `indexes..search.type=\"hybrid\"` when the configured search params remain compatible with the aggregate fallback. --- @@ -148,6 +150,16 @@ indexes: dims: 1536 datatype: float32 + search: + type: hybrid + params: + text_scorer: BM25STD + stopwords: english + vector_search_method: KNN + combination_method: LINEAR + linear_text_weight: 0.3 + knn_ef_runtime: 150 + runtime: # required explicit field mapping for tool behavior text_field_name: content @@ -170,6 +182,51 @@ indexes: max_concurrency: 16 ``` +### Search Configuration (Normative) + +`indexes..search` defines the retrieval strategy for the sole bound index in v1. Tool callers must not override this configuration. + +Required fields: + +- `type`: `vector` | `fulltext` | `hybrid` +- `params`: optional object whose allowed keys depend on `type` + +Allowed `params` by `type`: + +- `vector` + - `hybrid_policy` + - `batch_size` + - `ef_runtime` + - `epsilon` + - `search_window_size` + - `use_search_history` + - `search_buffer_capacity` + - `normalize_vector_distance` +- `fulltext` + - `text_scorer` + - `stopwords` + - `text_weights` +- `hybrid` + - `text_scorer` + - `stopwords` + - `text_weights` + - `vector_search_method` + - `knn_ef_runtime` + - `range_radius` + - `range_epsilon` + - `combination_method` + - `rrf_window` + - `rrf_constant` + - `linear_text_weight` + +Normalization rules: + +1. `linear_text_weight` is the MCP config's stable meaning for linear hybrid fusion and always represents the text-side weight. +2. When building native `HybridQuery`, the server must pass `linear_text_weight` through as `linear_alpha`. +3. When building `AggregateHybridQuery`, the server must translate `linear_text_weight` to `alpha = 1 - linear_text_weight` so the config meaning does not change across implementations. +4. `linear_text_weight` is only valid when `combination_method` is `LINEAR`. +5. Hybrid configs using FT.SEARCH-only runtime params (`knn_ef_runtime`) must fail startup if the environment only supports the aggregate fallback path. + ### Schema Discovery and Override Rules 1. `server.redis_url` is required. @@ -179,19 +236,22 @@ indexes: 5. The server must reconstruct the base schema from Redis metadata, preferably via existing RedisVL inspection primitives built on `FT.INFO`. 6. `indexes..vectorizer` remains fully manual and is never inferred from Redis index metadata in v1. 7. `indexes..schema_overrides` is optional and exists only to supplement incomplete inspection data. -8. Discovered index identity is authoritative: +8. `indexes..search.type` is required and is authoritative for query construction. +9. `indexes..search.params` is optional but, when present, may only contain keys valid for the configured `search.type`. +10. Tool requests implicitly target the sole configured index binding and its configured search behavior in v1. No `index`, `search_type`, or search-tuning request parameters are exposed. +11. Tool callers may control only query text, filtering, pagination, and returned fields for `search-records`. +12. Discovered index identity is authoritative: - `indexes..redis_name` - storage type - field identity (`name`, `type`, and `path` when applicable) -9. Overrides may: +13. Overrides may: - add missing attrs for a discovered field - replace discovered attrs for a discovered field when needed for compatibility -10. Overrides must not: +14. Overrides must not: - redefine index identity - add entirely new fields that do not exist in the inspected index - change a discovered field's `name`, `type`, or `path` -11. Override conflicts must fail startup with a config error. -12. Tool requests implicitly target the sole configured index binding in v1. No `index` request parameter is exposed yet. +15. Override conflicts must fail startup with a config error. ### Env Substitution Rules @@ -210,13 +270,17 @@ Server startup must fail fast if: 4. `indexes` missing, empty, or containing more than one entry. 5. The configured binding id is blank. 6. `indexes..redis_name` missing or blank. -7. The referenced Redis index does not exist. -8. Schema inspection fails and no valid `indexes..schema_overrides` resolve the issue. -9. `indexes..runtime.text_field_name` not in the effective schema. -10. `indexes..runtime.vector_field_name` not in the effective schema or not vector type. -11. `indexes..runtime.default_embed_text_field` not in the effective schema. -12. `default_limit <= 0` or `max_limit < default_limit`. -13. `max_upsert_records <= 0`. +7. `indexes..search.type` missing or not one of `vector`, `fulltext`, `hybrid`. +8. `indexes..search.params` contains keys that are incompatible with the configured `search.type`. +9. `indexes..search.params.linear_text_weight` is present without `combination_method: LINEAR`. +10. A hybrid config relies on FT.SEARCH-only runtime params and the environment only supports the aggregate fallback path. +11. The referenced Redis index does not exist. +12. Schema inspection fails and no valid `indexes..schema_overrides` resolve the issue. +13. `indexes..runtime.text_field_name` not in the effective schema. +14. `indexes..runtime.vector_field_name` not in the effective schema or not vector type. +15. `indexes..runtime.default_embed_text_field` not in the effective schema. +16. `default_limit <= 0` or `max_limit < default_limit`. +17. `max_upsert_records <= 0`. --- @@ -234,9 +298,10 @@ On server startup: 6. Convert the inspected index metadata into an `IndexSchema`. 7. Apply any validated `indexes..schema_overrides` to produce the effective schema. 8. Instantiate `AsyncSearchIndex` from the effective schema. -9. Instantiate the configured `indexes..vectorizer`. -10. Validate vectorizer dimensions against the effective vector field dims when available. -11. Register tools (omit upsert in read-only mode). +9. Validate `indexes..search` against the effective schema and current runtime capabilities. +10. Instantiate the configured `indexes..vectorizer`. +11. Validate vectorizer dimensions against the effective vector field dims when available. +12. Register tools (omit upsert in read-only mode). If vector field attributes cannot be reconstructed from Redis metadata on the target Redis version, startup must fail with an actionable error unless `indexes..schema_overrides` provides the missing attrs. @@ -299,14 +364,13 @@ Tool executions are bounded by an async semaphore (`runtime.max_concurrency`). R ## Tool: `search-records` -Search records using vector, full-text, or hybrid query. +Search records using the configured search behavior for the bound index. ### Request Contract | Parameter | Type | Required | Default | Constraints | |----------|------|----------|---------|-------------| | `query` | str | yes | - | non-empty | -| `search_type` | enum | no | `vector` | `vector` \| `fulltext` \| `hybrid` | | `limit` | int | no | `runtime.default_limit` | `1..runtime.max_limit` | | `offset` | int | no | `0` | `>=0` | | `filter` | string \\| object | no | `null` | Raw RedisVL filter string or DSL object | @@ -335,12 +399,14 @@ Search records using vector, full-text, or hybrid query. ### Search Semantics -- `vector`: embeds `query` with configured vectorizer, builds `VectorQuery`. -- `fulltext`: builds `TextQuery`. +- `search_type` in the response is informational metadata derived from `indexes..search.type`. +- `search-records` must reject deprecated client-side search-mode or tuning inputs with `invalid_request`. +- `vector`: embeds `query` with the configured vectorizer and builds `VectorQuery` using `indexes..search.params`. +- `fulltext`: builds `TextQuery` using `indexes..search.params`. - `hybrid`: embeds `query` and selects the query implementation by runtime capability: - use native `HybridQuery` when Redis `>=8.4.0` and redis-py `>=7.1.0` are available - otherwise fall back to `AggregateHybridQuery` -- The MCP request/response contract for `hybrid` is identical across both implementation paths. +- The MCP request/response contract for `hybrid` is identical across both implementation paths because config normalization hides class-specific fusion semantics from tool callers. - In v1, `filter` is applied uniformly to the hybrid query rather than allowing separate text-side and vector-side filters. This is intentional to keep the API simple; future versions may expose finer-grained hybrid filtering controls. ### Errors @@ -421,8 +487,10 @@ For the sole configured binding in v1, the server owns these validated values: - `text_field_name` - `vector_field_name` - `default_embed_text_field` +- `search.type` +- `search.params` -Schema discovery is automatic in v1. Field mapping is not. Runtime field mappings remain explicit so the server does not guess among multiple valid text or vector fields. +Schema discovery is automatic in v1. Field mapping is not. Search construction is configuration-owned. Runtime field mappings remain explicit so the server does not guess among multiple valid text or vector fields, and MCP callers do not choose retrieval mode or tuning. --- @@ -478,7 +546,7 @@ async def main(): ) as server: agent = Agent( name="search-agent", - instructions="Search and maintain Redis-backed knowledge.", + instructions="Search and maintain Redis-backed knowledge using the server-configured retrieval strategy.", mcp_servers=[server], ) ``` @@ -494,7 +562,7 @@ from mcp import StdioServerParameters root_agent = LlmAgent( model="gemini-2.0-flash", name="redis_search_agent", - instruction="Search and maintain Redis-backed knowledge using vector search.", + instruction="Search and maintain Redis-backed knowledge using the server-configured retrieval strategy.", tools=[ McpToolset( connection_params=StdioConnectionParams( @@ -570,6 +638,8 @@ Note: Full n8n MCP client support depends on n8n's MCP implementation. Refer to - env substitution success/failure - schema inspection merge and override validation - field mapping validation + - `indexes..search` validation by type + - normalized hybrid fusion validation - `test_filters.py` - DSL parsing, invalid operators, type mismatches - `test_errors.py` @@ -582,10 +652,14 @@ Note: Full n8n MCP client support depends on n8n's MCP implementation. Refer to - missing index failure - vector field inspection gap resolved by `indexes..schema_overrides` - conflicting override failure + - hybrid config with FT.SEARCH-only params rejected when only aggregate fallback is available - `test_search_tool.py` - - vector/fulltext/hybrid success paths + - configured `vector` / `fulltext` / `hybrid` success paths + - request without `search_type` succeeds + - deprecated client-side search-mode or tuning params rejected with `invalid_request` + - response reports configured `search_type` - native hybrid path on Redis `>=8.4.0` - - aggregate hybrid fallback path on older supported runtimes + - aggregate hybrid fallback path on older supported runtimes when config is compatible - pagination and field projection - filter behavior - `test_upsert_tool.py` @@ -622,12 +696,13 @@ DoD: Deliverables: 1. `search-records` request/response contract. 2. Filter parser (JSON DSL + raw string pass-through). -3. Hybrid query selection between native and aggregate implementations. +3. Config-owned search construction and hybrid query selection between native and aggregate implementations. DoD: 1. All search modes tested. 2. Invalid filter returns `invalid_filter`. -3. `hybrid` uses native execution when available and `AggregateHybridQuery` otherwise, without changing the MCP contract. +3. Deprecated client-side search-mode and tuning inputs return `invalid_request`. +4. `hybrid` uses native execution when available and `AggregateHybridQuery` otherwise, without changing the MCP contract or the meaning of `linear_text_weight`. ### Phase 3: Upsert Tool @@ -657,12 +732,11 @@ DoD: Deliverables: 1. Config reference and examples. 2. Client setup examples. -3. Companion production example document. -4. Troubleshooting guide with common errors and fixes. +3. Troubleshooting guide with common errors and fixes. DoD: 1. Docs reflect normative contracts in this spec. -2. Companion example is aligned with the config and lifecycle contract. +2. Client-facing examples do not imply MCP callers choose retrieval mode. --- @@ -670,17 +744,20 @@ DoD: 1. Runtime mismatch for hybrid search. - Native hybrid requires newer Redis and redis-py capabilities, while older supported environments may still need the aggregate fallback path. - - Mitigation: explicitly detect runtime capability and select native `HybridQuery` or `AggregateHybridQuery` deterministically. + - Mitigation: explicitly detect runtime capability, reject incompatible hybrid configs at startup, and otherwise select native `HybridQuery` or `AggregateHybridQuery` deterministically. 2. Dependency drift across provider vectorizers. - Mitigation: dependency matrix and startup validation. -3. Ambiguous filter behavior causing agent retries. - - Mitigation: explicit raw-string pass-through semantics and deterministic DSL parser errors. +3. Search behavior drift caused by client-selected tuning. + - Mitigation: keep search mode and query construction params in config, not in the MCP request surface. 4. Hidden partial writes during failures. - Mitigation: conservative `partial_write_possible` signaling. 5. Incomplete schema reconstruction on older Redis versions. - `FT.INFO` may not return enough vector metadata on some older Redis versions to fully reconstruct vector field attrs. - Mitigation: fail fast with an actionable error and support targeted `indexes..schema_overrides` for missing attrs. -6. Security and deployment limitations (v1 scope). +6. Hybrid fusion semantics differ between `HybridQuery` and `AggregateHybridQuery`. + - Native `HybridQuery` uses text-weight semantics while `AggregateHybridQuery` exposes vector-weight semantics. + - Mitigation: normalize on `linear_text_weight` in MCP config and translate internally per execution path. +7. Security and deployment limitations (v1 scope). - This implementation is stdio-first and not production-hardened by itself. It does not include: - Authentication/authorization mechanisms. - Remote transports (SSE/HTTP) that would enable multi-tenant or networked deployments. diff --git a/tests/integration/test_mcp/test_search_tool.py b/tests/integration/test_mcp/test_search_tool.py new file mode 100644 index 00000000..a5eaf8f3 --- /dev/null +++ b/tests/integration/test_mcp/test_search_tool.py @@ -0,0 +1,307 @@ +from pathlib import Path + +import pytest +import yaml + +from redisvl.index import AsyncSearchIndex +from redisvl.mcp.errors import MCPErrorCode, RedisVLMCPError +from redisvl.mcp.server import RedisVLMCPServer +from redisvl.mcp.settings import MCPSettings +from redisvl.mcp.tools.search import search_records +from redisvl.redis.connection import is_version_gte +from redisvl.redis.utils import array_to_buffer +from redisvl.schema import IndexSchema +from tests.conftest import get_redis_version_async, skip_if_redis_version_below_async + + +class FakeVectorizer: + def __init__(self, model: str, dims: int = 3, **kwargs): + self.model = model + self.dims = dims + self.kwargs = kwargs + + def embed(self, content: str = "", **kwargs): + del content, kwargs + return [0.1, 0.1, 0.5] + + +@pytest.fixture +async def searchable_index(async_client, worker_id): + schema = IndexSchema.from_dict( + { + "index": { + "name": f"mcp-search-{worker_id}", + "prefix": f"mcp-search:{worker_id}", + "storage_type": "hash", + }, + "fields": [ + {"name": "content", "type": "text"}, + {"name": "category", "type": "tag"}, + {"name": "rating", "type": "numeric"}, + { + "name": "embedding", + "type": "vector", + "attrs": { + "algorithm": "flat", + "dims": 3, + "distance_metric": "cosine", + "datatype": "float32", + }, + }, + ], + } + ) + index = AsyncSearchIndex(schema=schema, redis_client=async_client) + await index.create(overwrite=True, drop=True) + + def preprocess(record: dict) -> dict: + return { + **record, + "embedding": array_to_buffer(record["embedding"], "float32"), + } + + await index.load( + [ + { + "id": f"doc:{worker_id}:1", + "content": "science article about planets", + "category": "science", + "rating": 5, + "embedding": [0.1, 0.1, 0.5], + }, + { + "id": f"doc:{worker_id}:2", + "content": "medical science and health", + "category": "health", + "rating": 4, + "embedding": [0.1, 0.1, 0.4], + }, + { + "id": f"doc:{worker_id}:3", + "content": "sports update and scores", + "category": "sports", + "rating": 3, + "embedding": [-0.2, 0.1, 0.0], + }, + ], + preprocess=preprocess, + ) + + yield index + + await index.delete(drop=True) + + +@pytest.fixture +def mcp_config_path(tmp_path: Path, redis_url: str): + def factory(redis_name: str, search: dict) -> str: + config = { + "server": {"redis_url": redis_url}, + "indexes": { + "knowledge": { + "redis_name": redis_name, + "vectorizer": { + "class": "FakeVectorizer", + "model": "fake-model", + "dims": 3, + }, + "search": search, + "runtime": { + "text_field_name": "content", + "vector_field_name": "embedding", + "default_embed_text_field": "content", + "default_limit": 2, + "max_limit": 5, + }, + } + }, + } + config_path = tmp_path / f"{redis_name}-{search['type']}.yaml" + config_path.write_text(yaml.safe_dump(config), encoding="utf-8") + return str(config_path) + + return factory + + +@pytest.fixture +async def started_server(monkeypatch, searchable_index, mcp_config_path): + monkeypatch.setattr( + "redisvl.mcp.server.resolve_vectorizer_class", + lambda class_name: FakeVectorizer, + ) + + async def factory(search: dict) -> RedisVLMCPServer: + server = RedisVLMCPServer( + MCPSettings( + config=mcp_config_path(searchable_index.schema.index.name, search) + ) + ) + await server.startup() + return server + + servers = [] + + async def started(search: dict) -> RedisVLMCPServer: + server = await factory(search) + servers.append(server) + return server + + yield started + + for server in servers: + await server.shutdown() + + +@pytest.mark.asyncio +async def test_search_records_vector_success_with_pagination_and_projection( + started_server, +): + server = await started_server( + { + "type": "vector", + "params": {"normalize_vector_distance": True}, + } + ) + + response = await search_records( + server, + query="science", + limit=1, + offset=1, + return_fields=["content", "category"], + ) + + assert response["search_type"] == "vector" + assert response["offset"] == 1 + assert response["limit"] == 1 + assert len(response["results"]) == 1 + assert response["results"][0]["score_type"] == "vector_distance_normalized" + assert set(response["results"][0]["record"]) == {"content", "category"} + + +@pytest.mark.asyncio +async def test_search_records_fulltext_success(started_server): + server = await started_server( + { + "type": "fulltext", + "params": { + "text_scorer": "BM25STD.NORM", + "stopwords": None, + }, + } + ) + + response = await search_records( + server, + query="science", + return_fields=["content", "category"], + ) + + assert response["search_type"] == "fulltext" + assert response["results"] + assert response["results"][0]["score_type"] == "text_score" + assert response["results"][0]["score"] is not None + assert "science" in response["results"][0]["record"]["content"] + + +@pytest.mark.asyncio +async def test_search_records_respects_raw_string_filter(started_server): + server = await started_server({"type": "vector"}) + + response = await search_records( + server, + query="science", + filter="@category:{science}", + return_fields=["content", "category"], + ) + + assert response["results"] + assert all( + result["record"]["category"] == "science" for result in response["results"] + ) + + +@pytest.mark.asyncio +async def test_search_records_respects_dsl_filter(started_server): + server = await started_server({"type": "vector"}) + + response = await search_records( + server, + query="science", + filter={"field": "rating", "op": "gte", "value": 4.5}, + return_fields=["content", "category", "rating"], + ) + + assert response["results"] + assert all( + float(result["record"]["rating"]) >= 4.5 for result in response["results"] + ) + + +@pytest.mark.asyncio +async def test_search_records_invalid_filter_returns_invalid_filter(started_server): + server = await started_server({"type": "vector"}) + + with pytest.raises(RedisVLMCPError) as exc_info: + await search_records( + server, + query="science", + filter={"field": "missing", "op": "eq", "value": "science"}, + ) + + assert exc_info.value.code == MCPErrorCode.INVALID_FILTER + + +@pytest.mark.asyncio +async def test_search_records_native_hybrid_success(started_server, async_client): + await skip_if_redis_version_below_async(async_client, "8.4.0") + server = await started_server( + { + "type": "hybrid", + "params": { + "combination_method": "LINEAR", + "linear_text_weight": 0.3, + "stopwords": None, + }, + } + ) + + response = await search_records( + server, + query="science", + return_fields=["content", "category"], + ) + + assert response["search_type"] == "hybrid" + assert response["results"] + assert response["results"][0]["score_type"] == "hybrid_score" + assert response["results"][0]["score"] is not None + + +@pytest.mark.asyncio +async def test_search_records_fallback_hybrid_success(started_server, async_client): + redis_version = await get_redis_version_async(async_client) + if is_version_gte(redis_version, "8.4.0"): + pytest.skip(f"Redis version {redis_version} uses native hybrid search") + + server = await started_server( + { + "type": "hybrid", + "params": { + "combination_method": "LINEAR", + "linear_text_weight": 0.3, + "stopwords": None, + }, + } + ) + + response = await search_records( + server, + query="science", + return_fields=["content", "category"], + ) + + assert response["search_type"] == "hybrid" + assert response["results"] + assert response["results"][0]["score_type"] == "hybrid_score" + assert response["results"][0]["score"] is not None diff --git a/tests/integration/test_mcp/test_server_startup.py b/tests/integration/test_mcp/test_server_startup.py index 8709202a..953aa6df 100644 --- a/tests/integration/test_mcp/test_server_startup.py +++ b/tests/integration/test_mcp/test_server_startup.py @@ -7,7 +7,9 @@ from redisvl.index import AsyncSearchIndex from redisvl.mcp.server import RedisVLMCPServer from redisvl.mcp.settings import MCPSettings +from redisvl.redis.connection import is_version_gte from redisvl.schema import IndexSchema +from tests.conftest import get_redis_version_async class FakeVectorizer: @@ -80,6 +82,7 @@ def factory( vector_dims: int = 3, schema_overrides: Optional[dict] = None, runtime_overrides: Optional[dict] = None, + search: Optional[dict] = None, ) -> str: runtime = { "text_field_name": "content", @@ -99,6 +102,7 @@ def factory( "model": "fake-model", "dims": vector_dims, }, + "search": search or {"type": "vector"}, "runtime": runtime, } }, @@ -136,6 +140,38 @@ async def test_server_startup_success(monkeypatch, existing_index, mcp_config_pa await server.shutdown() +@pytest.mark.asyncio +async def test_server_fails_when_hybrid_config_requires_native_runtime( + monkeypatch, existing_index, mcp_config_path, async_client +): + redis_version = await get_redis_version_async(async_client) + if is_version_gte(redis_version, "8.4.0"): + pytest.skip(f"Redis version {redis_version} supports native hybrid search") + + index = await existing_index(index_name="mcp-native-required") + monkeypatch.setattr( + "redisvl.mcp.server.resolve_vectorizer_class", + lambda class_name: FakeVectorizer, + ) + server = RedisVLMCPServer( + MCPSettings( + config=mcp_config_path( + redis_name=index.name, + search={ + "type": "hybrid", + "params": { + "vector_search_method": "KNN", + "knn_ef_runtime": 150, + }, + }, + ) + ) + ) + + with pytest.raises(ValueError, match="knn_ef_runtime"): + await server.startup() + + @pytest.mark.asyncio async def test_server_fails_when_configured_index_is_missing( monkeypatch, mcp_config_path, worker_id diff --git a/tests/unit/test_mcp/test_config.py b/tests/unit/test_mcp/test_config.py index 4a0520f0..bd20ed2d 100644 --- a/tests/unit/test_mcp/test_config.py +++ b/tests/unit/test_mcp/test_config.py @@ -15,6 +15,7 @@ def _valid_config() -> dict: "knowledge": { "redis_name": "docs-index", "vectorizer": {"class": "FakeVectorizer", "model": "test-model"}, + "search": {"type": "vector"}, "runtime": { "text_field_name": "content", "vector_field_name": "embedding", @@ -68,17 +69,19 @@ def test_load_mcp_config_env_substitution(tmp_path: Path, monkeypatch): server: redis_url: ${REDIS_URL:-redis://localhost:6379} indexes: - knowledge: - redis_name: docs-index - vectorizer: - class: FakeVectorizer - model: ${VECTOR_MODEL:-test-model} - api_config: - api_key: ${OPENAI_API_KEY} - runtime: - text_field_name: content - vector_field_name: embedding - default_embed_text_field: content + knowledge: + redis_name: docs-index + vectorizer: + class: FakeVectorizer + model: ${VECTOR_MODEL:-test-model} + api_config: + api_key: ${OPENAI_API_KEY} + search: + type: vector + runtime: + text_field_name: content + vector_field_name: embedding + default_embed_text_field: content """.strip(), encoding="utf-8", ) @@ -101,15 +104,17 @@ def test_load_mcp_config_required_env_missing(tmp_path: Path, monkeypatch): server: redis_url: redis://localhost:6379 indexes: - knowledge: - redis_name: docs-index - vectorizer: - class: FakeVectorizer - model: ${VECTOR_MODEL} - runtime: - text_field_name: content - vector_field_name: embedding - default_embed_text_field: content + knowledge: + redis_name: docs-index + vectorizer: + class: FakeVectorizer + model: ${VECTOR_MODEL} + search: + type: vector + runtime: + text_field_name: content + vector_field_name: embedding + default_embed_text_field: content """.strip(), encoding="utf-8", ) @@ -166,6 +171,7 @@ def test_mcp_config_binding_helpers(): assert config.binding_id == "knowledge" assert config.binding.redis_name == "docs-index" + assert config.binding.search.type == "vector" assert config.runtime.default_embed_text_field == "content" assert config.vectorizer.class_name == "FakeVectorizer" assert config.redis_name == "docs-index" @@ -275,3 +281,116 @@ def test_load_mcp_config_requires_exactly_one_binding(tmp_path: Path): with pytest.raises(ValueError, match="exactly one configured index binding"): load_mcp_config(str(config_path)) + + +@pytest.mark.parametrize("search_type", ["vector", "fulltext", "hybrid"]) +def test_mcp_config_accepts_search_types(search_type): + config = _valid_config() + config["indexes"]["knowledge"]["search"] = {"type": search_type} + + loaded = MCPConfig.model_validate(config) + + assert loaded.binding.search.type == search_type + assert loaded.binding.search.params == {} + + +def test_mcp_config_requires_search_type(): + config = _valid_config() + del config["indexes"]["knowledge"]["search"]["type"] + + with pytest.raises(ValueError, match="type"): + MCPConfig.model_validate(config) + + +def test_mcp_config_rejects_invalid_search_type(): + config = _valid_config() + config["indexes"]["knowledge"]["search"] = {"type": "semantic"} + + with pytest.raises(ValueError, match="vector|fulltext|hybrid"): + MCPConfig.model_validate(config) + + +@pytest.mark.parametrize( + ("search_type", "params"), + [ + ("vector", {"text_scorer": "BM25STD"}), + ("fulltext", {"normalize_vector_distance": True}), + ("hybrid", {"normalize_vector_distance": True}), + ], +) +def test_mcp_config_rejects_invalid_search_params(search_type, params): + config = _valid_config() + config["indexes"]["knowledge"]["search"] = { + "type": search_type, + "params": params, + } + + with pytest.raises(ValueError, match="search.params"): + MCPConfig.model_validate(config) + + +def test_mcp_config_rejects_linear_text_weight_without_linear_combination(): + config = _valid_config() + config["indexes"]["knowledge"]["search"] = { + "type": "hybrid", + "params": { + "combination_method": "RRF", + "linear_text_weight": 0.3, + }, + } + + with pytest.raises(ValueError, match="linear_text_weight"): + MCPConfig.model_validate(config) + + +def test_mcp_config_normalizes_hybrid_linear_text_weight(): + config = _valid_config() + config["indexes"]["knowledge"]["search"] = { + "type": "hybrid", + "params": { + "combination_method": "LINEAR", + "linear_text_weight": 0.3, + }, + } + + loaded = MCPConfig.model_validate(config) + + assert loaded.binding.search.type == "hybrid" + assert loaded.binding.search.params["linear_text_weight"] == 0.3 + + +@pytest.mark.parametrize( + "params", + [ + {"knn_ef_runtime": 42}, + {"vector_search_method": "RANGE", "range_radius": 0.4}, + {"combination_method": "RRF", "rrf_window": 50}, + ], +) +def test_mcp_config_rejects_native_only_hybrid_runtime_params(params): + config = _valid_config() + config["indexes"]["knowledge"]["search"] = { + "type": "hybrid", + "params": params, + } + + loaded = MCPConfig.model_validate(config) + + with pytest.raises(ValueError, match="native hybrid search support"): + loaded.validate_search(supports_native_hybrid_search=False) + + +def test_mcp_config_allows_linear_hybrid_fallback_params(): + config = _valid_config() + config["indexes"]["knowledge"]["search"] = { + "type": "hybrid", + "params": { + "text_scorer": "TFIDF", + "combination_method": "LINEAR", + "linear_text_weight": 0.3, + }, + } + + loaded = MCPConfig.model_validate(config) + + loaded.validate_search(supports_native_hybrid_search=False) diff --git a/tests/unit/test_mcp/test_errors.py b/tests/unit/test_mcp/test_errors.py index 066e3173..ddd28622 100644 --- a/tests/unit/test_mcp/test_errors.py +++ b/tests/unit/test_mcp/test_errors.py @@ -26,6 +26,18 @@ def test_import_error_maps_to_dependency_missing(): assert mapped.retryable is False +def test_filter_error_is_preserved(): + original = RedisVLMCPError( + "bad filter", + code=MCPErrorCode.INVALID_FILTER, + retryable=False, + ) + + mapped = map_exception(original) + + assert mapped is original + + def test_redis_errors_map_to_backend_unavailable(): mapped = map_exception(RedisSearchError("redis unavailable")) diff --git a/tests/unit/test_mcp/test_filters.py b/tests/unit/test_mcp/test_filters.py new file mode 100644 index 00000000..4fb43b6a --- /dev/null +++ b/tests/unit/test_mcp/test_filters.py @@ -0,0 +1,136 @@ +import pytest + +from redisvl.mcp.errors import MCPErrorCode, RedisVLMCPError +from redisvl.mcp.filters import parse_filter +from redisvl.query.filter import FilterExpression +from redisvl.schema import IndexSchema + + +def _schema() -> IndexSchema: + return IndexSchema.from_dict( + { + "index": { + "name": "docs-index", + "prefix": "doc", + "storage_type": "hash", + }, + "fields": [ + {"name": "content", "type": "text"}, + {"name": "category", "type": "tag"}, + {"name": "rating", "type": "numeric"}, + { + "name": "embedding", + "type": "vector", + "attrs": { + "algorithm": "flat", + "dims": 3, + "distance_metric": "cosine", + "datatype": "float32", + }, + }, + ], + } + ) + + +def _render_filter(value): + if isinstance(value, FilterExpression): + return str(value) + return value + + +def test_parse_filter_passes_through_raw_string(): + raw = "@category:{science} @rating:[4 +inf]" + + parsed = parse_filter(raw, _schema()) + + assert parsed == raw + + +def test_parse_filter_builds_atomic_expression(): + parsed = parse_filter( + {"field": "category", "op": "eq", "value": "science"}, + _schema(), + ) + + assert isinstance(parsed, FilterExpression) + assert str(parsed) == "@category:{science}" + + +def test_parse_filter_builds_nested_logical_expression(): + parsed = parse_filter( + { + "and": [ + {"field": "category", "op": "eq", "value": "science"}, + { + "or": [ + {"field": "rating", "op": "gte", "value": 4.5}, + {"field": "content", "op": "like", "value": "quant*"}, + ] + }, + ] + }, + _schema(), + ) + + assert isinstance(parsed, FilterExpression) + assert ( + str(parsed) == "(@category:{science} (@rating:[4.5 +inf] | @content:(quant*)))" + ) + + +def test_parse_filter_builds_not_expression(): + parsed = parse_filter( + { + "not": {"field": "category", "op": "eq", "value": "science"}, + }, + _schema(), + ) + + assert _render_filter(parsed) == "(-(@category:{science}))" + + +def test_parse_filter_builds_exists_expression(): + parsed = parse_filter( + {"field": "content", "op": "exists"}, + _schema(), + ) + + assert _render_filter(parsed) == "(-ismissing(@content))" + + +def test_parse_filter_rejects_unknown_field(): + with pytest.raises(RedisVLMCPError) as exc_info: + parse_filter({"field": "missing", "op": "eq", "value": "science"}, _schema()) + + assert exc_info.value.code == MCPErrorCode.INVALID_FILTER + + +def test_parse_filter_rejects_unknown_operator(): + with pytest.raises(RedisVLMCPError) as exc_info: + parse_filter( + {"field": "category", "op": "contains", "value": "science"}, _schema() + ) + + assert exc_info.value.code == MCPErrorCode.INVALID_FILTER + + +def test_parse_filter_rejects_type_mismatch(): + with pytest.raises(RedisVLMCPError) as exc_info: + parse_filter({"field": "rating", "op": "gte", "value": "high"}, _schema()) + + assert exc_info.value.code == MCPErrorCode.INVALID_FILTER + + +def test_parse_filter_rejects_empty_logical_array(): + with pytest.raises(RedisVLMCPError) as exc_info: + parse_filter({"and": []}, _schema()) + + assert exc_info.value.code == MCPErrorCode.INVALID_FILTER + + +def test_parse_filter_rejects_malformed_payload(): + with pytest.raises(RedisVLMCPError) as exc_info: + parse_filter({"field": "category", "value": "science"}, _schema()) + + assert exc_info.value.code == MCPErrorCode.INVALID_FILTER diff --git a/tests/unit/test_mcp/test_search_tool_unit.py b/tests/unit/test_mcp/test_search_tool_unit.py new file mode 100644 index 00000000..5bb49a0d --- /dev/null +++ b/tests/unit/test_mcp/test_search_tool_unit.py @@ -0,0 +1,530 @@ +from types import SimpleNamespace +from typing import Optional + +import pytest + +from redisvl.mcp.config import MCPConfig +from redisvl.mcp.errors import MCPErrorCode, RedisVLMCPError +from redisvl.mcp.tools.search import _embed_query, register_search_tool, search_records +from redisvl.schema import IndexSchema + + +def _schema() -> IndexSchema: + return IndexSchema.from_dict( + { + "index": { + "name": "docs-index", + "prefix": "doc", + "storage_type": "hash", + }, + "fields": [ + {"name": "content", "type": "text"}, + {"name": "category", "type": "tag"}, + {"name": "rating", "type": "numeric"}, + { + "name": "embedding", + "type": "vector", + "attrs": { + "algorithm": "flat", + "dims": 3, + "distance_metric": "cosine", + "datatype": "float32", + }, + }, + ], + } + ) + + +def _config_with_search(search_type: str, params: Optional[dict] = None) -> MCPConfig: + return MCPConfig.model_validate( + { + "server": {"redis_url": "redis://localhost:6379"}, + "indexes": { + "knowledge": { + "redis_name": "docs-index", + "vectorizer": {"class": "FakeVectorizer", "model": "test-model"}, + "search": {"type": search_type, "params": params or {}}, + "runtime": { + "text_field_name": "content", + "vector_field_name": "embedding", + "default_embed_text_field": "content", + "default_limit": 2, + "max_limit": 5, + }, + } + }, + } + ) + + +class FakeVectorizer: + async def embed(self, text: str): + return [0.1, 0.2, 0.3] + + +class FakeIndex: + def __init__(self): + self.schema = _schema() + self.query_calls = [] + + async def query(self, query): + self.query_calls.append(query) + return [] + + +class FakeServer: + def __init__( + self, + *, + search_type: str = "vector", + search_params: Optional[dict] = None, + ): + self.config = _config_with_search(search_type, search_params) + self.mcp_settings = SimpleNamespace(tool_search_description=None) + self.index = FakeIndex() + self.vectorizer = FakeVectorizer() + self.registered_tools = [] + self.native_hybrid_supported = False + + async def get_index(self): + return self.index + + async def get_vectorizer(self): + return self.vectorizer + + async def run_guarded(self, operation_name, awaitable): + return await awaitable + + async def supports_native_hybrid_search(self): + return self.native_hybrid_supported + + def tool(self, name=None, description=None, **kwargs): + def decorator(fn): + self.registered_tools.append( + { + "name": name, + "description": description, + "fn": fn, + } + ) + return fn + + return decorator + + +class FakeQuery: + def __init__(self, **kwargs): + self.kwargs = kwargs + + +@pytest.mark.asyncio +async def test_embed_query_falls_back_to_sync_embed_when_aembed_is_not_implemented(): + class FallbackVectorizer: + async def aembed(self, text: str): + raise NotImplementedError + + def embed(self, text: str): + return [0.4, 0.5, 0.6] + + embedding = await _embed_query(FallbackVectorizer(), "science") + + assert embedding == [0.4, 0.5, 0.6] + + +@pytest.mark.asyncio +async def test_search_records_rejects_blank_query(): + server = FakeServer() + + with pytest.raises(RedisVLMCPError) as exc_info: + await search_records(server, query=" ") + + assert exc_info.value.code == MCPErrorCode.INVALID_REQUEST + + +@pytest.mark.asyncio +async def test_search_records_rejects_invalid_limit_and_offset(): + server = FakeServer() + + with pytest.raises(RedisVLMCPError) as limit_exc: + await search_records(server, query="science", limit=0) + + with pytest.raises(RedisVLMCPError) as offset_exc: + await search_records(server, query="science", offset=-1) + + assert limit_exc.value.code == MCPErrorCode.INVALID_REQUEST + assert offset_exc.value.code == MCPErrorCode.INVALID_REQUEST + + +@pytest.mark.asyncio +async def test_search_records_rejects_unknown_or_vector_return_fields(): + server = FakeServer() + + with pytest.raises(RedisVLMCPError) as unknown_exc: + await search_records(server, query="science", return_fields=["missing"]) + + with pytest.raises(RedisVLMCPError) as vector_exc: + await search_records(server, query="science", return_fields=["embedding"]) + + assert unknown_exc.value.code == MCPErrorCode.INVALID_REQUEST + assert vector_exc.value.code == MCPErrorCode.INVALID_REQUEST + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("result", "message"), + [ + ( + { + "id": "doc:broken", + "content": "science doc", + "category": "science", + }, + "missing expected score field", + ), + ( + { + "content": "science doc", + "category": "science", + "vector_distance": "0.93", + }, + "missing id", + ), + ], +) +async def test_search_records_treats_malformed_backend_results_as_internal_errors( + result, message +): + server = FakeServer(search_type="vector") + + async def fake_query(query): + server.index.query_calls.append(query) + return [result] + + server.index.query = fake_query + + with pytest.raises(RedisVLMCPError, match=message) as exc_info: + await search_records(server, query="science") + + assert exc_info.value.code == MCPErrorCode.INTERNAL_ERROR + assert exc_info.value.retryable is False + + +@pytest.mark.asyncio +async def test_search_records_builds_vector_query_and_normalizes_results(monkeypatch): + server = FakeServer( + search_type="vector", + search_params={"normalize_vector_distance": False, "ef_runtime": 42}, + ) + built_queries = [] + + class FakeVectorQuery(FakeQuery): + def __init__(self, **kwargs): + built_queries.append(kwargs) + super().__init__(**kwargs) + + async def fake_query(query): + server.index.query_calls.append(query) + return [ + { + "id": "doc:1", + "content": "science doc", + "category": "science", + "vector_distance": "0.93", + } + ] + + monkeypatch.setattr("redisvl.mcp.tools.search.VectorQuery", FakeVectorQuery) + server.index.query = fake_query + + response = await search_records(server, query="science") + + assert built_queries[0]["vector"] == [0.1, 0.2, 0.3] + assert built_queries[0]["vector_field_name"] == "embedding" + assert built_queries[0]["return_fields"] == ["content", "category", "rating"] + assert built_queries[0]["num_results"] == 2 + assert built_queries[0]["normalize_vector_distance"] is False + assert built_queries[0]["ef_runtime"] == 42 + assert response == { + "search_type": "vector", + "offset": 0, + "limit": 2, + "results": [ + { + "id": "doc:1", + "score": 0.93, + "score_type": "vector_distance", + "record": { + "content": "science doc", + "category": "science", + }, + } + ], + } + + +@pytest.mark.asyncio +async def test_search_records_builds_fulltext_query(monkeypatch): + server = FakeServer( + search_type="fulltext", + search_params={ + "text_scorer": "BM25STD.NORM", + "stopwords": None, + "text_weights": {"medical": 2.5}, + }, + ) + built_queries = [] + + class FakeTextQuery(FakeQuery): + def __init__(self, **kwargs): + built_queries.append(kwargs) + super().__init__(**kwargs) + + async def fake_query(query): + server.index.query_calls.append(query) + return [ + { + "id": "doc:2", + "content": "medical science", + "category": "health", + "__score": "1.5", + } + ] + + monkeypatch.setattr("redisvl.mcp.tools.search.TextQuery", FakeTextQuery) + server.index.query = fake_query + + response = await search_records( + server, + query="medical science", + limit=1, + return_fields=["content", "category"], + ) + + assert built_queries[0]["text"] == "medical science" + assert built_queries[0]["text_field_name"] == "content" + assert built_queries[0]["num_results"] == 1 + assert built_queries[0]["text_scorer"] == "BM25STD.NORM" + assert built_queries[0]["stopwords"] is None + assert built_queries[0]["text_weights"] == {"medical": 2.5} + assert response["search_type"] == "fulltext" + assert response["results"][0]["score"] == 1.5 + assert response["results"][0]["score_type"] == "text_score" + + +@pytest.mark.asyncio +async def test_search_records_builds_hybrid_query_for_native_runtime(monkeypatch): + server = FakeServer( + search_type="hybrid", + search_params={ + "text_scorer": "TFIDF", + "stopwords": None, + "text_weights": {"hybrid": 2.0}, + "vector_search_method": "KNN", + "knn_ef_runtime": 77, + "combination_method": "LINEAR", + "linear_text_weight": 0.2, + }, + ) + server.native_hybrid_supported = True + built_queries = [] + + class FakePostProcessingConfig: + def __init__(self): + self.apply_calls = [] + + def apply(self, **kwargs): + self.apply_calls.append(kwargs) + + class FakeHybridQuery(FakeQuery): + def __init__(self, **kwargs): + self.postprocessing_config = FakePostProcessingConfig() + built_queries.append(("native", kwargs, self.postprocessing_config)) + super().__init__(**kwargs) + + class FakeAggregateHybridQuery(FakeQuery): + def __init__(self, **kwargs): + built_queries.append(("fallback", kwargs)) + super().__init__(**kwargs) + + async def fake_query(query): + server.index.query_calls.append(query) + return [ + { + "id": "doc:3", + "content": "hybrid doc", + "hybrid_score": "2.5", + } + ] + + monkeypatch.setattr("redisvl.mcp.tools.search.HybridQuery", FakeHybridQuery) + monkeypatch.setattr( + "redisvl.mcp.tools.search.AggregateHybridQuery", FakeAggregateHybridQuery + ) + server.index.query = fake_query + + response = await search_records(server, query="hybrid") + + assert built_queries[0][0] == "native" + assert built_queries[0][1]["vector"] == [0.1, 0.2, 0.3] + assert built_queries[0][1]["text_scorer"] == "TFIDF" + assert built_queries[0][1]["stopwords"] is None + assert built_queries[0][1]["text_weights"] == {"hybrid": 2.0} + assert built_queries[0][1]["vector_search_method"] == "KNN" + assert built_queries[0][1]["knn_ef_runtime"] == 77 + assert built_queries[0][1]["combination_method"] == "LINEAR" + assert built_queries[0][1]["linear_alpha"] == 0.2 + assert built_queries[0][2].apply_calls == [{"__key": "@__key"}] + assert response["search_type"] == "hybrid" + assert response["results"][0]["score_type"] == "hybrid_score" + assert response["results"][0]["score"] == 2.5 + + +@pytest.mark.asyncio +async def test_search_records_avoids_linear_defaults_for_rrf_native_hybrid_query( + monkeypatch, +): + server = FakeServer( + search_type="hybrid", + search_params={ + "combination_method": "RRF", + "rrf_window": 50, + }, + ) + server.native_hybrid_supported = True + built_queries = [] + + class FakePostProcessingConfig: + def __init__(self): + self.apply_calls = [] + + def apply(self, **kwargs): + self.apply_calls.append(kwargs) + + class FakeHybridQuery(FakeQuery): + def __init__(self, **kwargs): + self.postprocessing_config = FakePostProcessingConfig() + built_queries.append(("native", kwargs, self.postprocessing_config)) + super().__init__(**kwargs) + + class FakeAggregateHybridQuery(FakeQuery): + def __init__(self, **kwargs): + built_queries.append(("fallback", kwargs)) + super().__init__(**kwargs) + + async def fake_query(query): + server.index.query_calls.append(query) + return [ + { + "id": "doc:rrf", + "content": "hybrid doc", + "hybrid_score": "1.2", + } + ] + + monkeypatch.setattr("redisvl.mcp.tools.search.HybridQuery", FakeHybridQuery) + monkeypatch.setattr( + "redisvl.mcp.tools.search.AggregateHybridQuery", FakeAggregateHybridQuery + ) + server.index.query = fake_query + + response = await search_records(server, query="hybrid") + + assert built_queries[0][0] == "native" + assert built_queries[0][1]["combination_method"] == "RRF" + assert built_queries[0][1]["rrf_window"] == 50 + assert "linear_alpha" not in built_queries[0][1] + assert "linear_text_weight" not in built_queries[0][1] + assert built_queries[0][2].apply_calls == [{"__key": "@__key"}] + assert response["search_type"] == "hybrid" + assert response["results"][0]["score_type"] == "hybrid_score" + assert response["results"][0]["score"] == 1.2 + + +@pytest.mark.asyncio +async def test_search_records_builds_hybrid_query_for_fallback_runtime(monkeypatch): + server = FakeServer( + search_type="hybrid", + search_params={ + "text_scorer": "TFIDF", + "stopwords": None, + "text_weights": {"hybrid": 2.0}, + "combination_method": "LINEAR", + "linear_text_weight": 0.2, + }, + ) + built_queries = [] + + class FakeHybridQuery(FakeQuery): + def __init__(self, **kwargs): + built_queries.append(("native", kwargs)) + super().__init__(**kwargs) + + class FakeAggregateHybridQuery(FakeQuery): + def __init__(self, **kwargs): + built_queries.append(("fallback", kwargs)) + super().__init__(**kwargs) + + async def fake_query(query): + server.index.query_calls.append(query) + return [ + { + "id": "doc:4", + "content": "fallback hybrid", + "hybrid_score": "0.7", + } + ] + + monkeypatch.setattr("redisvl.mcp.tools.search.HybridQuery", FakeHybridQuery) + monkeypatch.setattr( + "redisvl.mcp.tools.search.AggregateHybridQuery", FakeAggregateHybridQuery + ) + server.index.query = fake_query + + response = await search_records(server, query="hybrid") + + assert built_queries[0][0] == "fallback" + assert built_queries[0][1]["text_scorer"] == "TFIDF" + assert built_queries[0][1]["stopwords"] is None + assert built_queries[0][1]["text_weights"] == {"hybrid": 2.0} + assert built_queries[0][1]["alpha"] == pytest.approx(0.8) + assert built_queries[0][1]["return_fields"] == [ + "__key", + "content", + "category", + "rating", + ] + assert response["search_type"] == "hybrid" + assert response["results"][0]["score"] == 0.7 + + +@pytest.mark.asyncio +async def test_search_records_rejects_native_only_hybrid_runtime_params(monkeypatch): + server = FakeServer( + search_type="hybrid", + search_params={ + "combination_method": "RRF", + "rrf_window": 50, + }, + ) + + with pytest.raises(ValueError, match="native hybrid search support"): + server.config.validate_search(supports_native_hybrid_search=False) + + +def test_register_search_tool_uses_default_and_override_descriptions(): + default_server = FakeServer() + register_search_tool(default_server) + + assert default_server.registered_tools[0]["name"] == "search-records" + assert "Search records" in default_server.registered_tools[0]["description"] + assert "query" in default_server.registered_tools[0]["fn"].__annotations__ + assert "search_type" not in default_server.registered_tools[0]["fn"].__annotations__ + + custom_server = FakeServer() + custom_server.mcp_settings.tool_search_description = "Custom search description" + register_search_tool(custom_server) + + assert ( + custom_server.registered_tools[0]["description"] == "Custom search description" + ) diff --git a/tests/unit/test_mcp/test_server_unit.py b/tests/unit/test_mcp/test_server_unit.py new file mode 100644 index 00000000..13ba23d5 --- /dev/null +++ b/tests/unit/test_mcp/test_server_unit.py @@ -0,0 +1,42 @@ +from types import SimpleNamespace + +import pytest + +from redisvl.mcp.server import RedisVLMCPServer + + +class FakeClient: + def __init__(self): + self.info_calls = 0 + + async def info(self, section: str): + self.info_calls += 1 + assert section == "server" + return {"redis_version": "8.4.0"} + + def ft(self, index_name: str): + assert index_name == "docs-index" + return SimpleNamespace(hybrid_search=object()) + + +class FakeIndex: + def __init__(self, client: FakeClient): + self.schema = SimpleNamespace(index=SimpleNamespace(name="docs-index")) + self._client = client + + async def _get_client(self): + return self._client + + +@pytest.mark.asyncio +async def test_supports_native_hybrid_search_caches_runtime_probe(monkeypatch): + client = FakeClient() + server = RedisVLMCPServer.__new__(RedisVLMCPServer) + server._index = FakeIndex(client) + server._supports_native_hybrid_search = None + + monkeypatch.setattr("redisvl.mcp.server.redis_py_version", "7.1.0") + + assert await server.supports_native_hybrid_search() is True + assert await server.supports_native_hybrid_search() is True + assert client.info_calls == 1