Skip to content

Commit be581d3

Browse files
committed
fix: improve MCP tool schema types to avoid gemini tool exclusion
1 parent 8881af8 commit be581d3

File tree

9 files changed

+69
-29
lines changed

9 files changed

+69
-29
lines changed

src/tools/hash.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import List, Union, Optional
2+
13
import numpy as np
24
from redis.exceptions import RedisError
35

@@ -7,7 +9,7 @@
79

810
@mcp.tool()
911
async def hset(
10-
name: str, key: str, value: str | int | float, expire_seconds: int = None
12+
name: str, key: str, value: str | int | float, expire_seconds: Optional[int] = None
1113
) -> str:
1214
"""Set a field in a hash stored at key with an optional expiration time.
1315
@@ -118,8 +120,8 @@ async def hexists(name: str, key: str) -> bool:
118120

119121
@mcp.tool()
120122
async def set_vector_in_hash(
121-
name: str, vector: list, vector_field: str = "vector"
122-
) -> bool:
123+
name: str, vector: List[float], vector_field: str = "vector"
124+
) -> Union[bool, str]:
123125
"""Store a vector as a field in a Redis hash.
124126
125127
Args:

src/tools/json.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,44 @@
11
import json
2-
from typing import Union, Mapping, List, TYPE_CHECKING
2+
from typing import Optional
33
from redis.exceptions import RedisError
4+
from pydantic_core import core_schema
45

56
from src.common.connection import RedisConnectionManager
67
from src.common.server import mcp
78

8-
# Define JsonType for type checking to match redis-py definition
9-
# Use object as runtime type to avoid issubclass() issues with Any in Python 3.10
10-
if TYPE_CHECKING:
11-
JsonType = Union[
12-
str, int, float, bool, None, Mapping[str, "JsonType"], List["JsonType"]
13-
]
14-
else:
15-
# Use object at runtime to avoid MCP framework issubclass() issues
16-
JsonType = object
9+
10+
# Custom type that accepts any JSON value but generates a proper schema
11+
class JsonValue:
12+
"""Accepts any JSON-serializable value."""
13+
14+
@classmethod
15+
def __get_pydantic_core_schema__(cls, _source_type, _handler):
16+
"""Define how Pydantic should validate this type."""
17+
# Accept any value
18+
return core_schema.any_schema()
19+
20+
@classmethod
21+
def __get_pydantic_json_schema__(cls, _core_schema, _handler):
22+
"""Define the JSON schema for this type."""
23+
# Return a schema that accepts string, number, boolean, object, array, or null
24+
return {
25+
"anyOf": [
26+
{"type": "string"},
27+
{"type": "number"},
28+
{"type": "boolean"},
29+
{"type": "object"},
30+
{"type": "array", "items": {"type": "string"}},
31+
{"type": "null"},
32+
]
33+
}
1734

1835

1936
@mcp.tool()
2037
async def json_set(
21-
name: str, path: str, value: JsonType, expire_seconds: int = None
38+
name: str,
39+
path: str,
40+
value: JsonValue,
41+
expire_seconds: Optional[int] = None,
2242
) -> str:
2343
"""Set a JSON value in Redis at a given path with an optional expiration time.
2444

src/tools/list.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
from typing import Union, List, Optional
23

34
from redis.exceptions import RedisError
45
from redis.typing import FieldT
@@ -8,7 +9,7 @@
89

910

1011
@mcp.tool()
11-
async def lpush(name: str, value: FieldT, expire: int = None) -> str:
12+
async def lpush(name: str, value: FieldT, expire: Optional[int] = None) -> str:
1213
"""Push a value onto the left of a Redis list and optionally set an expiration time."""
1314
try:
1415
r = RedisConnectionManager.get_connection()
@@ -21,7 +22,7 @@ async def lpush(name: str, value: FieldT, expire: int = None) -> str:
2122

2223

2324
@mcp.tool()
24-
async def rpush(name: str, value: FieldT, expire: int = None) -> str:
25+
async def rpush(name: str, value: FieldT, expire: Optional[int] = None) -> str:
2526
"""Push a value onto the right of a Redis list and optionally set an expiration time."""
2627
try:
2728
r = RedisConnectionManager.get_connection()
@@ -56,7 +57,7 @@ async def rpop(name: str) -> str:
5657

5758

5859
@mcp.tool()
59-
async def lrange(name: str, start: int, stop: int) -> list:
60+
async def lrange(name: str, start: int, stop: int) -> Union[str, List[str]]:
6061
"""Get elements from a Redis list within a specific range.
6162
6263
Returns:

