diff --git a/.github/scripts/increment_version.py b/.github/scripts/increment_version.py index 1d7257c..0b5d41d 100644 --- a/.github/scripts/increment_version.py +++ b/.github/scripts/increment_version.py @@ -97,17 +97,13 @@ def extract_version(pyproject_content: str): return VersionLine(old_line=version_line, version_str=version_part) -def increment_version_at_pyproject( - pyproject_path: str, inc_type: str, with_beta: bool -) -> str: +def increment_version_at_pyproject(pyproject_path: str, inc_type: str, with_beta: bool) -> str: with open(pyproject_path, "rt") as f: setup_content = f.read() version = extract_version(setup_content) version.increment(inc_type, with_beta) - setup_content = setup_content.replace( - version.old_line, version.version_line_with_mark() - ) + setup_content = setup_content.replace(version.old_line, version.version_line_with_mark()) with open(pyproject_path, "w") as f: f.write(setup_content) @@ -143,9 +139,7 @@ def main(): help="increment version type: patch or minor", choices=["minor", "patch"], ) - parser.add_argument( - "--beta", choices=["true", "false"], help="is beta version" - ) + parser.add_argument("--beta", choices=["true", "false"], help="is beta version") parser.add_argument( "--changelog-path", default=DEFAULT_CHANGELOG_PATH, @@ -158,13 +152,11 @@ def main(): is_beta = args.beta == "true" - new_version = increment_version_at_pyproject( - args.pyproject_path, args.inc_type, is_beta - ) + new_version = increment_version_at_pyproject(args.pyproject_path, args.inc_type, is_beta) add_changelog_version(args.changelog_path, new_version) set_version_in_version_file(DEFAULT_YDB_VERSION_FILE, new_version) print(new_version) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/.github/workflows/run-linters.yml b/.github/workflows/run-linters.yml new file mode 100644 index 0000000..2e9c751 --- /dev/null +++ b/.github/workflows/run-linters.yml @@ -0,0 +1,26 @@ +name: Run Linters + +on: + push: + branches: + - '**' + pull_request_target: + branches: + - '**' + workflow_dispatch: + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.13' + - name: Install dependencies + run: | + python -m pip install --upgrade pip + - name: Run linters + run: make lint \ No newline at end of file diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index 99e3cce..960550a 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -22,6 +22,5 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -r requirements-dev.txt - name: Run tests run: make test \ No newline at end of file diff --git a/Makefile b/Makefile index b7c49ad..1489cdd 100644 --- a/Makefile +++ b/Makefile @@ -43,15 +43,12 @@ run-server: # Run lint checks lint: dev - flake8 ydb_mcp tests + ruff check ydb_mcp tests mypy ydb_mcp - black --check ydb_mcp tests - isort --check-only --profile black ydb_mcp tests # Format code format: dev - black ydb_mcp tests - isort --profile black ydb_mcp tests + ruff format ydb_mcp tests # Install package install: @@ -59,5 +56,4 @@ install: # Install development dependencies dev: - pip install -e ".[dev]" - pip install -r requirements-dev.txt \ No newline at end of file + pip install -e ".[dev]" \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 7237b20..4286c03 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,24 +27,35 @@ dependencies = [ [project.optional-dependencies] dev = [ - "pytest>=7.0.0", - "black>=22.0.0", - "isort>=5.0.0", - "mypy>=0.9.0", + "pytest>=7.3.1", + "pytest-asyncio>=0.21.0", + "pytest-cov>=4.1.0", + "pytest-assume>=2.4.3", + "mypy>=1.3.0", + "ruff>=0.11.0", + "docker>=7.0.0", ] [project.scripts] ydb-mcp = "ydb_mcp.__main__:main" -[tool.black] -line-length = 100 -target-version = ["py38"] +[tool.ruff] +line-length = 121 +target-version = "py310" -[tool.isort] -profile = "black" -line_length = 100 +[tool.ruff.lint] +select = [ + "E", # pycodestyle + "F", # pyflakes + "I", # isort + # TODO: extend with more rules +] [tool.mypy] -python_version = "3.8" +python_version = "3.10" warn_return_any = true -warn_unused_configs = true \ No newline at end of file +warn_unused_configs = true + +[[tool.mypy.overrides]] +module = "ydb.*" +ignore_missing_imports = true \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt index 789942e..3ca2d8e 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,9 +3,6 @@ pytest>=7.3.1 pytest-asyncio>=0.21.0 pytest-cov>=4.1.0 pytest-assume>=2.4.3 -black>=23.3.0 -isort>=5.12.0 mypy>=1.3.0 -flake8>=6.0.0 -httpx>=0.24.0 +ruff>=0.11.0 docker>=7.0.0 # For YDB Docker container management in tests \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 6a9bd0b..cc5d73b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,7 @@ """Pytest configuration for testing YDB MCP server.""" import os -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, patch import pytest diff --git a/tests/docker_utils.py b/tests/docker_utils.py index 8178458..34f28e4 100644 --- a/tests/docker_utils.py +++ b/tests/docker_utils.py @@ -1,6 +1,5 @@ import logging import os -import platform import socket import time diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 1c92515..afac3a8 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -15,7 +15,6 @@ from urllib.parse import urlparse import pytest -import ydb from tests.docker_utils import start_ydb_container, stop_container, wait_for_port from ydb_mcp.server import AUTH_MODE_ANONYMOUS, YDBMCPServer @@ -279,7 +278,7 @@ async def session_mcp_server(ydb_server): @pytest.fixture(scope="function") -async def mcp_server(session_mcp_server): +async def mcp_server(session_mcp_server): # noqa: F811 """Provide a clean MCP server connection for each test by restarting the connection.""" if session_mcp_server is None: pytest.fail("Could not get a valid MCP server instance") diff --git a/tests/integration/test_authentication_integration.py b/tests/integration/test_authentication_integration.py index 755475c..d4f1354 100644 --- a/tests/integration/test_authentication_integration.py +++ b/tests/integration/test_authentication_integration.py @@ -14,9 +14,7 @@ from ydb_mcp.server import AUTH_MODE_ANONYMOUS, AUTH_MODE_LOGIN_PASSWORD # Suppress the utcfromtimestamp deprecation warning from the YDB library -warnings.filterwarnings( - "ignore", message="datetime.datetime.utcfromtimestamp.*", category=DeprecationWarning -) +warnings.filterwarnings("ignore", message="datetime.datetime.utcfromtimestamp.*", category=DeprecationWarning) # Table name used for tests - using timestamp to avoid conflicts TEST_TABLE = f"mcp_integration_test_{int(time.time())}" @@ -58,12 +56,7 @@ async def test_login_password_authentication(mcp_server): # Verify we can execute a query result = await call_mcp_tool(mcp_server, "ydb_query", sql="SELECT 1+1 as result") # Parse the JSON from the 'text' field if present - if ( - isinstance(result, list) - and len(result) > 0 - and isinstance(result[0], dict) - and "text" in result[0] - ): + if isinstance(result, list) and len(result) > 0 and isinstance(result[0], dict) and "text" in result[0]: parsed = json.loads(result[0]["text"]) else: parsed = result @@ -80,12 +73,7 @@ async def test_login_password_authentication(mcp_server): # Query should fail with auth error result = await call_mcp_tool(mcp_server, "ydb_query", sql="SELECT 1+1 as result") # Parse the JSON from the 'text' field if present - if ( - isinstance(result, list) - and len(result) > 0 - and isinstance(result[0], dict) - and "text" in result[0] - ): + if isinstance(result, list) and len(result) > 0 and isinstance(result[0], dict) and "text" in result[0]: parsed = json.loads(result[0]["text"]) else: parsed = result @@ -111,9 +99,9 @@ async def test_login_password_authentication(mcp_server): # Allow empty error message as valid pass else: - assert any( - keyword in error_msg for keyword in all_keywords - ), f"Unexpected error message: {parsed.get('error')}" + assert any(keyword in error_msg for keyword in all_keywords), ( + f"Unexpected error message: {parsed.get('error')}" + ) finally: # Switch back to anonymous auth to clean up (fixture will handle final state reset) diff --git a/tests/integration/test_mcp_server_integration.py b/tests/integration/test_mcp_server_integration.py index a318fc2..e38549b 100644 --- a/tests/integration/test_mcp_server_integration.py +++ b/tests/integration/test_mcp_server_integration.py @@ -8,7 +8,6 @@ import datetime import json import logging -import os import time import warnings @@ -18,9 +17,7 @@ from tests.integration.conftest import call_mcp_tool # Suppress the utcfromtimestamp deprecation warning from the YDB library -warnings.filterwarnings( - "ignore", message="datetime.datetime.utcfromtimestamp.*", category=DeprecationWarning -) +warnings.filterwarnings("ignore", message="datetime.datetime.utcfromtimestamp.*", category=DeprecationWarning) # Table name used for tests - using timestamp to avoid conflicts TEST_TABLE = f"mcp_integration_test_{int(time.time())}" @@ -37,23 +34,19 @@ async def test_simple_query(mcp_server): """Test a basic YDB query.""" result = await call_mcp_tool(mcp_server, "ydb_query", sql="SELECT 1+1 as result") - assert ( - isinstance(result, list) and len(result) > 0 and "text" in result[0] - ), f"Result should be a list of dicts with 'text': {result}" + assert isinstance(result, list) and len(result) > 0 and "text" in result[0], ( + f"Result should be a list of dicts with 'text': {result}" + ) parsed = json.loads(result[0]["text"]) assert "result_sets" in parsed, f"No result_sets in parsed result: {parsed}" - assert ( - len(parsed["result_sets"]) == 1 - ), f"Expected 1 result set, got {len(parsed['result_sets'])}" + assert len(parsed["result_sets"]) == 1, f"Expected 1 result set, got {len(parsed['result_sets'])}" first_result = parsed["result_sets"][0] assert "columns" in first_result, f"No columns in result: {first_result}" assert "rows" in first_result, f"No rows in result: {first_result}" assert len(first_result["rows"]) > 0, f"Empty result set: {first_result}" - assert ( - first_result["columns"][0] == "result" - ), f"Unexpected column name: {first_result['columns'][0]}" + assert first_result["columns"][0] == "result", f"Unexpected column name: {first_result['columns'][0]}" assert first_result["rows"][0][0] == 2, f"Unexpected result value: {first_result['rows'][0][0]}" @@ -75,11 +68,9 @@ async def test_create_table_and_query(mcp_server): ); """, ) - assert ( - isinstance(create_result, list) - and len(create_result) > 0 - and "text" in create_result[0] - ), f"Result should be a list of dicts with 'text': {create_result}" + assert isinstance(create_result, list) and len(create_result) > 0 and "text" in create_result[0], ( + f"Result should be a list of dicts with 'text': {create_result}" + ) parsed = json.loads(create_result[0]["text"]) assert "error" not in parsed, f"Error creating table: {parsed}" @@ -92,28 +83,22 @@ async def test_create_table_and_query(mcp_server): VALUES (1, 'Test 1'), (2, 'Test 2'), (3, 'Test 3'); """, ) - assert ( - isinstance(insert_result, list) - and len(insert_result) > 0 - and "text" in insert_result[0] - ), f"Result should be a list of dicts with 'text': {insert_result}" + assert isinstance(insert_result, list) and len(insert_result) > 0 and "text" in insert_result[0], ( + f"Result should be a list of dicts with 'text': {insert_result}" + ) parsed = json.loads(insert_result[0]["text"]) assert "error" not in parsed, f"Error inserting data: {parsed}" # Query data - query_result = await call_mcp_tool( - mcp_server, "ydb_query", sql=f"SELECT * FROM {test_table_name} ORDER BY id;" - ) + query_result = await call_mcp_tool(mcp_server, "ydb_query", sql=f"SELECT * FROM {test_table_name} ORDER BY id;") - assert ( - isinstance(query_result, list) and len(query_result) > 0 and "text" in query_result[0] - ), f"Result should be a list of dicts with 'text': {query_result}" + assert isinstance(query_result, list) and len(query_result) > 0 and "text" in query_result[0], ( + f"Result should be a list of dicts with 'text': {query_result}" + ) parsed = json.loads(query_result[0]["text"]) assert "result_sets" in parsed, f"No result_sets in parsed result: {parsed}" - assert ( - len(parsed["result_sets"]) == 1 - ), f"Expected 1 result set, got {len(parsed['result_sets'])}" + assert len(parsed["result_sets"]) == 1, f"Expected 1 result set, got {len(parsed['result_sets'])}" first_result = parsed["result_sets"][0] assert "columns" in first_result, f"No columns in result: {first_result}" @@ -121,36 +106,26 @@ async def test_create_table_and_query(mcp_server): assert len(first_result["rows"]) == 3, f"Expected 3 rows, got {len(first_result['rows'])}" # Check if 'id' and 'name' columns are present - assert ( - "id" in first_result["columns"] - ), f"Column 'id' not found in {first_result['columns']}" - assert ( - "name" in first_result["columns"] - ), f"Column 'name' not found in {first_result['columns']}" + assert "id" in first_result["columns"], f"Column 'id' not found in {first_result['columns']}" + assert "name" in first_result["columns"], f"Column 'name' not found in {first_result['columns']}" # Get column indexes id_idx = first_result["columns"].index("id") name_idx = first_result["columns"].index("name") # Verify values - assert ( - first_result["rows"][0][id_idx] == 1 - ), f"Expected id=1, got {first_result['rows'][0][id_idx]}" + assert first_result["rows"][0][id_idx] == 1, f"Expected id=1, got {first_result['rows'][0][id_idx]}" # YDB may return strings as bytes, so handle both cases name_value = first_result["rows"][0][name_idx] if isinstance(name_value, bytes): - assert ( - name_value.decode("utf-8") == "Test 1" - ), f"Expected name='Test 1', got {name_value.decode('utf-8')}" + assert name_value.decode("utf-8") == "Test 1", f"Expected name='Test 1', got {name_value.decode('utf-8')}" else: assert name_value == "Test 1", f"Expected name='Test 1', got {name_value}" finally: # Cleanup - drop the table after test - cleanup_result = await call_mcp_tool( - mcp_server, "ydb_query", sql=f"DROP TABLE {test_table_name};" - ) + cleanup_result = await call_mcp_tool(mcp_server, "ydb_query", sql=f"DROP TABLE {test_table_name};") logger.debug(f"Table cleanup result: {cleanup_result}") @@ -170,15 +145,13 @@ async def test_parameterized_query(mcp_server): ), ) - assert ( - isinstance(result, list) and len(result) > 0 and "text" in result[0] - ), f"Result should be a list of dicts with 'text': {result}" + assert isinstance(result, list) and len(result) > 0 and "text" in result[0], ( + f"Result should be a list of dicts with 'text': {result}" + ) parsed = json.loads(result[0]["text"]) assert "result_sets" in parsed, f"No result_sets in parsed result: {parsed}" - assert ( - len(parsed["result_sets"]) == 1 - ), f"Expected 1 result set, got {len(parsed['result_sets'])}" + assert len(parsed["result_sets"]) == 1, f"Expected 1 result set, got {len(parsed['result_sets'])}" first_result = parsed["result_sets"][0] assert "columns" in first_result, f"No columns in result: {first_result}" @@ -186,28 +159,22 @@ async def test_parameterized_query(mcp_server): assert len(first_result["rows"]) > 0, f"Empty result set: {first_result}" # Check column names - assert ( - "answer" in first_result["columns"] - ), f"Expected 'answer' column in result: {first_result['columns']}" - assert ( - "greeting" in first_result["columns"] - ), f"Expected 'greeting' column in result: {first_result['columns']}" + assert "answer" in first_result["columns"], f"Expected 'answer' column in result: {first_result['columns']}" + assert "greeting" in first_result["columns"], f"Expected 'greeting' column in result: {first_result['columns']}" # Check values answer_idx = first_result["columns"].index("answer") greeting_idx = first_result["columns"].index("greeting") - assert ( - first_result["rows"][0][answer_idx] == -42 - ), f"Expected answer=-42, got {first_result['rows'][0][answer_idx]}" + assert first_result["rows"][0][answer_idx] == -42, f"Expected answer=-42, got {first_result['rows'][0][answer_idx]}" # YDB may return strings either as bytes or as strings depending on context greeting_value = first_result["rows"][0][greeting_idx] if isinstance(greeting_value, bytes): # If bytes, decode to string - assert ( - greeting_value.decode("utf-8") == "hello" - ), f"Expected greeting to decode to 'hello', got {greeting_value.decode('utf-8')}" + assert greeting_value.decode("utf-8") == "hello", ( + f"Expected greeting to decode to 'hello', got {greeting_value.decode('utf-8')}" + ) else: # If already string assert greeting_value == "hello", f"Expected greeting to be 'hello', got {greeting_value}" @@ -225,9 +192,9 @@ async def test_complex_query(mcp_server): ) # Check for result sets - assert ( - isinstance(result, list) and len(result) > 0 and "text" in result[0] - ), f"Result should be a list of dicts with 'text': {result}" + assert isinstance(result, list) and len(result) > 0 and "text" in result[0], ( + f"Result should be a list of dicts with 'text': {result}" + ) parsed = json.loads(result[0]["text"]) assert "result_sets" in parsed, f"No result_sets in parsed result: {parsed}" @@ -268,9 +235,9 @@ async def test_multiple_resultsets_with_tables(mcp_server): CREATE TABLE {test_table2} (id Uint64, value Double, PRIMARY KEY (id)); """, ) - assert ( - isinstance(setup_result, list) and len(setup_result) > 0 and "text" in setup_result[0] - ), f"Result should be a list of dicts with 'text': {setup_result}" + assert isinstance(setup_result, list) and len(setup_result) > 0 and "text" in setup_result[0], ( + f"Result should be a list of dicts with 'text': {setup_result}" + ) parsed = json.loads(setup_result[0]["text"]) assert "error" not in parsed, f"Error creating tables: {parsed}" @@ -282,11 +249,9 @@ async def test_multiple_resultsets_with_tables(mcp_server): UPSERT INTO {test_table1} (id, name) VALUES (1, 'First'), (2, 'Second'), (3, 'Third'); """, ) - assert ( - isinstance(insert_result, list) - and len(insert_result) > 0 - and "text" in insert_result[0] - ), f"Result should be a list of dicts with 'text': {insert_result}" + assert isinstance(insert_result, list) and len(insert_result) > 0 and "text" in insert_result[0], ( + f"Result should be a list of dicts with 'text': {insert_result}" + ) parsed = json.loads(insert_result[0]["text"]) assert "error" not in parsed, f"Error inserting data into table1: {parsed}" @@ -297,11 +262,9 @@ async def test_multiple_resultsets_with_tables(mcp_server): UPSERT INTO {test_table2} (id, value) VALUES (1, 10.5), (2, 20.75), (3, 30.25); """, ) - assert ( - isinstance(insert_result2, list) - and len(insert_result2) > 0 - and "text" in insert_result2[0] - ), f"Result should be a list of dicts with 'text': {insert_result2}" + assert isinstance(insert_result2, list) and len(insert_result2) > 0 and "text" in insert_result2[0], ( + f"Result should be a list of dicts with 'text': {insert_result2}" + ) parsed = json.loads(insert_result2[0]["text"]) assert "error" not in parsed, f"Error inserting data into table2: {parsed}" @@ -316,31 +279,25 @@ async def test_multiple_resultsets_with_tables(mcp_server): ) # Verify we have all result sets - assert ( - isinstance(result, list) and len(result) > 0 and "text" in result[0] - ), f"Result should be a list of dicts with 'text': {result}" + assert isinstance(result, list) and len(result) > 0 and "text" in result[0], ( + f"Result should be a list of dicts with 'text': {result}" + ) parsed = json.loads(result[0]["text"]) assert "result_sets" in parsed, f"No result_sets in parsed result: {parsed}" # Check first table results first_result = parsed["result_sets"][0] assert len(first_result["rows"]) == 3, "Expected 3 rows in first table" - assert ( - "id" in first_result["columns"] - ), f"Expected 'id' column in first table, got {first_result['columns']}" - assert ( - "name" in first_result["columns"] - ), f"Expected 'name' column in first table, got {first_result['columns']}" + assert "id" in first_result["columns"], f"Expected 'id' column in first table, got {first_result['columns']}" + assert "name" in first_result["columns"], f"Expected 'name' column in first table, got {first_result['columns']}" # Check second table results second_result = parsed["result_sets"][1] assert len(second_result["rows"]) == 3, "Expected 3 rows in second table" - assert ( - "id" in second_result["columns"] - ), f"Expected 'id' column in second table, got {second_result['columns']}" - assert ( - "value" in second_result["columns"] - ), f"Expected 'value' column in second table, got {second_result['columns']}" + assert "id" in second_result["columns"], f"Expected 'id' column in second table, got {second_result['columns']}" + assert "value" in second_result["columns"], ( + f"Expected 'value' column in second table, got {second_result['columns']}" + ) # Now test a join query - should return a single result set join_result = await call_mcp_tool( @@ -355,36 +312,30 @@ async def test_multiple_resultsets_with_tables(mcp_server): ) # Validate join results - assert ( - isinstance(join_result, list) and len(join_result) > 0 and "text" in join_result[0] - ), f"Result should be a list of dicts with 'text': {join_result}" + assert isinstance(join_result, list) and len(join_result) > 0 and "text" in join_result[0], ( + f"Result should be a list of dicts with 'text': {join_result}" + ) parsed = json.loads(join_result[0]["text"]) assert "result_sets" in parsed, "Join query should return result_sets" - assert ( - len(parsed["result_sets"]) == 1 - ), f"Expected 1 result set for join, got {len(parsed['result_sets'])}" + assert len(parsed["result_sets"]) == 1, f"Expected 1 result set for join, got {len(parsed['result_sets'])}" first_join_result = parsed["result_sets"][0] assert "columns" in first_join_result, "Join query should return columns" assert "rows" in first_join_result, "Join query should return rows" - assert ( - len(first_join_result["rows"]) == 3 - ), f"Expected 3 rows in join result, got {len(first_join_result['rows'])}" - assert ( - len(first_join_result["columns"]) == 3 - ), f"Expected 3 columns in join result, got {len(first_join_result['columns'])}" + assert len(first_join_result["rows"]) == 3, ( + f"Expected 3 rows in join result, got {len(first_join_result['rows'])}" + ) + assert len(first_join_result["columns"]) == 3, ( + f"Expected 3 columns in join result, got {len(first_join_result['columns'])}" + ) finally: # Cleanup - drop the tables after test try: - cleanup_result = await call_mcp_tool( - mcp_server, "ydb_query", sql=f"DROP TABLE {test_table1};" - ) + cleanup_result = await call_mcp_tool(mcp_server, "ydb_query", sql=f"DROP TABLE {test_table1};") logger.debug(f"Table1 cleanup result: {cleanup_result}") - cleanup_result2 = await call_mcp_tool( - mcp_server, "ydb_query", sql=f"DROP TABLE {test_table2};" - ) + cleanup_result2 = await call_mcp_tool(mcp_server, "ydb_query", sql=f"DROP TABLE {test_table2};") logger.debug(f"Table2 cleanup result: {cleanup_result2}") except Exception as e: logger.warning(f"Failed to clean up test tables: {e}") @@ -395,14 +346,12 @@ async def test_single_resultset_format(mcp_server): result = await call_mcp_tool(mcp_server, "ydb_query", sql="SELECT 42 as answer") # Single result set should have result_sets key with one item - assert ( - isinstance(result, list) and len(result) > 0 and "text" in result[0] - ), f"Result should be a list of dicts with 'text': {result}" + assert isinstance(result, list) and len(result) > 0 and "text" in result[0], ( + f"Result should be a list of dicts with 'text': {result}" + ) parsed = json.loads(result[0]["text"]) assert "result_sets" in parsed, "Single result should include result_sets key" - assert ( - len(parsed["result_sets"]) == 1 - ), f"Expected 1 result set, got {len(parsed['result_sets'])}" + assert len(parsed["result_sets"]) == 1, f"Expected 1 result set, got {len(parsed['result_sets'])}" first_result = parsed["result_sets"][0] assert "columns" in first_result, "Should have columns in result set" @@ -418,31 +367,23 @@ async def test_data_types(mcp_server): result = await call_mcp_tool(mcp_server, "ydb_query", sql="SELECT 1 AS value, 'test' AS text") # Basic result checks - assert ( - isinstance(result, list) and len(result) > 0 and "text" in result[0] - ), f"Result should be a list of dicts with 'text': {result}" + assert isinstance(result, list) and len(result) > 0 and "text" in result[0], ( + f"Result should be a list of dicts with 'text': {result}" + ) parsed = json.loads(result[0]["text"]) assert "result_sets" in parsed, f"No result_sets in parsed result: {parsed}" - assert ( - len(parsed["result_sets"]) == 1 - ), f"Expected 1 result set, got {len(parsed['result_sets'])}" + assert len(parsed["result_sets"]) == 1, f"Expected 1 result set, got {len(parsed['result_sets'])}" first_result = parsed["result_sets"][0] assert "columns" in first_result, f"No columns in result: {first_result}" assert "rows" in first_result, f"No rows in result: {first_result}" - assert ( - len(first_result["columns"]) == 2 - ), f"Expected 2 columns, got {len(first_result['columns'])}" + assert len(first_result["columns"]) == 2, f"Expected 2 columns, got {len(first_result['columns'])}" assert len(first_result["rows"]) == 1, f"Expected 1 row, got {len(first_result['rows'])}" # Verify column names - assert ( - "value" in first_result["columns"] - ), f"Expected column 'value', got {first_result['columns']}" - assert ( - "text" in first_result["columns"] - ), f"Expected column 'text', got {first_result['columns']}" + assert "value" in first_result["columns"], f"Expected column 'value', got {first_result['columns']}" + assert "text" in first_result["columns"], f"Expected column 'text', got {first_result['columns']}" # Verify values row = first_result["rows"][0] @@ -565,15 +506,13 @@ async def test_all_data_types(mcp_server): ) # Basic result checks - assert ( - isinstance(result, list) and len(result) > 0 and "text" in result[0] - ), f"Result should be a list of dicts with 'text': {result}" + assert isinstance(result, list) and len(result) > 0 and "text" in result[0], ( + f"Result should be a list of dicts with 'text': {result}" + ) parsed = json.loads(result[0]["text"]) assert "result_sets" in parsed, f"No result_sets in parsed result: {parsed}" - assert ( - len(parsed["result_sets"]) == 1 - ), f"Expected 1 result set, got {len(parsed['result_sets'])}" + assert len(parsed["result_sets"]) == 1, f"Expected 1 result set, got {len(parsed['result_sets'])}" first_result = parsed["result_sets"][0] assert "columns" in first_result, f"No columns in result: {first_result}" @@ -593,103 +532,67 @@ def get_value(column_name): # Test each data type # Boolean values - assert ( - get_value("bool_true") is True - ), f"Expected bool_true to be True, got {get_value('bool_true')}" - assert ( - get_value("bool_false") is False - ), f"Expected bool_false to be False, got {get_value('bool_false')}" + assert get_value("bool_true") is True, f"Expected bool_true to be True, got {get_value('bool_true')}" + assert get_value("bool_false") is False, f"Expected bool_false to be False, got {get_value('bool_false')}" # Integer types (signed) - assert ( - get_value("int8_min") == -128 - ), f"Expected int8_min to be -128, got {get_value('int8_min')}" + assert get_value("int8_min") == -128, f"Expected int8_min to be -128, got {get_value('int8_min')}" assert get_value("int8_max") == 127, f"Expected int8_max to be 127, got {get_value('int8_max')}" - assert ( - get_value("int16_min") == -32768 - ), f"Expected int16_min to be -32768, got {get_value('int16_min')}" - assert ( - get_value("int16_max") == 32767 - ), f"Expected int16_max to be 32767, got {get_value('int16_max')}" - assert ( - get_value("int32_min") == -2147483648 - ), f"Expected int32_min to be -2147483648, got {get_value('int32_min')}" - assert ( - get_value("int32_max") == 2147483647 - ), f"Expected int32_max to be 2147483647, got {get_value('int32_max')}" - assert ( - get_value("int64_min") == -9223372036854775808 - ), f"Expected int64_min to be -9223372036854775808, got {get_value('int64_min')}" - assert ( - get_value("int64_max") == 9223372036854775807 - ), f"Expected int64_max to be 9223372036854775807, got {get_value('int64_max')}" + assert get_value("int16_min") == -32768, f"Expected int16_min to be -32768, got {get_value('int16_min')}" + assert get_value("int16_max") == 32767, f"Expected int16_max to be 32767, got {get_value('int16_max')}" + assert get_value("int32_min") == -2147483648, f"Expected int32_min to be -2147483648, got {get_value('int32_min')}" + assert get_value("int32_max") == 2147483647, f"Expected int32_max to be 2147483647, got {get_value('int32_max')}" + assert get_value("int64_min") == -9223372036854775808, ( + f"Expected int64_min to be -9223372036854775808, got {get_value('int64_min')}" + ) + assert get_value("int64_max") == 9223372036854775807, ( + f"Expected int64_max to be 9223372036854775807, got {get_value('int64_max')}" + ) # Integer types (unsigned) assert get_value("uint8_min") == 0, f"Expected uint8_min to be 0, got {get_value('uint8_min')}" - assert ( - get_value("uint8_max") == 255 - ), f"Expected uint8_max to be 255, got {get_value('uint8_max')}" - assert ( - get_value("uint16_min") == 0 - ), f"Expected uint16_min to be 0, got {get_value('uint16_min')}" - assert ( - get_value("uint16_max") == 65535 - ), f"Expected uint16_max to be 65535, got {get_value('uint16_max')}" - assert ( - get_value("uint32_min") == 0 - ), f"Expected uint32_min to be 0, got {get_value('uint32_min')}" - assert ( - get_value("uint32_max") == 4294967295 - ), f"Expected uint32_max to be 4294967295, got {get_value('uint32_max')}" - assert ( - get_value("uint64_min") == 0 - ), f"Expected uint64_min to be 0, got {get_value('uint64_min')}" - assert ( - get_value("uint64_max") == 18446744073709551615 - ), f"Expected uint64_max to be 18446744073709551615, got {get_value('uint64_max')}" + assert get_value("uint8_max") == 255, f"Expected uint8_max to be 255, got {get_value('uint8_max')}" + assert get_value("uint16_min") == 0, f"Expected uint16_min to be 0, got {get_value('uint16_min')}" + assert get_value("uint16_max") == 65535, f"Expected uint16_max to be 65535, got {get_value('uint16_max')}" + assert get_value("uint32_min") == 0, f"Expected uint32_min to be 0, got {get_value('uint32_min')}" + assert get_value("uint32_max") == 4294967295, f"Expected uint32_max to be 4294967295, got {get_value('uint32_max')}" + assert get_value("uint64_min") == 0, f"Expected uint64_min to be 0, got {get_value('uint64_min')}" + assert get_value("uint64_max") == 18446744073709551615, ( + f"Expected uint64_max to be 18446744073709551615, got {get_value('uint64_max')}" + ) # Floating point types - assert ( - abs(get_value("float_value") - 3.14) < 0.0001 - ), f"Expected float_value to be close to 3.14, got {get_value('float_value')}" - assert ( - abs(get_value("double_value") - 2.7182818284590452) < 0.0000000000000001 - ), f"Expected double_value to be close to 2.7182818284590452, got {get_value('double_value')}" + assert abs(get_value("float_value") - 3.14) < 0.0001, ( + f"Expected float_value to be close to 3.14, got {get_value('float_value')}" + ) + assert abs(get_value("double_value") - 2.7182818284590452) < 0.0000000000000001, ( + f"Expected double_value to be close to 2.7182818284590452, got {get_value('double_value')}" + ) # String types - expect only str, not bytes string_value = get_value("string_value") - assert ( - string_value == "Hello, World!" - ), f"Expected string_value to be 'Hello, World!', got {string_value}" + assert string_value == "Hello, World!", f"Expected string_value to be 'Hello, World!', got {string_value}" utf8_value = get_value("utf8_value") assert utf8_value == "UTF8 строка", f"Expected utf8_value to be 'UTF8 строка', got {utf8_value}" uuid_value = get_value("uuid_value") - assert ( - uuid_value == "00000000-0000-0000-0000-000000000000" - ), f"Expected uuid_value to be '00000000-0000-0000-0000-000000000000', got {uuid_value}" + assert uuid_value == "00000000-0000-0000-0000-000000000000", ( + f"Expected uuid_value to be '00000000-0000-0000-0000-000000000000', got {uuid_value}" + ) json_value = get_value("json_value") - assert ( - json_value == '{"key": "value"}' - ), f"Expected json_value to be '{{'key': 'value'}}', got {json_value}" + assert json_value == '{"key": "value"}', f"Expected json_value to be '{{'key': 'value'}}', got {json_value}" # Date and time types - YDB returns these as Python datetime objects date_value = get_value("date_value") if isinstance(date_value, str): # Parse string to date parsed_date = datetime.date.fromisoformat(date_value) - assert parsed_date == datetime.date( - 2023, 7, 15 - ), f"Expected date_value to be 2023-07-15, got {parsed_date}" + assert parsed_date == datetime.date(2023, 7, 15), f"Expected date_value to be 2023-07-15, got {parsed_date}" else: - assert isinstance( - date_value, datetime.date - ), f"Expected date_value to be datetime.date, got {type(date_value)}" - assert date_value == datetime.date( - 2023, 7, 15 - ), f"Expected date_value to be 2023-07-15, got {date_value}" + assert isinstance(date_value, datetime.date), f"Expected date_value to be datetime.date, got {type(date_value)}" + assert date_value == datetime.date(2023, 7, 15), f"Expected date_value to be 2023-07-15, got {date_value}" datetime_value = get_value("datetime_value") if isinstance(datetime_value, str): @@ -698,66 +601,56 @@ def get_value(column_name): expected_datetime = datetime.datetime(2023, 7, 15, 12, 30, 45, tzinfo=datetime.timezone.utc) if parsed_dt.tzinfo is None: parsed_dt = parsed_dt.replace(tzinfo=datetime.timezone.utc) - assert ( - parsed_dt == expected_datetime - ), f"Expected datetime_value to be {expected_datetime}, got {parsed_dt}" + assert parsed_dt == expected_datetime, f"Expected datetime_value to be {expected_datetime}, got {parsed_dt}" else: - assert isinstance( - datetime_value, datetime.datetime - ), f"Expected datetime_value to be datetime.datetime, got {type(datetime_value)}" + assert isinstance(datetime_value, datetime.datetime), ( + f"Expected datetime_value to be datetime.datetime, got {type(datetime_value)}" + ) expected_datetime = datetime.datetime(2023, 7, 15, 12, 30, 45, tzinfo=datetime.timezone.utc) if datetime_value.tzinfo is None: datetime_value = datetime_value.replace(tzinfo=datetime.timezone.utc) - assert ( - datetime_value == expected_datetime - ), f"Expected datetime_value to be {expected_datetime}, got {datetime_value}" + assert datetime_value == expected_datetime, ( + f"Expected datetime_value to be {expected_datetime}, got {datetime_value}" + ) timestamp_value = get_value("timestamp_value") if isinstance(timestamp_value, str): parsed_ts = datetime.datetime.fromisoformat(timestamp_value.replace("Z", "+00:00")) - expected_timestamp = datetime.datetime( - 2023, 7, 15, 12, 30, 45, 123456, tzinfo=datetime.timezone.utc - ) + expected_timestamp = datetime.datetime(2023, 7, 15, 12, 30, 45, 123456, tzinfo=datetime.timezone.utc) if parsed_ts.tzinfo is None: parsed_ts = parsed_ts.replace(tzinfo=datetime.timezone.utc) - assert ( - parsed_ts == expected_timestamp - ), f"Expected timestamp_value to be {expected_timestamp}, got {parsed_ts}" + assert parsed_ts == expected_timestamp, f"Expected timestamp_value to be {expected_timestamp}, got {parsed_ts}" else: - assert isinstance( - timestamp_value, datetime.datetime - ), f"Expected timestamp_value to be datetime.datetime, got {type(timestamp_value)}" - expected_timestamp = datetime.datetime( - 2023, 7, 15, 12, 30, 45, 123456, tzinfo=datetime.timezone.utc + assert isinstance(timestamp_value, datetime.datetime), ( + f"Expected timestamp_value to be datetime.datetime, got {type(timestamp_value)}" ) + expected_timestamp = datetime.datetime(2023, 7, 15, 12, 30, 45, 123456, tzinfo=datetime.timezone.utc) if timestamp_value.tzinfo is None: timestamp_value = timestamp_value.replace(tzinfo=datetime.timezone.utc) - assert ( - timestamp_value == expected_timestamp - ), f"Expected timestamp_value to be {expected_timestamp}, got {timestamp_value}" + assert timestamp_value == expected_timestamp, ( + f"Expected timestamp_value to be {expected_timestamp}, got {timestamp_value}" + ) interval_value = get_value("interval_value") # Accept both string and timedelta for interval_value - expected_interval = datetime.timedelta( - days=1, hours=2, minutes=3, seconds=4, microseconds=567000 - ) + expected_interval = datetime.timedelta(days=1, hours=2, minutes=3, seconds=4, microseconds=567000) if isinstance(interval_value, str): # Parse string like '93784.567s' to seconds if interval_value.endswith("s"): seconds = float(interval_value[:-1]) parsed_interval = datetime.timedelta(seconds=seconds) - assert ( - parsed_interval.total_seconds() == expected_interval.total_seconds() - ), f"Expected interval_value to be {expected_interval}, got {parsed_interval}" + assert parsed_interval.total_seconds() == expected_interval.total_seconds(), ( + f"Expected interval_value to be {expected_interval}, got {parsed_interval}" + ) else: assert False, f"Unexpected interval string format: {interval_value}" else: - assert isinstance( - interval_value, datetime.timedelta - ), f"Expected interval_value to be datetime.timedelta, got {type(interval_value)}" - assert ( - interval_value.total_seconds() == expected_interval.total_seconds() - ), f"Expected interval_value to be {expected_interval}, got {interval_value}" + assert isinstance(interval_value, datetime.timedelta), ( + f"Expected interval_value to be datetime.timedelta, got {type(interval_value)}" + ) + assert interval_value.total_seconds() == expected_interval.total_seconds(), ( + f"Expected interval_value to be {expected_interval}, got {interval_value}" + ) # Decimal - YDB returns Decimal objects from decimal import Decimal @@ -765,16 +658,14 @@ def get_value(column_name): decimal_value = get_value("decimal_value") if isinstance(decimal_value, str): parsed_decimal = Decimal(decimal_value) - assert parsed_decimal == Decimal( - "123.456789" - ), f"Expected decimal_value to be Decimal('123.456789'), got {parsed_decimal}" + assert parsed_decimal == Decimal("123.456789"), ( + f"Expected decimal_value to be Decimal('123.456789'), got {parsed_decimal}" + ) else: - assert isinstance( - decimal_value, Decimal - ), f"Expected decimal_value to be Decimal, got {type(decimal_value)}" - assert decimal_value == Decimal( - "123.456789" - ), f"Expected decimal_value to be Decimal('123.456789'), got {decimal_value}" + assert isinstance(decimal_value, Decimal), f"Expected decimal_value to be Decimal, got {type(decimal_value)}" + assert decimal_value == Decimal("123.456789"), ( + f"Expected decimal_value to be Decimal('123.456789'), got {decimal_value}" + ) # Container types # List containers @@ -783,52 +674,40 @@ def get_value(column_name): assert int_list == [1, 2, 3], f"Expected int_list to be [1, 2, 3], got {int_list}" string_list = get_value("string_list") - assert isinstance( - string_list, list - ), f"Expected string_list to be a list, got {type(string_list)}" + assert isinstance(string_list, list), f"Expected string_list to be a list, got {type(string_list)}" expected = ["a", "b", "c"] for actual, exp in zip(string_list, expected): assert actual == exp, f"Expected {exp}, got {actual} in string_list" # Struct containers (similar to Python dictionaries) simple_struct = get_value("simple_struct") - assert isinstance( - simple_struct, dict - ), f"Expected simple_struct to be a dict, got {type(simple_struct)}" - assert ( - "a" in simple_struct and "b" in simple_struct - ), f"Expected simple_struct to have keys 'a' and 'b', got {simple_struct}" + assert isinstance(simple_struct, dict), f"Expected simple_struct to be a dict, got {type(simple_struct)}" + assert "a" in simple_struct and "b" in simple_struct, ( + f"Expected simple_struct to have keys 'a' and 'b', got {simple_struct}" + ) assert simple_struct["a"] == 1, f"Expected simple_struct['a'] to be 1, got {simple_struct['a']}" - assert ( - simple_struct["b"] == "x" - ), f"Expected simple_struct['b'] to be 'x', got {simple_struct['b']}" + assert simple_struct["b"] == "x", f"Expected simple_struct['b'] to be 'x', got {simple_struct['b']}" # Dictionary containers string_to_int_dict = get_value("string_to_int_dict") - assert isinstance( - string_to_int_dict, dict - ), f"Expected string_to_int_dict to be a dict, got {type(string_to_int_dict)}" + assert isinstance(string_to_int_dict, dict), ( + f"Expected string_to_int_dict to be a dict, got {type(string_to_int_dict)}" + ) # Accept both string keys and stringified bytes keys expected_dict = {"key1": 1, "key2": 2, "key3": 3} expected_bytes_dict = {f"b'{k}'": v for k, v in expected_dict.items()} - assert ( - string_to_int_dict == expected_dict or string_to_int_dict == expected_bytes_dict - ), f"Expected dict to be {expected_dict} or {expected_bytes_dict}, got {string_to_int_dict}" + assert string_to_int_dict == expected_dict or string_to_int_dict == expected_bytes_dict, ( + f"Expected dict to be {expected_dict} or {expected_bytes_dict}, got {string_to_int_dict}" + ) # Nested containers - list of structs list_of_structs = get_value("list_of_structs") - assert isinstance( - list_of_structs, list - ), f"Expected list_of_structs to be a list, got {type(list_of_structs)}" - assert ( - len(list_of_structs) == 3 - ), f"Expected list_of_structs to have 3 items, got {len(list_of_structs)}" + assert isinstance(list_of_structs, list), f"Expected list_of_structs to be a list, got {type(list_of_structs)}" + assert len(list_of_structs) == 3, f"Expected list_of_structs to have 3 items, got {len(list_of_structs)}" # Check first item in list of structs first_struct = list_of_structs[0] - assert isinstance( - first_struct, dict - ), f"Expected first_struct to be a dict, got {type(first_struct)}" + assert isinstance(first_struct, dict), f"Expected first_struct to be a dict, got {type(first_struct)}" assert first_struct == { "id": 1, "name": "Alice", @@ -836,57 +715,48 @@ def get_value(column_name): # Struct with list struct_with_list = get_value("struct_with_list") - assert isinstance( - struct_with_list, dict - ), f"Expected struct_with_list to be a dict, got {type(struct_with_list)}" + assert isinstance(struct_with_list, dict), f"Expected struct_with_list to be a dict, got {type(struct_with_list)}" assert struct_with_list == { "collection_name": "users", "ids": [1, 2, 3], "active": True, - }, f"Expected struct_with_list to be {{'collection_name': 'users', 'ids': [1, 2, 3], 'active': True}}, got {struct_with_list}" + }, ( + f"Expected struct_with_list to be {{'collection_name': 'users', 'ids': [1, 2, 3], 'active': True}}, " + f"got {struct_with_list}" + ) # Complex dict complex_dict = get_value("complex_dict") - assert isinstance( - complex_dict, dict - ), f"Expected complex_dict to be a dict, got {type(complex_dict)}" + assert isinstance(complex_dict, dict), f"Expected complex_dict to be a dict, got {type(complex_dict)}" expected_complex_dict = { "person1": {"id": 1, "name": "Alice", "scores": [25, 30, 28]}, "person2": {"id": 2, "name": "Bob", "scores": [22, 27, 29]}, } expected_bytes_complex_dict = {f"b'{k}'": v for k, v in expected_complex_dict.items()} - assert ( - complex_dict == expected_complex_dict or complex_dict == expected_bytes_complex_dict - ), f"Expected complex_dict to be {expected_complex_dict} or {expected_bytes_complex_dict}, got {complex_dict}" + assert complex_dict == expected_complex_dict or complex_dict == expected_bytes_complex_dict, ( + f"Expected complex_dict to be {expected_complex_dict} or {expected_bytes_complex_dict}, got {complex_dict}" + ) # Triple-nested list nested_list = get_value("nested_list_struct_list") - assert isinstance( - nested_list, list - ), f"Expected nested_list to be a list, got {type(nested_list)}" + assert isinstance(nested_list, list), f"Expected nested_list to be a list, got {type(nested_list)}" assert len(nested_list) == 2, f"Expected nested_list to have 2 items, got {len(nested_list)}" expected_nested_list = [ {"id": 1, "name": "Team A", "members": ["Alice", "Bob"]}, {"id": 2, "name": "Team B", "members": ["Charlie", "David"]}, ] - assert ( - nested_list == expected_nested_list - ), f"Expected nested_list to be {expected_nested_list}, got {nested_list}" + assert nested_list == expected_nested_list, f"Expected nested_list to be {expected_nested_list}, got {nested_list}" # Tuple containers mixed_tuple = get_value("mixed_tuple") - assert isinstance( - mixed_tuple, (list, tuple) - ), f"Expected mixed_tuple to be a list or tuple, got {type(mixed_tuple)}" + assert isinstance(mixed_tuple, (list, tuple)), f"Expected mixed_tuple to be a list or tuple, got {type(mixed_tuple)}" assert len(mixed_tuple) == 3, f"Expected mixed_tuple to have 3 items, got {len(mixed_tuple)}" expected_tuple = (1, "a", True) # Convert to tuple if it's a list for comparison if isinstance(mixed_tuple, list): mixed_tuple = tuple(mixed_tuple) - assert ( - mixed_tuple == expected_tuple - ), f"Expected mixed_tuple to be {expected_tuple}, got {mixed_tuple}" + assert mixed_tuple == expected_tuple, f"Expected mixed_tuple to be {expected_tuple}, got {mixed_tuple}" async def test_list_directory_integration(mcp_server): @@ -895,9 +765,9 @@ async def test_list_directory_integration(mcp_server): result = await call_mcp_tool(mcp_server, "ydb_list_directory", path="/") # Parse the JSON result - assert ( - isinstance(result, list) and len(result) > 0 and "text" in result[0] - ), f"Result should be a list of dicts with 'text': {result}" + assert isinstance(result, list) and len(result) > 0 and "text" in result[0], ( + f"Result should be a list of dicts with 'text': {result}" + ) parsed = json.loads(result[0]["text"]) # Verify the structure @@ -906,9 +776,7 @@ async def test_list_directory_integration(mcp_server): assert "items" in parsed, f"Missing 'items' field in dir_data: {parsed}" # Root directory should have at least some items - assert isinstance( - parsed["items"], list - ), f"Expected items to be a list, got {type(parsed['items'])}" + assert isinstance(parsed["items"], list), f"Expected items to be a list, got {type(parsed['items'])}" assert len(parsed["items"]) > 0, f"Expected non-empty directory, got {parsed['items']}" # Verify at least one item has expected properties @@ -930,9 +798,9 @@ async def test_list_directory_nonexistent_integration(mcp_server): result = await call_mcp_tool(mcp_server, "ydb_list_directory", path=nonexistent_path) # Parse the result - assert ( - isinstance(result, list) and len(result) > 0 and "text" in result[0] - ), f"Result should be a list of dicts with 'text': {result}" + assert isinstance(result, list) and len(result) > 0 and "text" in result[0], ( + f"Result should be a list of dicts with 'text': {result}" + ) parsed = json.loads(result[0]["text"]) # Should contain an error message @@ -959,11 +827,9 @@ async def test_describe_path_integration(mcp_server): ); """, ) - assert ( - isinstance(create_result, list) - and len(create_result) > 0 - and "text" in create_result[0] - ), f"Result should be a list of dicts with 'text': {create_result}" + assert isinstance(create_result, list) and len(create_result) > 0 and "text" in create_result[0], ( + f"Result should be a list of dicts with 'text': {create_result}" + ) parsed = json.loads(create_result[0]["text"]) assert "error" not in parsed, f"Error creating table: {parsed}" @@ -974,16 +840,16 @@ async def test_describe_path_integration(mcp_server): result = await call_mcp_tool(mcp_server, "ydb_describe_path", path=f"/{test_table_name}") # Parse the JSON result - assert ( - isinstance(result, list) and len(result) > 0 and "text" in result[0] - ), f"Result should be a list of dicts with 'text': {result}" + assert isinstance(result, list) and len(result) > 0 and "text" in result[0], ( + f"Result should be a list of dicts with 'text': {result}" + ) parsed = json.loads(result[0]["text"]) # Only check for path if not error if "error" not in parsed: - assert ( - parsed["path"] == f"/{test_table_name}" - ), f"Expected path to be '/{test_table_name}', got {parsed['path']}" + assert parsed["path"] == f"/{test_table_name}", ( + f"Expected path to be '/{test_table_name}', got {parsed['path']}" + ) assert "type" in parsed, f"Missing 'type' field in path_data: {parsed}" assert parsed["type"] == "TABLE", f"Expected type to be 'TABLE', got {parsed['type']}" # Verify table information @@ -991,9 +857,7 @@ async def test_describe_path_integration(mcp_server): finally: # Clean up - drop the table even if test fails - cleanup_result = await call_mcp_tool( - mcp_server, "ydb_query", sql=f"DROP TABLE {test_table_name};" - ) + cleanup_result = await call_mcp_tool(mcp_server, "ydb_query", sql=f"DROP TABLE {test_table_name};") logger.debug(f"Table cleanup result: {cleanup_result}") @@ -1006,9 +870,9 @@ async def test_describe_nonexistent_path_integration(mcp_server): result = await call_mcp_tool(mcp_server, "ydb_describe_path", path=nonexistent_path) # Parse the result - assert ( - isinstance(result, list) and len(result) > 0 and "text" in result[0] - ), f"Result should be a list of dicts with 'text': {result}" + assert isinstance(result, list) and len(result) > 0 and "text" in result[0], ( + f"Result should be a list of dicts with 'text': {result}" + ) parsed = json.loads(result[0]["text"]) # Should contain an error message @@ -1020,9 +884,9 @@ async def test_ydb_status_integration(mcp_server): result = await call_mcp_tool(mcp_server, "ydb_status") # Parse the JSON result - assert ( - isinstance(result, list) and len(result) > 0 and "text" in result[0] - ), f"Result should be a list of dicts with 'text': {result}" + assert isinstance(result, list) and len(result) > 0 and "text" in result[0], ( + f"Result should be a list of dicts with 'text': {result}" + ) parsed = json.loads(result[0]["text"]) # Verify the structure @@ -1034,9 +898,9 @@ async def test_ydb_status_integration(mcp_server): # For a successful test run, we expect to be connected assert parsed["status"] == "running", f"Expected status to be 'running', got {parsed['status']}" - assert ( - parsed["ydb_connection"] == "connected" - ), f"Expected ydb_connection to be 'connected', got {parsed['ydb_connection']}" + assert parsed["ydb_connection"] == "connected", ( + f"Expected ydb_connection to be 'connected', got {parsed['ydb_connection']}" + ) assert parsed["error"] is None, f"Expected no error, got: {parsed.get('error')}" logger.info(f"YDB status check successful: {parsed}") diff --git a/tests/integration/test_path_operations.py b/tests/integration/test_path_operations.py index ecc3b1e..0fa2ee8 100644 --- a/tests/integration/test_path_operations.py +++ b/tests/integration/test_path_operations.py @@ -1,6 +1,7 @@ """Integration tests for YDB directory and table operations (list_directory, describe_path, table creation, and cleanup). -These tests validate the functionality of YDB directory listing, path description, and table operations including creation and cleanup. +These tests validate the functionality of YDB directory listing, path description, and table operations +including creation and cleanup. They test real YDB interactions without mocks, requiring a running YDB instance. """ @@ -104,9 +105,7 @@ async def test_list_directory_after_table_creation(mcp_server): finally: # Clean up - drop the table - cleanup_result = await call_mcp_tool( - mcp_server, "ydb_query", sql=f"DROP TABLE {test_table_name};" - ) + cleanup_result = await call_mcp_tool(mcp_server, "ydb_query", sql=f"DROP TABLE {test_table_name};") logger.debug(f"Table cleanup result: {cleanup_result}") @@ -123,17 +122,11 @@ async def test_path_description(mcp_server): describe_result = await call_mcp_tool(mcp_server, "ydb_describe_path", path=item_path) path_data = parse_text_content(describe_result) assert "path" in path_data, f"Missing 'path' field in path data: {path_data}" - assert ( - path_data["path"] == item_path - ), f"Expected path to be '{item_path}', got {path_data['path']}" + assert path_data["path"] == item_path, f"Expected path to be '{item_path}', got {path_data['path']}" assert "type" in path_data, f"Missing 'type' field in path data: {path_data}" assert "name" in path_data, f"Missing 'name' field in path data: {path_data}" assert "owner" in path_data, f"Missing 'owner' field in path data: {path_data}" if path_data["type"] == "TABLE": assert "table" in path_data, f"Missing 'table' field for TABLE: {path_data}" - assert ( - "columns" in path_data["table"] - ), f"Missing 'columns' field in table data: {path_data}" - assert ( - len(path_data["table"]["columns"]) > 0 - ), f"Table should have at least one column: {path_data}" + assert "columns" in path_data["table"], f"Missing 'columns' field in table data: {path_data}" + assert len(path_data["table"]["columns"]) > 0, f"Table should have at least one column: {path_data}" diff --git a/tests/mocks.py b/tests/mocks.py index 72eba57..5dab3d8 100644 --- a/tests/mocks.py +++ b/tests/mocks.py @@ -1,6 +1,6 @@ """Mock classes for testing.""" -from typing import Any, Callable, Dict, Optional, Type +from typing import Callable, Type class MockRequestHandler: diff --git a/tests/test_connection.py b/tests/test_connection.py index ba7f4d9..8a6681d 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -15,7 +15,7 @@ sys.modules["mcp.server.handler"].register_handler = mock_register_handler # Import after mocking -from ydb_mcp.connection import YDBConnection +from ydb_mcp.connection import YDBConnection # noqa: E402 class TestYDBConnection(unittest.TestCase): @@ -45,9 +45,7 @@ async def test_connect(self, mock_driver_class): # Setup mocks mock_driver = AsyncMock() mock_driver.wait = AsyncMock(return_value=True) - mock_driver.discovery_debug_details = MagicMock( - return_value="Resolved endpoints: localhost:2136" - ) + mock_driver.discovery_debug_details = MagicMock(return_value="Resolved endpoints: localhost:2136") mock_driver_class.return_value = mock_driver with patch("ydb.aio.QuerySessionPool") as mock_session_pool_class: @@ -84,9 +82,7 @@ async def test_connect_with_database_in_endpoint(self, mock_driver_class): # Setup mocks mock_driver = AsyncMock() mock_driver.wait = AsyncMock(return_value=True) - mock_driver.discovery_debug_details = MagicMock( - return_value="Resolved endpoints: localhost:2136" - ) + mock_driver.discovery_debug_details = MagicMock(return_value="Resolved endpoints: localhost:2136") mock_driver_class.return_value = mock_driver with patch("ydb.aio.QuerySessionPool") as mock_session_pool_class: @@ -125,9 +121,7 @@ async def test_connect_with_explicit_database(self, mock_driver_class): # Setup mocks mock_driver = AsyncMock() mock_driver.wait = AsyncMock(return_value=True) - mock_driver.discovery_debug_details = MagicMock( - return_value="Resolved endpoints: localhost:2136" - ) + mock_driver.discovery_debug_details = MagicMock(return_value="Resolved endpoints: localhost:2136") mock_driver_class.return_value = mock_driver with patch("ydb.aio.QuerySessionPool") as mock_session_pool_class: diff --git a/tests/test_customjsonencoder.py b/tests/test_customjsonencoder.py index 94f5652..65b19f2 100644 --- a/tests/test_customjsonencoder.py +++ b/tests/test_customjsonencoder.py @@ -5,8 +5,6 @@ import decimal import json -import pytest - from ydb_mcp.server import CustomJSONEncoder @@ -38,9 +36,7 @@ def test_datetime_serialization(): assert deserialized["datetime"] == "2023-07-15T12:30:45" assert deserialized["date"] == "2023-07-15" assert deserialized["time"] == "12:30:45" - assert ( - deserialized["timedelta"] == "93784.567s" - ) # 1 day, 2 hours, 3 minutes, 4.567 seconds in seconds + assert deserialized["timedelta"] == "93784.567s" # 1 day, 2 hours, 3 minutes, 4.567 seconds in seconds assert deserialized["nested"]["datetime"] == "2023-07-15T12:30:45" assert deserialized["list_with_dates"][0] == "2023-07-15" assert deserialized["list_with_dates"][1] == "2023-07-15T12:30:45" diff --git a/tests/test_query.py b/tests/test_query.py index be26912..6f472e0 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -15,8 +15,8 @@ sys.modules["mcp.server.handler"].register_handler = mock_register_handler # Import modules after mocking -from ydb_mcp.connection import YDBConnection -from ydb_mcp.query import QueryExecutor +from ydb_mcp.connection import YDBConnection # noqa: E402 +from ydb_mcp.query import QueryExecutor # noqa: E402 class TestQueryExecutor(unittest.TestCase): @@ -103,9 +103,7 @@ def test_execute_query_sync(self): """Test _execute_query_sync method.""" # Setup session pool to return mock session mock_session = MagicMock() - self.mock_connection.session_pool.retry_operation_sync.side_effect = ( - lambda callback: callback(mock_session) - ) + self.mock_connection.session_pool.retry_operation_sync.side_effect = lambda callback: callback(mock_session) # Setup transaction mock mock_transaction = MagicMock() diff --git a/tests/test_server.py b/tests/test_server.py index 5f044dd..17ff733 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,19 +1,16 @@ """Unit tests for YDB MCP server implementation.""" import asyncio -import base64 import datetime import decimal import json # Patch the mcp module before importing the YDBMCPServer import sys -import unittest from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch import pytest import ydb -from mcp import Tool from mcp.types import TextContent sys.modules["mcp.server"] = MagicMock() @@ -21,7 +18,7 @@ sys.modules["mcp.server.handler"].RequestHandler = MagicMock sys.modules["mcp.server.handler"].register_handler = lambda name: lambda cls: cls -from ydb_mcp.server import CustomJSONEncoder, YDBMCPServer +from ydb_mcp.server import CustomJSONEncoder, YDBMCPServer # noqa: E402 @pytest.mark.unit @@ -31,9 +28,7 @@ class TestYDBMCPServer: # Initialization tests async def test_init_with_env_vars(self): """Test initialization with environment variables.""" - with patch( - "os.environ", {"YDB_ENDPOINT": "test-endpoint", "YDB_DATABASE": "test-database"} - ): + with patch("os.environ", {"YDB_ENDPOINT": "test-endpoint", "YDB_DATABASE": "test-database"}): with patch.object(YDBMCPServer, "register_tools"): server = YDBMCPServer() assert server.endpoint == "test-endpoint" @@ -98,23 +93,17 @@ async def test_query_with_params(self): server = YDBMCPServer(endpoint="test-endpoint", database="test-database") # Mock server.query method - server.query = AsyncMock( - return_value={"result_sets": [{"columns": ["test"], "rows": [["data"]]}]} - ) + server.query = AsyncMock(return_value={"result_sets": [{"columns": ["test"], "rows": [["data"]]}]}) # Create params as a JSON string params_json = json.dumps({"$param1": 123, "param2": "value"}) # Execute query with params - result = await server.query_with_params( - "SELECT * FROM table WHERE id = $param1", params_json - ) + result = await server.query_with_params("SELECT * FROM table WHERE id = $param1", params_json) # Check the query was executed with correct parameters expected_params = {"$param1": 123, "$param2": "value"} - server.query.assert_called_once_with( - "SELECT * FROM table WHERE id = $param1", expected_params - ) + server.query.assert_called_once_with("SELECT * FROM table WHERE id = $param1", expected_params) # Check the result assert result == {"result_sets": [{"columns": ["test"], "rows": [["data"]]}]} @@ -216,9 +205,9 @@ def _update_driver_config(self, driver_config): # Verify the error message indicates an authentication problem error_message = str(excinfo.value).lower() - assert ( - "authentication" in error_message or "invalid" in error_message - ), f"Expected authentication error, got: {error_message}" + assert "authentication" in error_message or "invalid" in error_message, ( + f"Expected authentication error, got: {error_message}" + ) # Directory and path tests async def test_list_directory(self): diff --git a/ydb_mcp/__init__.py b/ydb_mcp/__init__.py index 0959556..4cd203d 100644 --- a/ydb_mcp/__init__.py +++ b/ydb_mcp/__init__.py @@ -1,6 +1,6 @@ """YDB MCP - Model Context Protocol server for YDB.""" -from .version import VERSION +from .version import VERSION __version__ = VERSION diff --git a/ydb_mcp/__main__.py b/ydb_mcp/__main__.py index 9e72587..408db0c 100644 --- a/ydb_mcp/__main__.py +++ b/ydb_mcp/__main__.py @@ -103,7 +103,7 @@ def main(): auth_mode=auth_mode, ) - print(f"Starting YDB MCP server with stdio transport") + print("Starting YDB MCP server with stdio transport") print(f"YDB endpoint: {args.ydb_endpoint or 'Not set'}") print(f"YDB database: {args.ydb_database or 'Not set'}") print(f"YDB login: {'Set' if args.ydb_login else 'Not set'}") diff --git a/ydb_mcp/connection.py b/ydb_mcp/connection.py index 8c75bb9..518419b 100644 --- a/ydb_mcp/connection.py +++ b/ydb_mcp/connection.py @@ -5,7 +5,6 @@ from urllib.parse import urlparse import ydb -from ydb.aio import QuerySessionPool logger = logging.getLogger(__name__) @@ -13,7 +12,7 @@ class YDBConnection: """Manages YDB connection with async support.""" - def __init__(self, connection_string: str, database: str = None): + def __init__(self, connection_string: str, database: str | None = None): """Initialize YDB connection. Args: @@ -24,7 +23,7 @@ def __init__(self, connection_string: str, database: str = None): self.driver: Optional[ydb.Driver] = None self.session_pool: Optional[ydb.aio.QuerySessionPool] = None self._database = database - self.last_error = None + self.last_error: str | None = None def _parse_endpoint_and_database(self) -> Tuple[str, str]: """Parse endpoint and database from connection string. @@ -94,7 +93,7 @@ async def connect(self) -> Tuple[ydb.Driver, ydb.aio.QuerySessionPool]: await asyncio.wait_for(self.driver.wait(), timeout=10.0) except asyncio.TimeoutError: self.last_error = "Connection timeout" - raise RuntimeError(f"YDB driver connection timeout after 10 seconds") + raise RuntimeError("YDB driver connection timeout after 10 seconds") # Check if we connected successfully if not self.driver.discovery_debug_details().startswith("Resolved endpoints"): diff --git a/ydb_mcp/query.py b/ydb_mcp/query.py index 3caa568..f3fce17 100644 --- a/ydb_mcp/query.py +++ b/ydb_mcp/query.py @@ -78,9 +78,12 @@ def _execute_query(session): result.append(self._convert_row_to_dict(row)) return result + if self._session_pool is None: + raise RuntimeError("SessionPool is not provided.") + return self._session_pool.retry_operation_sync(_execute_query) - def _convert_row_to_dict(self, row: Any, col_names: List[str] = None) -> Dict[str, Any]: + def _convert_row_to_dict(self, row: Any, col_names: List[str] | None = None) -> Dict[str, Any]: """Convert a YDB result row to a dictionary. Args: @@ -125,9 +128,7 @@ def _convert_ydb_value(self, value: Any) -> Any: if isinstance(value, list): return [self._convert_ydb_value(item) for item in value] if isinstance(value, dict): - return { - self._convert_ydb_value(k): self._convert_ydb_value(v) for k, v in value.items() - } + return {self._convert_ydb_value(k): self._convert_ydb_value(v) for k, v in value.items()} if isinstance(value, tuple): return tuple(self._convert_ydb_value(item) for item in value) diff --git a/ydb_mcp/server.py b/ydb_mcp/server.py index b70845a..6d6eea5 100644 --- a/ydb_mcp/server.py +++ b/ydb_mcp/server.py @@ -4,20 +4,17 @@ import base64 import datetime import decimal -import gc import json import logging import os -import sys -import types -from typing import Any, Callable, Dict, List, Literal, Optional, Union +from typing import Any, Callable, Dict, List, Optional import ydb from mcp.server.fastmcp import FastMCP from mcp.types import TextContent -from ydb.aio import Driver as AsyncDriver from ydb.aio import QuerySessionPool +from ydb_mcp.connection import YDBConnection from ydb_mcp.tool_manager import ToolManager logger = logging.getLogger(__name__) @@ -94,15 +91,15 @@ class YDBMCPServer(FastMCP): def __init__( self, - endpoint: str = None, - database: str = None, - credentials_factory: Optional[Callable[[], ydb.Credentials]] = None, + endpoint: str | None = None, + database: str | None = None, + credentials_factory: Callable[[], ydb.Credentials] | None = None, ydb_connection_string: str = "", - tool_manager: Optional[ToolManager] = None, - auth_mode: str = None, - login: str = None, - password: str = None, - root_certificates: str = None, + tool_manager: ToolManager | None = None, + auth_mode: str | None = None, + login: str | None = None, + password: str | None = None, + root_certificates: str | None = None, *args, **kwargs, ): @@ -127,14 +124,14 @@ def __init__( self.database = database or os.environ.get("YDB_DATABASE", "/local") self.credentials_factory = credentials_factory self.ydb_connection_string = ydb_connection_string - self.auth_error = None + self.auth_error: str | None = None self._loop = None self.pool = None self.tool_manager = tool_manager or ToolManager() self._driver_lock = asyncio.Lock() self._pool_lock = asyncio.Lock() self.root_certificates = root_certificates - self._original_methods = {} + self._original_methods: Dict = {} # Authentication settings supported_auth_modes = {AUTH_MODE_ANONYMOUS, AUTH_MODE_LOGIN_PASSWORD} @@ -155,11 +152,7 @@ def __init__( def _restore_ydb_patches(self): """Restore original YDB methods that were patched.""" # Restore topic client __del__ method - if ( - "topic_client_del" in self._original_methods - and hasattr(ydb, "topic") - and hasattr(ydb.topic, "TopicClient") - ): + if "topic_client_del" in self._original_methods and hasattr(ydb, "topic") and hasattr(ydb.topic, "TopicClient"): if self._original_methods["topic_client_del"] is not None: ydb.topic.TopicClient.__del__ = self._original_methods["topic_client_del"] else: @@ -229,9 +222,7 @@ async def create_driver(self): # Initialize driver with latest API await self.driver.wait(timeout=5.0) # Check if we connected successfully - debug_details = await self._loop.run_in_executor( - None, lambda: self.driver.discovery_debug_details() - ) + debug_details = await self._loop.run_in_executor(None, lambda: self.driver.discovery_debug_details()) if not debug_details.startswith("Resolved endpoints"): self.auth_error = f"Failed to connect to YDB server: {debug_details}" logger.error(self.auth_error) @@ -274,9 +265,7 @@ async def _terminate_discovery(self, discovery): logger.warning(f"Error waiting for discovery task cancellation: {e}") # Handle any streaming response generators that might be running - if hasattr(discovery, "_fetch_stream_responses") and callable( - discovery._fetch_stream_responses - ): + if hasattr(discovery, "_fetch_stream_responses") and callable(discovery._fetch_stream_responses): # This is a generator method that might be active # Nothing to do directly - the generator will be GC'ed when the driver is destroyed pass @@ -313,9 +302,7 @@ async def _cancel_ydb_related_tasks(self): # Wait briefly for tasks to cancel if discovery_tasks: try: - await asyncio.wait_for( - asyncio.gather(*discovery_tasks, return_exceptions=True), timeout=0.5 - ) + await asyncio.wait_for(asyncio.gather(*discovery_tasks, return_exceptions=True), timeout=0.5) except (asyncio.TimeoutError, asyncio.CancelledError): pass @@ -376,11 +363,7 @@ async def query(self, sql: str, params: Optional[Dict[str, Any]] = None) -> List all_results.append(processed) # Convert all dict keys to strings for JSON serialization safe_result = self._stringify_dict_keys({"result_sets": all_results}) - return [ - TextContent( - type="text", text=json.dumps(safe_result, indent=2, cls=CustomJSONEncoder) - ) - ] + return [TextContent(type="text", text=json.dumps(safe_result, indent=2, cls=CustomJSONEncoder))] except Exception as e: error_message = str(e) safe_error = self._stringify_dict_keys({"error": error_message}) @@ -437,9 +420,7 @@ async def query_with_params(self, sql: str, params: str) -> List[TextContent]: # Handle authentication errors if self.auth_error: logger.error(f"Authentication error: {self.auth_error}") - safe_error = self._stringify_dict_keys( - {"error": f"Authentication error: {self.auth_error}"} - ) + safe_error = self._stringify_dict_keys({"error": f"Authentication error: {self.auth_error}"}) return [TextContent(type="text", text=json.dumps(safe_error, indent=2))] parsed_params = {} try: @@ -447,9 +428,7 @@ async def query_with_params(self, sql: str, params: str) -> List[TextContent]: parsed_params = json.loads(params) except json.JSONDecodeError as e: logger.error(f"Error parsing JSON parameters: {str(e)}") - safe_error = self._stringify_dict_keys( - {"error": f"Error parsing JSON parameters: {str(e)}"} - ) + safe_error = self._stringify_dict_keys({"error": f"Error parsing JSON parameters: {str(e)}"}) return [TextContent(type="text", text=json.dumps(safe_error, indent=2))] # Convert [value, type] to YDB type if needed ydb_params = {} @@ -608,11 +587,7 @@ async def list_directory(self, path: str) -> List[TextContent]: await self.create_driver() if self.driver is None: - return [ - TextContent( - type="text", text=json.dumps({"error": "Failed to create driver"}, indent=2) - ) - ] + return [TextContent(type="text", text=json.dumps({"error": "Failed to create driver"}, indent=2))] # Access the scheme client scheme_client = self.driver.scheme_client @@ -650,17 +625,11 @@ async def list_directory(self, path: str) -> List[TextContent]: # Convert all dict keys to strings for JSON serialization safe_result = self._stringify_dict_keys(result) - return [ - TextContent( - type="text", text=json.dumps(safe_result, indent=2, cls=CustomJSONEncoder) - ) - ] + return [TextContent(type="text", text=json.dumps(safe_result, indent=2, cls=CustomJSONEncoder))] except Exception as e: logger.exception(f"Error listing directory {path}: {e}") - safe_error = self._stringify_dict_keys( - {"error": f"Error listing directory {path}: {str(e)}"} - ) + safe_error = self._stringify_dict_keys({"error": f"Error listing directory {path}: {str(e)}"}) return [TextContent(type="text", text=json.dumps(safe_error, indent=2))] async def describe_path(self, path: str) -> List[TextContent]: @@ -748,14 +717,8 @@ async def describe_path(self, path: str) -> List[TextContent]: index_info = { "name": index.name, "index_columns": list(index.index_columns), - "cover_columns": ( - list(index.cover_columns) - if hasattr(index, "cover_columns") - else [] - ), - "index_type": ( - str(index.index_type) if hasattr(index, "index_type") else None - ), + "cover_columns": (list(index.cover_columns) if hasattr(index, "cover_columns") else []), + "index_type": (str(index.index_type) if hasattr(index, "index_type") else None), } result["table"]["indexes"].append(index_info) @@ -765,11 +728,7 @@ async def describe_path(self, path: str) -> List[TextContent]: family_info = { "name": family.name, "data": family.data, - "compression": ( - str(family.compression) - if hasattr(family, "compression") - else None - ), + "compression": (str(family.compression) if hasattr(family, "compression") else None), } result["table"]["column_families"].append(family_info) @@ -789,21 +748,17 @@ async def describe_path(self, path: str) -> List[TextContent]: ps = table_desc.partitioning_settings if ps: if hasattr(ps, "partition_at_keys"): - result["table"]["partitioning_settings"][ - "partition_at_keys" - ] = ps.partition_at_keys + result["table"]["partitioning_settings"]["partition_at_keys"] = ps.partition_at_keys if hasattr(ps, "partition_by_size"): - result["table"]["partitioning_settings"][ - "partition_by_size" - ] = ps.partition_by_size + result["table"]["partitioning_settings"]["partition_by_size"] = ps.partition_by_size if hasattr(ps, "min_partitions_count"): - result["table"]["partitioning_settings"][ - "min_partitions_count" - ] = ps.min_partitions_count + result["table"]["partitioning_settings"]["min_partitions_count"] = ( + ps.min_partitions_count + ) if hasattr(ps, "max_partitions_count"): - result["table"]["partitioning_settings"][ - "max_partitions_count" - ] = ps.max_partitions_count + result["table"]["partitioning_settings"]["max_partitions_count"] = ( + ps.max_partitions_count + ) finally: # Always release the session @@ -816,9 +771,7 @@ async def describe_path(self, path: str) -> List[TextContent]: result["table"] = { "columns": [], "primary_key": ( - path_response.table.primary_key - if hasattr(path_response.table, "primary_key") - else [] + path_response.table.primary_key if hasattr(path_response.table, "primary_key") else [] ), "indexes": [], "partitioning_settings": {}, @@ -827,9 +780,7 @@ async def describe_path(self, path: str) -> List[TextContent]: # Add basic columns if hasattr(path_response.table, "columns"): for column in path_response.table.columns: - result["table"]["columns"].append( - {"name": column.name, "type": str(column.type)} - ) + result["table"]["columns"].append({"name": column.name, "type": str(column.type)}) # Add basic indexes if hasattr(path_response.table, "indexes"): @@ -838,9 +789,7 @@ async def describe_path(self, path: str) -> List[TextContent]: { "name": index.name, "index_columns": ( - list(index.index_columns) - if hasattr(index, "index_columns") - else [] + list(index.index_columns) if hasattr(index, "index_columns") else [] ), } ) @@ -850,21 +799,17 @@ async def describe_path(self, path: str) -> List[TextContent]: ps = path_response.table.partitioning_settings if ps: if hasattr(ps, "partition_at_keys"): - result["table"]["partitioning_settings"][ - "partition_at_keys" - ] = ps.partition_at_keys + result["table"]["partitioning_settings"]["partition_at_keys"] = ps.partition_at_keys if hasattr(ps, "partition_by_size"): - result["table"]["partitioning_settings"][ - "partition_by_size" - ] = ps.partition_by_size + result["table"]["partitioning_settings"]["partition_by_size"] = ps.partition_by_size if hasattr(ps, "min_partitions_count"): - result["table"]["partitioning_settings"][ - "min_partitions_count" - ] = ps.min_partitions_count + result["table"]["partitioning_settings"]["min_partitions_count"] = ( + ps.min_partitions_count + ) if hasattr(ps, "max_partitions_count"): - result["table"]["partitioning_settings"][ - "max_partitions_count" - ] = ps.max_partitions_count + result["table"]["partitioning_settings"]["max_partitions_count"] = ( + ps.max_partitions_count + ) # Convert to JSON string and return as TextContent formatted_result = json.dumps(result, indent=2, cls=CustomJSONEncoder) @@ -996,15 +941,11 @@ async def call_tool(self, tool_name: str, params: Dict[str, Any]) -> List[TextCo # Convert TextContent objects to dictionaries if needed if isinstance(result, list) and any(isinstance(item, TextContent) for item in result): serializable_result = self._text_content_to_dict(result) - return serializable_result + return serializable_result # type: ignore # Handle any other result type if result is None: - return [ - TextContent( - type="text", text="Operation completed successfully but returned no data" - ) - ] + return [TextContent(type="text", text="Operation completed successfully but returned no data")] return result @@ -1023,10 +964,10 @@ def get_tool_schema(self) -> List[Dict[str, Any]]: def run(self): """Run the YDB MCP server using the FastMCP server implementation.""" - print(f"Starting YDB MCP server") + print("Starting YDB MCP server") print(f"YDB endpoint: {self.endpoint or 'Not set'}") print(f"YDB database: {self.database or 'Not set'}") - logger.info(f"Starting YDB MCP server") + logger.info("Starting YDB MCP server") # Use FastMCP's built-in run method with stdio transport super().run(transport="stdio") @@ -1042,15 +983,15 @@ def get_credentials_factory(self) -> Optional[Callable[[], ydb.Credentials]]: supported_auth_modes = {AUTH_MODE_ANONYMOUS, AUTH_MODE_LOGIN_PASSWORD} if self.auth_mode not in supported_auth_modes: - self.auth_error = f"Unsupported auth mode: {self.auth_mode}. Supported modes: {', '.join(supported_auth_modes)}" + self.auth_error = ( + f"Unsupported auth mode: {self.auth_mode}. Supported modes: {', '.join(supported_auth_modes)}" + ) return None # If auth_mode is login_password and we have both login and password, use them if self.auth_mode == AUTH_MODE_LOGIN_PASSWORD: if not self.login or not self.password: - self.auth_error = ( - "Login and password must be provided for login-password authentication mode." - ) + self.auth_error = "Login and password must be provided for login-password authentication mode." return None logger.info(f"Using login/password authentication with user '{self.login}'") return self._login_password_credentials diff --git a/ydb_mcp/tool_manager.py b/ydb_mcp/tool_manager.py index 9126796..5c6b63e 100644 --- a/ydb_mcp/tool_manager.py +++ b/ydb_mcp/tool_manager.py @@ -4,9 +4,7 @@ class ToolDefinition: """Defines a tool that can be called by the MCP.""" - def __init__( - self, name: str, handler: Callable, description: str = "", parameters: Optional[Dict] = None - ): + def __init__(self, name: str, handler: Callable, description: str = "", parameters: Optional[Dict] = None): """Initialize a tool definition. Args: @@ -39,9 +37,7 @@ def register_tool( description: Tool description parameters: JSON schema for tool parameters """ - self._tools[name] = ToolDefinition( - name=name, handler=handler, description=description, parameters=parameters - ) + self._tools[name] = ToolDefinition(name=name, handler=handler, description=description, parameters=parameters) def get(self, name: str) -> Optional[ToolDefinition]: """Get a tool by name. @@ -70,7 +66,7 @@ def get_schema(self) -> List[Dict[str, Any]]: """ result = [] for name, tool in self._tools.items(): - tool_schema = { + tool_schema: Dict[str, Any] = { "name": name, "description": tool.description, }