diff --git a/README.md b/README.md index a09ccac..e269f41 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,10 @@ # Redis MCP Server -[![Integration](https://github.com/redis/mcp-redis/actions/workflows/ci.yml/badge.svg?branch=main)](https://github.com/redis/lettuce/actions/workflows/integration.yml) +[![Integration](https://github.com/redis/mcp-redis/actions/workflows/ci.yml/badge.svg?branch=main)](https://github.com/redis/mcp-redis/actions/workflows/ci.yml) [![Python Version](https://img.shields.io/badge/python-3.13%2B-blue)](https://www.python.org/downloads/) [![MIT licensed](https://img.shields.io/badge/license-MIT-blue.svg)](./LICENSE.txt) [![smithery badge](https://smithery.ai/badge/@redis/mcp-redis)](https://smithery.ai/server/@redis/mcp-redis) [![Verified on MseeP](https://mseep.ai/badge.svg)](https://mseep.ai/app/70102150-efe0-4705-9f7d-87980109a279) -[![codecov](https://codecov.io/gh/redis/mcp-redis/branch/master/graph/badge.svg?token=yenl5fzxxr)](https://codecov.io/gh/redis/mcp-redis) +![Docker Image Version](https://img.shields.io/docker/v/mcp/redis?sort=semver&logo=docker&label=Docker) [![Discord](https://img.shields.io/discord/697882427875393627.svg?style=social&logo=discord)](https://discord.gg/redis) diff --git a/src/common/server.py b/src/common/server.py index ae14068..c27cedb 100644 --- a/src/common/server.py +++ b/src/common/server.py @@ -2,11 +2,14 @@ import pkgutil from mcp.server.fastmcp import FastMCP + def load_tools(): import src.tools as tools_pkg + for _, module_name, _ in pkgutil.iter_modules(tools_pkg.__path__): importlib.import_module(f"src.tools.{module_name}") + # Initialize FastMCP server mcp = FastMCP("Redis MCP Server", dependencies=["redis", "dotenv", "numpy"]) diff --git a/src/tools/hash.py b/src/tools/hash.py index 60212ac..a061a65 100644 --- a/src/tools/hash.py +++ b/src/tools/hash.py @@ -1,3 +1,5 @@ +from typing import Union + import numpy as np from redis.exceptions import RedisError @@ -7,7 +9,7 @@ @mcp.tool() async def hset( - name: str, key: str, value: str | int | float, expire_seconds: int = None + name: str, key: str, value: Union[str, int, float], expire_seconds: int = None ) -> str: """Set a field in a hash stored at key with an optional expiration time. diff --git a/src/tools/string.py b/src/tools/string.py index d5e22ad..c3e191e 100644 --- a/src/tools/string.py +++ b/src/tools/string.py @@ -9,7 +9,9 @@ @mcp.tool() -async def set(key: str, value: Union[str, bytes, int, float, dict], expiration: int = None) -> str: +async def set( + key: str, value: Union[str, bytes, int, float, dict], expiration: int = None +) -> str: """Set a Redis string value with an optional expiration time. Args: diff --git a/tests/test_integration.py b/tests/test_integration.py new file mode 100644 index 0000000..2d5a1f7 --- /dev/null +++ b/tests/test_integration.py @@ -0,0 +1,391 @@ +""" +Integration tests for Redis MCP Server. + +These tests actually start the MCP server process and verify it can handle real requests. +""" + +import json +import subprocess +import sys +import time +import os +from pathlib import Path + +import pytest + + +def _redis_available(): + """Check if Redis is available for testing.""" + try: + import redis + + r = redis.Redis(host="localhost", port=6379, decode_responses=True) + r.ping() + return True + except Exception: + return False + + +def _create_server_process(project_root): + """Create a server process with proper encoding for cross-platform compatibility.""" + return subprocess.Popen( + [sys.executable, "-m", "src.main"], + cwd=project_root, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + encoding="utf-8", + errors="replace", # Replace invalid characters instead of failing + env={"REDIS_HOST": "localhost", "REDIS_PORT": "6379", **dict(os.environ)}, + ) + + +@pytest.mark.integration +class TestMCPServerIntegration: + """Integration tests that start the actual MCP server.""" + + @pytest.fixture + def server_process(self): + """Start the MCP server process for testing.""" + # Get the project root directory + project_root = Path(__file__).parent.parent + + # Start the server process with proper encoding for cross-platform compatibility + process = _create_server_process(project_root) + + # Give the server a moment to start + time.sleep(1) + + yield process + + # Clean up + process.terminate() + try: + process.wait(timeout=5) + except subprocess.TimeoutExpired: + process.kill() + process.wait() + + def test_server_starts_successfully(self, server_process): + """Test that the MCP server starts without crashing.""" + # Check if process is still running + assert server_process.poll() is None, "Server process should be running" + + # Check for startup message in stderr + # Note: MCP servers typically output startup info to stderr + time.sleep(0.5) # Give time for startup message + + # The server should still be running + assert server_process.poll() is None + + def test_server_handles_unicode_on_windows(self, server_process): + """Test that the server handles Unicode properly on Windows.""" + # This test specifically addresses the Windows Unicode decode error + # Check if process is still running + assert server_process.poll() is None, "Server process should be running" + + # Try to read any available output without blocking + # This should not cause a UnicodeDecodeError on Windows + try: + # Use a short timeout to avoid blocking + import select + import sys + + if sys.platform == "win32": + # On Windows, we can't use select, so just check if process is alive + time.sleep(0.1) + assert server_process.poll() is None + else: + # On Unix-like systems, we can use select + ready, _, _ = select.select([server_process.stdout], [], [], 0.1) + # If there's output available, try to read it + if ready: + try: + server_process.stdout.read(1) # Read just one character + # If we get here, Unicode handling is working + assert True + except UnicodeDecodeError: + pytest.fail("Unicode decode error occurred") + + except Exception: + # If any other error occurs, that's fine - we're just testing Unicode handling + pass + + # Main assertion: process should still be running + assert server_process.poll() is None + + def test_server_responds_to_initialize_request(self, server_process): + """Test that the server responds to MCP initialize request.""" + # MCP initialize request + initialize_request = { + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "test-client", "version": "1.0.0"}, + }, + } + + # Send the request + request_json = json.dumps(initialize_request) + "\n" + server_process.stdin.write(request_json) + server_process.stdin.flush() + + # Read the response + response_line = server_process.stdout.readline() + assert response_line.strip(), "Server should respond to initialize request" + + # Parse the response + try: + response = json.loads(response_line) + assert response.get("jsonrpc") == "2.0" + assert response.get("id") == 1 + assert "result" in response + except json.JSONDecodeError: + pytest.fail(f"Invalid JSON response: {response_line}") + + def test_server_lists_tools(self, server_process): + """Test that the server can list available tools.""" + # First initialize + initialize_request = { + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "test-client", "version": "1.0.0"}, + }, + } + + server_process.stdin.write(json.dumps(initialize_request) + "\n") + server_process.stdin.flush() + server_process.stdout.readline() # Read initialize response + + # Send initialized notification + initialized_notification = { + "jsonrpc": "2.0", + "method": "notifications/initialized", + } + server_process.stdin.write(json.dumps(initialized_notification) + "\n") + server_process.stdin.flush() + + # Request tools list + tools_request = {"jsonrpc": "2.0", "id": 2, "method": "tools/list"} + + server_process.stdin.write(json.dumps(tools_request) + "\n") + server_process.stdin.flush() + + # Read the response + response_line = server_process.stdout.readline() + response = json.loads(response_line) + + assert response.get("jsonrpc") == "2.0" + assert response.get("id") == 2 + assert "result" in response + assert "tools" in response["result"] + + # Verify we have some Redis tools + tools = response["result"]["tools"] + tool_names = [tool["name"] for tool in tools] + + # Should have basic Redis operations + expected_tools = [ + "hset", + "hget", + "hdel", + "hgetall", + "hexists", + "set_vector_in_hash", + "get_vector_from_hash", + "json_set", + "json_get", + "json_del", + "lpush", + "rpush", + "lpop", + "rpop", + "lrange", + "llen", + "delete", + "type", + "expire", + "rename", + "scan_keys", + "scan_all_keys", + "publish", + "subscribe", + "unsubscribe", + "get_indexes", + "get_index_info", + "get_indexed_keys_number", + "create_vector_index_hash", + "vector_search_hash", + "dbsize", + "info", + "client_list", + "sadd", + "srem", + "smembers", + "zadd", + "zrange", + "zrem", + "xadd", + "xrange", + "xdel", + "set", + "get", + ] + for tool in tool_names: + assert tool in expected_tools, ( + f"Expected tool '{tool}' not found in {tool_names}" + ) + + def test_server_tool_count_and_names(self, server_process): + """Test that the server registers the correct number of tools with expected names.""" + # Initialize the server + self._initialize_server(server_process) + + # Request tools list + tools_request = {"jsonrpc": "2.0", "id": 3, "method": "tools/list"} + + server_process.stdin.write(json.dumps(tools_request) + "\n") + server_process.stdin.flush() + + # Read the response + response_line = server_process.stdout.readline() + response = json.loads(response_line) + + assert response.get("jsonrpc") == "2.0" + assert response.get("id") == 3 + assert "result" in response + assert "tools" in response["result"] + + tools = response["result"]["tools"] + tool_names = [tool["name"] for tool in tools] + + # Expected tool count (based on @mcp.tool() decorators in codebase) + expected_tool_count = 44 + assert len(tools) == expected_tool_count, ( + f"Expected {expected_tool_count} tools, but got {len(tools)}" + ) + + # Expected tool names (alphabetically sorted for easier verification) + expected_tools = [ + "client_list", + "create_vector_index_hash", + "dbsize", + "delete", + "expire", + "get", + "get_index_info", + "get_indexed_keys_number", + "get_indexes", + "get_vector_from_hash", + "hdel", + "hexists", + "hget", + "hgetall", + "hset", + "info", + "json_del", + "json_get", + "json_set", + "llen", + "lpop", + "lpush", + "lrange", + "publish", + "rename", + "rpop", + "rpush", + "sadd", + "scan_all_keys", + "scan_keys", + "set", + "set_vector_in_hash", + "smembers", + "srem", + "subscribe", + "type", + "unsubscribe", + "vector_search_hash", + "xadd", + "xdel", + "xrange", + "zadd", + "zrange", + "zrem", + ] + + # Verify all expected tools are present + missing_tools = set(expected_tools) - set(tool_names) + extra_tools = set(tool_names) - set(expected_tools) + + assert not missing_tools, f"Missing expected tools: {sorted(missing_tools)}" + assert not extra_tools, f"Found unexpected tools: {sorted(extra_tools)}" + + # Verify tool categories are represented + tool_categories = { + "string": ["get", "set"], + "hash": ["hget", "hset", "hgetall", "hdel", "hexists"], + "list": ["lpush", "rpush", "lpop", "rpop", "lrange", "llen"], + "set": ["sadd", "srem", "smembers"], + "sorted_set": ["zadd", "zrem", "zrange"], + "stream": ["xadd", "xdel", "xrange"], + "json": ["json_get", "json_set", "json_del"], + "pub_sub": ["publish", "subscribe", "unsubscribe"], + "server_mgmt": ["dbsize", "info", "client_list"], + "misc": [ + "delete", + "expire", + "rename", + "type", + "scan_keys", + "scan_all_keys", + ], + "vector_search": [ + "create_vector_index_hash", + "vector_search_hash", + "get_indexes", + "get_index_info", + "set_vector_in_hash", + "get_vector_from_hash", + "get_indexed_keys_number", + ], + } + + for category, category_tools in tool_categories.items(): + for tool in category_tools: + assert tool in tool_names, ( + f"Tool '{tool}' from category '{category}' not found in registered tools" + ) + + def _initialize_server(self, server_process): + """Helper to initialize the MCP server.""" + # Send initialize request + initialize_request = { + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "test-client", "version": "1.0.0"}, + }, + } + + server_process.stdin.write(json.dumps(initialize_request) + "\n") + server_process.stdin.flush() + server_process.stdout.readline() # Read response + + # Send initialized notification + initialized_notification = { + "jsonrpc": "2.0", + "method": "notifications/initialized", + } + server_process.stdin.write(json.dumps(initialized_notification) + "\n") + server_process.stdin.flush() diff --git a/tests/tools/test_json.py b/tests/tools/test_json.py index 2f36c85..6c39743 100644 --- a/tests/tools/test_json.py +++ b/tests/tools/test_json.py @@ -2,6 +2,8 @@ Unit tests for src/tools/json.py """ +import json + import pytest from redis.exceptions import RedisError @@ -82,7 +84,8 @@ async def test_json_get_success( result = await json_get("test_doc", "$") mock_redis.json.return_value.get.assert_called_once_with("test_doc", "$") - assert result == sample_json_data + # json_get returns a JSON string representation + assert result == json.dumps(sample_json_data, ensure_ascii=False, indent=2) @pytest.mark.asyncio async def test_json_get_specific_field(self, mock_redis_connection_manager): @@ -93,7 +96,8 @@ async def test_json_get_specific_field(self, mock_redis_connection_manager): result = await json_get("test_doc", "$.name") mock_redis.json.return_value.get.assert_called_once_with("test_doc", "$.name") - assert result == ["John Doe"] + # json_get returns a JSON string representation + assert result == json.dumps(["John Doe"], ensure_ascii=False, indent=2) @pytest.mark.asyncio async def test_json_get_default_path( @@ -106,7 +110,8 @@ async def test_json_get_default_path( result = await json_get("test_doc") mock_redis.json.return_value.get.assert_called_once_with("test_doc", "$") - assert result == sample_json_data + # json_get returns a JSON string representation + assert result == json.dumps(sample_json_data, ensure_ascii=False, indent=2) @pytest.mark.asyncio async def test_json_get_not_found(self, mock_redis_connection_manager): @@ -226,7 +231,8 @@ async def test_json_get_array_element(self, mock_redis_connection_manager): mock_redis.json.return_value.get.assert_called_once_with( "test_doc", "$.items[0]" ) - assert result == ["first_item"] + # json_get returns a JSON string representation + assert result == json.dumps(["first_item"], ensure_ascii=False, indent=2) @pytest.mark.asyncio async def test_json_operations_with_numeric_values( @@ -243,7 +249,7 @@ async def test_json_operations_with_numeric_values( # Get numeric value result = await json_get("test_doc", "$.count") - assert result == [42] + assert result == json.dumps([42], ensure_ascii=False, indent=2) @pytest.mark.asyncio async def test_json_operations_with_boolean_values( @@ -262,7 +268,7 @@ async def test_json_operations_with_boolean_values( # Get boolean value result = await json_get("test_doc", "$.active") - assert result == [True] + assert result == json.dumps([True], ensure_ascii=False, indent=2) @pytest.mark.asyncio async def test_json_set_expiration_error(self, mock_redis_connection_manager): @@ -310,4 +316,4 @@ async def test_json_operations_with_null_values( # Get null value result = await json_get("test_doc", "$.optional_field") - assert result == [None] + assert result == json.dumps([None], ensure_ascii=False, indent=2) diff --git a/tests/tools/test_redis_query_engine.py b/tests/tools/test_redis_query_engine.py index 8c6812a..1d25be0 100644 --- a/tests/tools/test_redis_query_engine.py +++ b/tests/tools/test_redis_query_engine.py @@ -256,7 +256,8 @@ async def test_get_index_info_success(self, mock_redis_connection_manager): mock_redis.ft.assert_called_once_with("vector_index") mock_ft.info.assert_called_once() - assert result == mock_info + # get_index_info returns a JSON string representation + assert result == json.dumps(mock_info, ensure_ascii=False, indent=2) @pytest.mark.asyncio async def test_get_index_info_default_index(self, mock_redis_connection_manager): @@ -269,7 +270,10 @@ async def test_get_index_info_default_index(self, mock_redis_connection_manager) result = await get_index_info("vector_index") mock_redis.ft.assert_called_once_with("vector_index") - assert result == {"index_name": "vector_index"} + # get_index_info returns a JSON string representation + assert result == json.dumps( + {"index_name": "vector_index"}, ensure_ascii=False, indent=2 + ) @pytest.mark.asyncio async def test_get_index_info_redis_error(self, mock_redis_connection_manager): diff --git a/tests/tools/test_string.py b/tests/tools/test_string.py index 1fc6316..ed8ed91 100644 --- a/tests/tools/test_string.py +++ b/tests/tools/test_string.py @@ -21,7 +21,7 @@ async def test_set_success(self, mock_redis_connection_manager): result = await set("test_key", "test_value") - mock_redis.set.assert_called_once_with("test_key", "test_value") + mock_redis.set.assert_called_once_with("test_key", b"test_value") assert "Successfully set test_key" in result @pytest.mark.asyncio @@ -32,7 +32,7 @@ async def test_set_with_expiration(self, mock_redis_connection_manager): result = await set("test_key", "test_value", 60) - mock_redis.setex.assert_called_once_with("test_key", 60, "test_value") + mock_redis.setex.assert_called_once_with("test_key", 60, b"test_value") assert "Successfully set test_key" in result assert "with expiration 60 seconds" in result @@ -102,13 +102,12 @@ async def test_get_redis_error(self, mock_redis_connection_manager): async def test_get_empty_string_value(self, mock_redis_connection_manager): """Test string get operation returning empty string.""" mock_redis = mock_redis_connection_manager - mock_redis.get.return_value = "" + mock_redis.get.return_value = b"" # Redis returns bytes result = await get("test_key") - # Current implementation treats empty string as falsy, so it returns "does not exist" - # This is actually a bug - empty string is a valid Redis value - assert "Key test_key does not exist" in result + # The implementation correctly handles empty bytes and returns empty string + assert result == "" @pytest.mark.asyncio async def test_set_with_zero_expiration(self, mock_redis_connection_manager): @@ -119,7 +118,7 @@ async def test_set_with_zero_expiration(self, mock_redis_connection_manager): result = await set("test_key", "test_value", 0) # Should use regular set, not setex for zero expiration - mock_redis.set.assert_called_once_with("test_key", "test_value") + mock_redis.set.assert_called_once_with("test_key", b"test_value") assert "Successfully set test_key" in result @pytest.mark.asyncio @@ -131,7 +130,7 @@ async def test_set_with_negative_expiration(self, mock_redis_connection_manager) result = await set("test_key", "test_value", -1) # Negative expiration is truthy in Python, so setex is called - mock_redis.setex.assert_called_once_with("test_key", -1, "test_value") + mock_redis.setex.assert_called_once_with("test_key", -1, b"test_value") assert "Successfully set test_key" in result assert "with expiration -1 seconds" in result @@ -143,7 +142,7 @@ async def test_set_with_large_expiration(self, mock_redis_connection_manager): result = await set("test_key", "test_value", 86400) # 24 hours - mock_redis.setex.assert_called_once_with("test_key", 86400, "test_value") + mock_redis.setex.assert_called_once_with("test_key", 86400, b"test_value") assert "with expiration 86400 seconds" in result @pytest.mark.asyncio @@ -167,7 +166,9 @@ async def test_set_with_unicode_value(self, mock_redis_connection_manager): unicode_value = "测试值 🚀" result = await set("test_key", unicode_value) - mock_redis.set.assert_called_once_with("test_key", unicode_value) + mock_redis.set.assert_called_once_with( + "test_key", unicode_value.encode("utf-8") + ) assert "Successfully set test_key" in result @pytest.mark.asyncio @@ -183,6 +184,8 @@ async def test_connection_manager_called_correctly(self): await set("test_key", "test_value") mock_get_conn.assert_called_once() + # Verify the actual call was made with bytes + mock_redis.set.assert_called_once_with("test_key", b"test_value") @pytest.mark.asyncio async def test_function_signatures(self):