src/tools/misc.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict
1+
from typing import Any, Dict, Union, List
22

33
from redis.exceptions import RedisError
44

@@ -100,7 +100,9 @@ async def rename(old_key: str, new_key: str) -> Dict[str, Any]:
100100

101101

102102
@mcp.tool()
103-
async def scan_keys(pattern: str = "*", count: int = 100, cursor: int = 0) -> dict:
103+
async def scan_keys(
104+
pattern: str = "*", count: int = 100, cursor: int = 0
105+
) -> Union[str, Dict[str, Any]]:
104106
"""
105107
Scan keys in the Redis database using the SCAN command (non-blocking, production-safe).
106108
@@ -152,7 +154,9 @@ async def scan_keys(pattern: str = "*", count: int = 100, cursor: int = 0) -> di
152154

153155

154156
@mcp.tool()
155-
async def scan_all_keys(pattern: str = "*", batch_size: int = 100) -> list:
157+
async def scan_all_keys(
158+
pattern: str = "*", batch_size: int = 100
159+
) -> Union[str, List[str]]:
156160
"""
157161
Scan and return ALL keys matching a pattern using multiple SCAN iterations.
158162

src/tools/redis_query_engine.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
from typing import List, Optional, Union, Dict, Any
23

34
import numpy as np
45
from redis.commands.search.field import VectorField
@@ -102,12 +103,12 @@ async def create_vector_index_hash(
102103

103104
@mcp.tool()
104105
async def vector_search_hash(
105-
query_vector: list,
106+
query_vector: List[float],
106107
index_name: str = "vector_index",
107108
vector_field: str = "vector",
108109
k: int = 5,
109-
return_fields: list = None,
110-
) -> list:
110+
return_fields: Optional[List[str]] = None,
111+
) -> Union[List[Dict[str, Any]], str]:
111112
"""
112113
Perform a KNN vector similarity search using Redis 8 or later version on vectors stored in hash data structures.
113114

src/tools/set.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1+
from typing import Union, List, Optional
2+
13
from redis.exceptions import RedisError
24

35
from src.common.connection import RedisConnectionManager
46
from src.common.server import mcp
57

68

79
@mcp.tool()
8-
async def sadd(name: str, value: str, expire_seconds: int = None) -> str:
10+
async def sadd(name: str, value: str, expire_seconds: Optional[int] = None) -> str:
911
"""Add a value to a Redis set with an optional expiration time.
1012
1113
Args:
@@ -54,7 +56,7 @@ async def srem(name: str, value: str) -> str:
5456

5557

5658
@mcp.tool()
57-
async def smembers(name: str) -> list:
59+
async def smembers(name: str) -> Union[str, List[str]]:
5860
"""Get all members of a Redis set.
5961
6062
Args:

src/tools/sorted_set.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
1+
from typing import Optional
2+
13
from redis.exceptions import RedisError
24

35
from src.common.connection import RedisConnectionManager
46
from src.common.server import mcp
57

68

79
@mcp.tool()
8-
async def zadd(key: str, score: float, member: str, expiration: int = None) -> str:
10+
async def zadd(
11+
key: str, score: float, member: str, expiration: Optional[int] = None
12+
) -> str:
913
"""Add a member to a Redis sorted set with an optional expiration time.
1014
1115
Args:

src/tools/stream.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
1+
from typing import Dict, Any, Optional
2+
13
from redis.exceptions import RedisError
24

35
from src.common.connection import RedisConnectionManager
46
from src.common.server import mcp
57

68

79
@mcp.tool()
8-
async def xadd(key: str, fields: dict, expiration: int = None) -> str:
10+
async def xadd(
11+
key: str, fields: Dict[str, Any], expiration: Optional[int] = None
12+
) -> str:
913
"""Add an entry to a Redis stream with an optional expiration time.
1014
1115
Args:

src/tools/string.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from typing import Union
2+
from typing import Union, Optional
33

44
from redis.exceptions import RedisError
55
from redis import Redis
@@ -10,7 +10,9 @@
1010

1111
@mcp.tool()
1212
async def set(
13-
key: str, value: Union[str, bytes, int, float, dict], expiration: int = None
13+
key: str,
14+
value: Union[str, bytes, int, float, dict],
15+
expiration: Optional[int] = None,
1416
) -> str:
1517
"""Set a Redis string value with an optional expiration time.
1618

0 commit comments

Comments
 (0)