Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions redisvl/mcp/filters.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
from __future__ import annotations

from typing import Any, Iterable, Optional
from typing import Any, Dict, Iterable, List, Optional, Union

from redisvl.mcp.errors import MCPErrorCode, RedisVLMCPError
from redisvl.query.filter import FilterExpression, Num, Tag, Text
from redisvl.schema import IndexSchema


def parse_filter(
value: Optional[str | dict[str, Any]], schema: IndexSchema
) -> Optional[str | FilterExpression]:
value: Optional[Union[str, Dict[str, Any]]], schema: IndexSchema
) -> Optional[Union[str, FilterExpression]]:
"""Parse an MCP filter value into a RedisVL filter representation."""
if value is None:
return None
Expand All @@ -24,7 +22,7 @@ def parse_filter(
return _parse_expression(value, schema)


def _parse_expression(value: dict[str, Any], schema: IndexSchema) -> FilterExpression:
def _parse_expression(value: Dict[str, Any], schema: IndexSchema) -> FilterExpression:
logical_keys = [key for key in ("and", "or", "not") if key in value]
if logical_keys:
if len(logical_keys) != 1 or len(value) != 1:
Expand Down Expand Up @@ -53,7 +51,7 @@ def _parse_expression(value: dict[str, Any], schema: IndexSchema) -> FilterExpre
retryable=False,
)

