Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/tools/hash.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List, Union, Optional

import numpy as np
from redis.exceptions import RedisError

Expand All @@ -7,7 +9,7 @@

@mcp.tool()
async def hset(
name: str, key: str, value: str | int | float, expire_seconds: int = None
name: str, key: str, value: str | int | float, expire_seconds: Optional[int] = None
) -> str:
"""Set a field in a hash stored at key with an optional expiration time.

Expand Down Expand Up @@ -118,8 +120,8 @@ async def hexists(name: str, key: str) -> bool:

@mcp.tool()
async def set_vector_in_hash(
name: str, vector: list, vector_field: str = "vector"
) -> bool:
name: str, vector: List[float], vector_field: str = "vector"
) -> Union[bool, str]:
"""Store a vector as a field in a Redis hash.

Args:
Expand Down
42 changes: 31 additions & 11 deletions src/tools/json.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,44 @@
import json
from typing import Union, Mapping, List, TYPE_CHECKING
from typing import Optional
from redis.exceptions import RedisError
from pydantic_core import core_schema

from src.common.connection import RedisConnectionManager
from src.common.server import mcp

# Define JsonType for type checking to match redis-py definition
# Use object as runtime type to avoid issubclass() issues with Any in Python 3.10
if TYPE_CHECKING:
JsonType = Union[
str, int, float, bool, None, Mapping[str, "JsonType"], List["JsonType"]
]
else:
# Use object at runtime to avoid MCP framework issubclass() issues
JsonType = object

# Custom type that accepts any JSON value but generates a proper schema
class JsonValue:
"""Accepts any JSON-serializable value."""

@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
"""Define how Pydantic should validate this type."""
# Accept any value
return core_schema.any_schema()

@classmethod
def __get_pydantic_json_schema__(cls, _core_schema, _handler):
"""Define the JSON schema for this type."""
# Return a schema that accepts string, number, boolean, object, array, or null
return {
"anyOf": [
{"type": "string"},
{"type": "number"},
{"type": "boolean"},
{"type": "object"},
{"type": "array", "items": {"type": "string"}},
{"type": "null"},
]
}


@mcp.tool()
async def json_set(
name: str, path: str, value: JsonType, expire_seconds: int = None
name: str,
path: str,
value: JsonValue,
expire_seconds: Optional[int] = None,
) -> str:
"""Set a JSON value in Redis at a given path with an optional expiration time.

Expand Down
7 changes: 4 additions & 3 deletions src/tools/list.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from typing import Union, List, Optional

from redis.exceptions import RedisError
from redis.typing import FieldT
Expand All @@ -8,7 +9,7 @@


@mcp.tool()
async def lpush(name: str, value: FieldT, expire: int = None) -> str:
async def lpush(name: str, value: FieldT, expire: Optional[int] = None) -> str:
"""Push a value onto the left of a Redis list and optionally set an expiration time."""
try:
r = RedisConnectionManager.get_connection()
Expand All @@ -21,7 +22,7 @@ async def lpush(name: str, value: FieldT, expire: int = None) -> str:


@mcp.tool()
async def rpush(name: str, value: FieldT, expire: int = None) -> str:
async def rpush(name: str, value: FieldT, expire: Optional[int] = None) -> str:
"""Push a value onto the right of a Redis list and optionally set an expiration time."""
try:
r = RedisConnectionManager.get_connection()
Expand Down Expand Up @@ -56,7 +57,7 @@ async def rpop(name: str) -> str:


@mcp.tool()
async def lrange(name: str, start: int, stop: int) -> list:
async def lrange(name: str, start: int, stop: int) -> Union[str, List[str]]:
"""Get elements from a Redis list within a specific range.

Returns:
Expand Down
10 changes: 7 additions & 3 deletions src/tools/misc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any, Dict, Union, List

from redis.exceptions import RedisError

Expand Down Expand Up @@ -100,7 +100,9 @@ async def rename(old_key: str, new_key: str) -> Dict[str, Any]:


@mcp.tool()
async def scan_keys(pattern: str = "*", count: int = 100, cursor: int = 0) -> dict:
async def scan_keys(
pattern: str = "*", count: int = 100, cursor: int = 0
) -> Union[str, Dict[str, Any]]:
"""
Scan keys in the Redis database using the SCAN command (non-blocking, production-safe).

Expand Down Expand Up @@ -152,7 +154,9 @@ async def scan_keys(pattern: str = "*", count: int = 100, cursor: int = 0) -> di


@mcp.tool()
async def scan_all_keys(pattern: str = "*", batch_size: int = 100) -> list:
async def scan_all_keys(
pattern: str = "*", batch_size: int = 100
) -> Union[str, List[str]]:
"""
Scan and return ALL keys matching a pattern using multiple SCAN iterations.

Expand Down
7 changes: 4 additions & 3 deletions src/tools/redis_query_engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from typing import List, Optional, Union, Dict, Any

import numpy as np
from redis.commands.search.field import VectorField
Expand Down Expand Up @@ -102,12 +103,12 @@ async def create_vector_index_hash(

@mcp.tool()
async def vector_search_hash(
query_vector: list,
query_vector: List[float],
index_name: str = "vector_index",
vector_field: str = "vector",
k: int = 5,
return_fields: list = None,
) -> list:
return_fields: Optional[List[str]] = None,
) -> Union[List[Dict[str, Any]], str]:
"""
Perform a KNN vector similarity search using Redis 8 or later version on vectors stored in hash data structures.

Expand Down
6 changes: 4 additions & 2 deletions src/tools/set.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import Union, List, Optional

from redis.exceptions import RedisError

from src.common.connection import RedisConnectionManager
from src.common.server import mcp


@mcp.tool()
async def sadd(name: str, value: str, expire_seconds: int = None) -> str:
async def sadd(name: str, value: str, expire_seconds: Optional[int] = None) -> str:
"""Add a value to a Redis set with an optional expiration time.

Args:
Expand Down Expand Up @@ -54,7 +56,7 @@ async def srem(name: str, value: str) -> str:


@mcp.tool()
async def smembers(name: str) -> list:
async def smembers(name: str) -> Union[str, List[str]]:
"""Get all members of a Redis set.

Args:
Expand Down
6 changes: 5 additions & 1 deletion src/tools/sorted_set.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from typing import Optional

from redis.exceptions import RedisError

from src.common.connection import RedisConnectionManager
from src.common.server import mcp


@mcp.tool()
async def zadd(key: str, score: float, member: str, expiration: int = None) -> str:
async def zadd(
key: str, score: float, member: str, expiration: Optional[int] = None
) -> str:
"""Add a member to a Redis sorted set with an optional expiration time.

Args:
Expand Down
6 changes: 5 additions & 1 deletion src/tools/stream.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from typing import Dict, Any, Optional

from redis.exceptions import RedisError

from src.common.connection import RedisConnectionManager
from src.common.server import mcp


@mcp.tool()
async def xadd(key: str, fields: dict, expiration: int = None) -> str:
async def xadd(
key: str, fields: Dict[str, Any], expiration: Optional[int] = None
) -> str:
"""Add an entry to a Redis stream with an optional expiration time.

Args:
Expand Down
6 changes: 4 additions & 2 deletions src/tools/string.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Union
from typing import Union, Optional

from redis.exceptions import RedisError
from redis import Redis
Expand All @@ -10,7 +10,9 @@

@mcp.tool()
async def set(
key: str, value: Union[str, bytes, int, float, dict], expiration: int = None
key: str,
value: Union[str, bytes, int, float, dict],
expiration: Optional[int] = None,
) -> str:
"""Set a Redis string value with an optional expiration time.

Expand Down
Loading