From be581d3bef7afb192226fa94d27f1b84cab3a541 Mon Sep 17 00:00:00 2001 From: Vasil Chomakov Date: Thu, 2 Oct 2025 13:32:15 +0300 Subject: [PATCH] fix: improve MCP tool schema types to avoid gemini tool exclusion --- src/tools/hash.py | 8 ++++--- src/tools/json.py | 42 ++++++++++++++++++++++++--------- src/tools/list.py | 7 +++--- src/tools/misc.py | 10 +++++--- src/tools/redis_query_engine.py | 7 +++--- src/tools/set.py | 6 +++-- src/tools/sorted_set.py | 6 ++++- src/tools/stream.py | 6 ++++- src/tools/string.py | 6 +++-- 9 files changed, 69 insertions(+), 29 deletions(-) diff --git a/src/tools/hash.py b/src/tools/hash.py index 60212ac..0e70a86 100644 --- a/src/tools/hash.py +++ b/src/tools/hash.py @@ -1,3 +1,5 @@ +from typing import List, Union, Optional + import numpy as np from redis.exceptions import RedisError @@ -7,7 +9,7 @@ @mcp.tool() async def hset( - name: str, key: str, value: str | int | float, expire_seconds: int = None + name: str, key: str, value: str | int | float, expire_seconds: Optional[int] = None ) -> str: """Set a field in a hash stored at key with an optional expiration time. @@ -118,8 +120,8 @@ async def hexists(name: str, key: str) -> bool: @mcp.tool() async def set_vector_in_hash( - name: str, vector: list, vector_field: str = "vector" -) -> bool: + name: str, vector: List[float], vector_field: str = "vector" +) -> Union[bool, str]: """Store a vector as a field in a Redis hash. Args: diff --git a/src/tools/json.py b/src/tools/json.py index e0d0624..5cfb4fd 100644 --- a/src/tools/json.py +++ b/src/tools/json.py @@ -1,24 +1,44 @@ import json -from typing import Union, Mapping, List, TYPE_CHECKING +from typing import Optional from redis.exceptions import RedisError +from pydantic_core import core_schema from src.common.connection import RedisConnectionManager from src.common.server import mcp -# Define JsonType for type checking to match redis-py definition -# Use object as runtime type to avoid issubclass() issues with Any in Python 3.10 -if TYPE_CHECKING: - JsonType = Union[ - str, int, float, bool, None, Mapping[str, "JsonType"], List["JsonType"] - ] -else: - # Use object at runtime to avoid MCP framework issubclass() issues - JsonType = object + +# Custom type that accepts any JSON value but generates a proper schema +class JsonValue: + """Accepts any JSON-serializable value.""" + + @classmethod + def __get_pydantic_core_schema__(cls, _source_type, _handler): + """Define how Pydantic should validate this type.""" + # Accept any value + return core_schema.any_schema() + + @classmethod + def __get_pydantic_json_schema__(cls, _core_schema, _handler): + """Define the JSON schema for this type.""" + # Return a schema that accepts string, number, boolean, object, array, or null + return { + "anyOf": [ + {"type": "string"}, + {"type": "number"}, + {"type": "boolean"}, + {"type": "object"}, + {"type": "array", "items": {"type": "string"}}, + {"type": "null"}, + ] + } @mcp.tool() async def json_set( - name: str, path: str, value: JsonType, expire_seconds: int = None + name: str, + path: str, + value: JsonValue, + expire_seconds: Optional[int] = None, ) -> str: """Set a JSON value in Redis at a given path with an optional expiration time. diff --git a/src/tools/list.py b/src/tools/list.py index 9b69291..929f24e 100644 --- a/src/tools/list.py +++ b/src/tools/list.py @@ -1,4 +1,5 @@ import json +from typing import Union, List, Optional from redis.exceptions import RedisError from redis.typing import FieldT @@ -8,7 +9,7 @@ @mcp.tool() -async def lpush(name: str, value: FieldT, expire: int = None) -> str: +async def lpush(name: str, value: FieldT, expire: Optional[int] = None) -> str: """Push a value onto the left of a Redis list and optionally set an expiration time.""" try: r = RedisConnectionManager.get_connection() @@ -21,7 +22,7 @@ async def lpush(name: str, value: FieldT, expire: int = None) -> str: @mcp.tool() -async def rpush(name: str, value: FieldT, expire: int = None) -> str: +async def rpush(name: str, value: FieldT, expire: Optional[int] = None) -> str: """Push a value onto the right of a Redis list and optionally set an expiration time.""" try: r = RedisConnectionManager.get_connection() @@ -56,7 +57,7 @@ async def rpop(name: str) -> str: @mcp.tool() -async def lrange(name: str, start: int, stop: int) -> list: +async def lrange(name: str, start: int, stop: int) -> Union[str, List[str]]: """Get elements from a Redis list within a specific range. Returns: diff --git a/src/tools/misc.py b/src/tools/misc.py index 72877c9..d79ba17 100644 --- a/src/tools/misc.py +++ b/src/tools/misc.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, Union, List from redis.exceptions import RedisError @@ -100,7 +100,9 @@ async def rename(old_key: str, new_key: str) -> Dict[str, Any]: @mcp.tool() -async def scan_keys(pattern: str = "*", count: int = 100, cursor: int = 0) -> dict: +async def scan_keys( + pattern: str = "*", count: int = 100, cursor: int = 0 +) -> Union[str, Dict[str, Any]]: """ Scan keys in the Redis database using the SCAN command (non-blocking, production-safe). @@ -152,7 +154,9 @@ async def scan_keys(pattern: str = "*", count: int = 100, cursor: int = 0) -> di @mcp.tool() -async def scan_all_keys(pattern: str = "*", batch_size: int = 100) -> list: +async def scan_all_keys( + pattern: str = "*", batch_size: int = 100 +) -> Union[str, List[str]]: """ Scan and return ALL keys matching a pattern using multiple SCAN iterations. diff --git a/src/tools/redis_query_engine.py b/src/tools/redis_query_engine.py index 6e04420..4ffa81c 100644 --- a/src/tools/redis_query_engine.py +++ b/src/tools/redis_query_engine.py @@ -1,4 +1,5 @@ import json +from typing import List, Optional, Union, Dict, Any import numpy as np from redis.commands.search.field import VectorField @@ -102,12 +103,12 @@ async def create_vector_index_hash( @mcp.tool() async def vector_search_hash( - query_vector: list, + query_vector: List[float], index_name: str = "vector_index", vector_field: str = "vector", k: int = 5, - return_fields: list = None, -) -> list: + return_fields: Optional[List[str]] = None, +) -> Union[List[Dict[str, Any]], str]: """ Perform a KNN vector similarity search using Redis 8 or later version on vectors stored in hash data structures. diff --git a/src/tools/set.py b/src/tools/set.py index ae8f2b2..bb2de0b 100644 --- a/src/tools/set.py +++ b/src/tools/set.py @@ -1,3 +1,5 @@ +from typing import Union, List, Optional + from redis.exceptions import RedisError from src.common.connection import RedisConnectionManager @@ -5,7 +7,7 @@ @mcp.tool() -async def sadd(name: str, value: str, expire_seconds: int = None) -> str: +async def sadd(name: str, value: str, expire_seconds: Optional[int] = None) -> str: """Add a value to a Redis set with an optional expiration time. Args: @@ -54,7 +56,7 @@ async def srem(name: str, value: str) -> str: @mcp.tool() -async def smembers(name: str) -> list: +async def smembers(name: str) -> Union[str, List[str]]: """Get all members of a Redis set. Args: diff --git a/src/tools/sorted_set.py b/src/tools/sorted_set.py index 8bd80c3..c633a74 100644 --- a/src/tools/sorted_set.py +++ b/src/tools/sorted_set.py @@ -1,3 +1,5 @@ +from typing import Optional + from redis.exceptions import RedisError from src.common.connection import RedisConnectionManager @@ -5,7 +7,9 @@ @mcp.tool() -async def zadd(key: str, score: float, member: str, expiration: int = None) -> str: +async def zadd( + key: str, score: float, member: str, expiration: Optional[int] = None +) -> str: """Add a member to a Redis sorted set with an optional expiration time. Args: diff --git a/src/tools/stream.py b/src/tools/stream.py index f1df3ba..ebca6b7 100644 --- a/src/tools/stream.py +++ b/src/tools/stream.py @@ -1,3 +1,5 @@ +from typing import Dict, Any, Optional + from redis.exceptions import RedisError from src.common.connection import RedisConnectionManager @@ -5,7 +7,9 @@ @mcp.tool() -async def xadd(key: str, fields: dict, expiration: int = None) -> str: +async def xadd( + key: str, fields: Dict[str, Any], expiration: Optional[int] = None +) -> str: """Add an entry to a Redis stream with an optional expiration time. Args: diff --git a/src/tools/string.py b/src/tools/string.py index c3e191e..f4f76ef 100644 --- a/src/tools/string.py +++ b/src/tools/string.py @@ -1,5 +1,5 @@ import json -from typing import Union +from typing import Union, Optional from redis.exceptions import RedisError from redis import Redis @@ -10,7 +10,9 @@ @mcp.tool() async def set( - key: str, value: Union[str, bytes, int, float, dict], expiration: int = None + key: str, + value: Union[str, bytes, int, float, dict], + expiration: Optional[int] = None, ) -> str: """Set a Redis string value with an optional expiration time.