expressions: list[FilterExpression] = []
expressions: List[FilterExpression] = []
for child in children:
if not isinstance(child, dict):
raise RedisVLMCPError(
Expand Down Expand Up @@ -205,7 +203,7 @@ def _require_string(value: Any, field_name: str, op: str) -> str:
return value


def _require_string_list(value: Any, field_name: str, op: str) -> list[str]:
def _require_string_list(value: Any, field_name: str, op: str) -> List[str]:
if not isinstance(value, list) or not value:
raise RedisVLMCPError(
f"filter value for field '{field_name}' and operator '{op}' must be a non-empty array",
Expand All @@ -216,7 +214,7 @@ def _require_string_list(value: Any, field_name: str, op: str) -> list[str]:
return strings


def _require_number(value: Any, field_name: str, op: str) -> int | float:
def _require_number(value: Any, field_name: str, op: str) -> Union[int, float]:
if isinstance(value, bool) or not isinstance(value, (int, float)):
raise RedisVLMCPError(
f"filter value for field '{field_name}' and operator '{op}' must be numeric",
Expand All @@ -226,7 +224,9 @@ def _require_number(value: Any, field_name: str, op: str) -> int | float:
return value


def _require_number_list(value: Any, field_name: str, op: str) -> list[int | float]:
def _require_number_list(
value: Any, field_name: str, op: str
) -> List[Union[int, float]]:
if not isinstance(value, list) or not value:
raise RedisVLMCPError(
f"filter value for field '{field_name}' and operator '{op}' must be a non-empty array",
Expand Down
6 changes: 4 additions & 2 deletions redisvl/mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from redisvl.index import AsyncSearchIndex
from redisvl.mcp.config import MCPConfig, load_mcp_config
from redisvl.mcp.settings import MCPSettings
from redisvl.mcp.tools.search import register_search_tool
from redisvl.mcp.tools.upsert import register_upsert_tool
from redisvl.redis.connection import RedisConnectionFactory, is_version_gte
from redisvl.schema import IndexSchema

Expand Down Expand Up @@ -181,9 +183,9 @@ def _register_tools(self) -> None:
if self._tools_registered or not hasattr(self, "tool"):
return

from redisvl.mcp.tools.search import register_search_tool

register_search_tool(self)
if not self.mcp_settings.read_only:
register_upsert_tool(self)
self._tools_registered = True

@staticmethod
Expand Down
3 changes: 2 additions & 1 deletion redisvl/mcp/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from redisvl.mcp.tools.search import search_records
from redisvl.mcp.tools.upsert import upsert_records

__all__ = ["search_records"]
__all__ = ["search_records", "upsert_records"]
272 changes: 272 additions & 0 deletions redisvl/mcp/tools/upsert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
import asyncio
import inspect
from typing import Any, Dict, List, Optional

from redisvl.mcp.errors import MCPErrorCode, RedisVLMCPError, map_exception
from redisvl.redis.utils import array_to_buffer
from redisvl.schema.schema import StorageType
from redisvl.schema.validation import validate_object

DEFAULT_UPSERT_DESCRIPTION = "Upsert records in the configured Redis index."


def _validate_request(
*,
server: Any,
records: List[Dict[str, Any]],
id_field: Optional[str],
skip_embedding_if_present: Optional[bool],
) -> bool:
"""Validate the public upsert request contract and resolve defaults."""
runtime = server.config.runtime

if not isinstance(records, list) or not records:
raise RedisVLMCPError(
"records must be a non-empty list",
code=MCPErrorCode.INVALID_REQUEST,
retryable=False,
)
if len(records) > runtime.max_upsert_records:
raise RedisVLMCPError(
"records length must be less than or equal to "
f"{runtime.max_upsert_records}",
code=MCPErrorCode.INVALID_REQUEST,
retryable=False,
)
if id_field is not None and (not isinstance(id_field, str) or not id_field):
raise RedisVLMCPError(
"id_field must be a non-empty string when provided",
code=MCPErrorCode.INVALID_REQUEST,
retryable=False,
)

effective_skip_embedding = runtime.skip_embedding_if_present
if skip_embedding_if_present is not None:
if not isinstance(skip_embedding_if_present, bool):
raise RedisVLMCPError(
"skip_embedding_if_present must be a boolean when provided",
code=MCPErrorCode.INVALID_REQUEST,
retryable=False,
)
effective_skip_embedding = skip_embedding_if_present

for record in records:
if not isinstance(record, dict):
raise RedisVLMCPError(
"records must contain only objects",
code=MCPErrorCode.INVALID_REQUEST,
retryable=False,
)
if id_field is not None and id_field not in record:
raise RedisVLMCPError(
"id_field '{id_field}' must exist in every record".format(
id_field=id_field
),
code=MCPErrorCode.INVALID_REQUEST,
retryable=False,
)

return effective_skip_embedding


def _record_needs_embedding(
record: Dict[str, Any],
*,
vector_field_name: str,
skip_embedding_if_present: bool,
) -> bool:
"""Determine whether a record requires server-side embedding."""
return (
not skip_embedding_if_present
or vector_field_name not in record
or record[vector_field_name] is None
)


def _validate_embed_sources(
records: List[Dict[str, Any]],
*,
embed_text_field: str,
vector_field_name: str,
skip_embedding_if_present: bool,
) -> List[str]:
"""Collect embed sources for records that require embedding."""
contents = []
for record in records:
if not _record_needs_embedding(
record,
vector_field_name=vector_field_name,
skip_embedding_if_present=skip_embedding_if_present,
):
continue

content = record.get(embed_text_field)
if not isinstance(content, str) or not content.strip():
raise RedisVLMCPError(
"records requiring embedding must include a non-empty "
"'{field}' field".format(field=embed_text_field),
code=MCPErrorCode.INVALID_REQUEST,
retryable=False,
)
contents.append(content)

return contents


async def _embed_one(vectorizer: Any, content: str) -> List[float]:
"""Embed one record, falling back from async to sync implementations."""
aembed = getattr(vectorizer, "aembed", None)
if callable(aembed):
try:
return await aembed(content)
except NotImplementedError:
pass

embed = getattr(vectorizer, "embed", None)
if embed is None:
raise AttributeError("Configured vectorizer does not support embed()")
if inspect.iscoroutinefunction(embed):
return await embed(content)
return await asyncio.to_thread(embed, content)


async def _embed_many(vectorizer: Any, contents: List[str]) -> List[List[float]]:
"""Embed multiple records with batch-first fallbacks."""
if not contents:
return []

aembed_many = getattr(vectorizer, "aembed_many", None)
if callable(aembed_many):
try:
return await aembed_many(contents)
except NotImplementedError:
pass

embed_many = getattr(vectorizer, "embed_many", None)
if callable(embed_many):
if inspect.iscoroutinefunction(embed_many):
return await embed_many(contents)
return await asyncio.to_thread(embed_many, contents)

embeddings = []
for content in contents:
embeddings.append(await _embed_one(vectorizer, content))
return embeddings


def _vector_dtype(server: Any, index: Any) -> str:
"""Resolve the configured vector field datatype as a lowercase string."""
field = server.config.get_vector_field(index.schema)
datatype = getattr(field.attrs.datatype, "value", field.attrs.datatype)
return str(datatype).lower()


def _prepare_record_for_storage(
record: Dict[str, Any],
*,
server: Any,
index: Any,
) -> Dict[str, Any]:
"""Serialize vector fields for storage and validate the prepared record."""
prepared = dict(record)
vector_field_name = server.config.runtime.vector_field_name
vector_value = prepared.get(vector_field_name)

if index.schema.index.storage_type == StorageType.HASH:
if isinstance(vector_value, list):
prepared[vector_field_name] = array_to_buffer(
vector_value,
_vector_dtype(server, index),
)
validate_object(index.schema, prepared)
return prepared


async def upsert_records(
server: Any,
*,
records: List[Dict[str, Any]],
id_field: Optional[str] = None,
skip_embedding_if_present: Optional[bool] = None,
) -> Dict[str, Any]:
"""Execute `upsert-records` against the configured Redis index."""
try:
index = await server.get_index()
effective_skip_embedding = _validate_request(
server=server,
records=records,
id_field=id_field,
skip_embedding_if_present=skip_embedding_if_present,
)
# Copy caller-provided records before enriching them with embeddings or
# storage-specific serialization so the MCP tool does not mutate inputs.
prepared_records = [record.copy() for record in records]
runtime = server.config.runtime
embed_contents = _validate_embed_sources(
prepared_records,
embed_text_field=runtime.default_embed_text_field,
vector_field_name=runtime.vector_field_name,
skip_embedding_if_present=effective_skip_embedding,
)

if embed_contents:
vectorizer = await server.get_vectorizer()
embeddings = await _embed_many(vectorizer, embed_contents)
# Tracks position in the compact embeddings list, which only contains
# vectors for records that still need server-side embedding.
embedding_index = 0
for record in prepared_records:
if _record_needs_embedding(
record,
vector_field_name=runtime.vector_field_name,
skip_embedding_if_present=effective_skip_embedding,
):
record[runtime.vector_field_name] = embeddings[embedding_index]
embedding_index += 1

loadable_records = [
_prepare_record_for_storage(record, server=server, index=index)
for record in prepared_records
]

try:
keys = await server.run_guarded(
"upsert-records",
index.load(loadable_records, id_field=id_field),
)
except Exception as exc:
mapped = map_exception(exc)
mapped.metadata["partial_write_possible"] = True
raise mapped

return {
"status": "success",
"keys_upserted": len(keys),
"keys": keys,
}
except RedisVLMCPError:
raise
except Exception as exc:
raise map_exception(exc)


def register_upsert_tool(server: Any) -> None:
"""Register the MCP upsert tool on a server-like object."""
description = (
server.mcp_settings.tool_upsert_description or DEFAULT_UPSERT_DESCRIPTION
)

async def upsert_records_tool(
records: List[Dict[str, Any]],
id_field: Optional[str] = None,
skip_embedding_if_present: Optional[bool] = None,
):
"""FastMCP wrapper for the `upsert-records` tool."""
return await upsert_records(
server,
records=records,
id_field=id_field,
skip_embedding_if_present=skip_embedding_if_present,
)

server.tool(name="upsert-records", description=description)(upsert_records_tool)
Loading
Loading