Skip to content

Commit a320f35

Browse files
committed
mypy
1 parent 94610bd commit a320f35

File tree

6 files changed

+88
-11
lines changed

6 files changed

+88
-11
lines changed

.github/workflows/test.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,5 @@ jobs:
2828
- name: Lint
2929
run: |
3030
uv run --python ${{ matrix.python-version }} --python-preference only-system ruff format --check
31-
uv run --python ${{ matrix.python-version }} --python-preference only-system ruff check
31+
uv run --python ${{ matrix.python-version }} --python-preference only-system ruff check
32+
uv run --python ${{ matrix.python-version }} --python-preference only-system mypy

pyproject.toml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ dev = [
2727
"pytest-asyncio~=0.24.0",
2828
"pytest-socket~=0.7.0",
2929
"ruff~=0.8.0",
30+
"mypy~=1.13.0",
31+
"typing-extensions~=4.12.2",
3032
]
3133

3234
[project.urls]
@@ -70,3 +72,13 @@ select = [
7072
[tool.pytest.ini_options]
7173
asyncio_mode = "auto"
7274
asyncio_default_fixture_loop_scope = "class"
75+
76+
[tool.mypy]
77+
disallow_untyped_defs = true
78+
warn_unused_configs = true
79+
warn_redundant_casts = true
80+
warn_unused_ignores = true
81+
strict_equality = true
82+
no_implicit_optional = true
83+
show_error_codes = true
84+
files = "src/**/*.py"

src/langchain_mcp/toolkit.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
# SPDX-License-Identifier: MIT
33

44
import asyncio
5-
import typing as t
65
import warnings
76
from collections.abc import Callable
87

98
import pydantic
109
import pydantic_core
10+
import typing_extensions as t
1111
from langchain_core.tools.base import BaseTool, BaseToolkit, ToolException
1212
from mcp import ClientSession
1313

@@ -24,7 +24,8 @@ class MCPToolkit(BaseToolkit):
2424

2525
model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)
2626

27-
async def get_tools(self) -> list[BaseTool]:
27+
@t.override
28+
async def get_tools(self) -> list[BaseTool]: # type: ignore[override]
2829
if not self._initialized:
2930
await self.session.initialize()
3031
self._initialized = True
@@ -33,7 +34,7 @@ async def get_tools(self) -> list[BaseTool]:
3334
MCPTool(
3435
session=self.session,
3536
name=tool.name,
36-
description=tool.description,
37+
description=tool.description or "",
3738
args_schema=create_schema_model(tool.inputSchema),
3839
)
3940
# list_tools returns a PaginatedResult, but I don't see a way to pass the cursor to retrieve more tools
@@ -47,6 +48,7 @@ def create_schema_model(schema: dict[str, t.Any]) -> type[pydantic.BaseModel]:
4748
class Schema(pydantic.BaseModel):
4849
model_config = pydantic.ConfigDict(extra="allow", arbitrary_types_allowed=True)
4950

51+
@t.override
5052
@classmethod
5153
def model_json_schema(
5254
cls,
@@ -66,22 +68,25 @@ class MCPTool(BaseTool):
6668
"""
6769

6870
session: ClientSession
69-
71+
args_schema: type[pydantic.BaseModel]
7072
handle_tool_error: bool | str | Callable[[ToolException], str] | None = True
7173

74+
@t.override
7275
def _run(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
7376
warnings.warn(
7477
"Invoke this tool asynchronousely using `ainvoke`. This method exists only to satisfy tests.", stacklevel=1
7578
)
7679
return asyncio.run(self._arun(*args, **kwargs))
7780

81+
@t.override
7882
async def _arun(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
7983
result = await self.session.call_tool(self.name, arguments=kwargs)
8084
content = pydantic_core.to_json(result.content).decode()
8185
if result.isError:
8286
raise ToolException(content)
8387
return content
8488

89+
@t.override
8590
@property
8691
def tool_call_schema(self) -> type[pydantic.BaseModel]:
8792
return self.args_schema

tests/conftest.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from unittest import mock
55

66
import pytest
7+
from langchain_tests.integration_tests import ToolsIntegrationTests
78
from mcp import ClientSession, ListToolsResult, Tool
89
from mcp.types import CallToolResult, TextContent
910

@@ -38,13 +39,13 @@ def mcptoolkit(request):
3839
isError=False,
3940
)
4041
toolkit = MCPToolkit(session=session_mock)
41-
request.cls.toolkit = toolkit
4242
yield toolkit
43-
# session_mock.call_tool.assert_called_with("read_file", arguments={"path": "LICENSE"})
43+
if issubclass(request.cls, ToolsIntegrationTests):
44+
session_mock.call_tool.assert_called_with("read_file", arguments={"path": "LICENSE"})
4445

4546

4647
@pytest.fixture(scope="class")
4748
async def mcptool(request, mcptoolkit):
48-
tool = (await request.cls.toolkit.get_tools())[0]
49+
tool = (await mcptoolkit.get_tools())[0]
4950
request.cls.tool = tool
5051
yield tool

tests/demo.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,18 @@
11
# Copyright (C) 2024 Andrew Wason
22
# SPDX-License-Identifier: MIT
33

4+
# /// script
5+
# requires-python = ">=3.10"
6+
# dependencies = [
7+
# "langchain-mcp",
8+
# "langchain-groq",
9+
# ]
10+
# ///
11+
12+
413
import asyncio
514
import pathlib
15+
import sys
616

717
from langchain_core.messages import HumanMessage
818
from langchain_core.output_parsers import StrOutputParser
@@ -13,7 +23,7 @@
1323
from langchain_mcp import MCPToolkit
1424

1525

16-
async def main():
26+
async def main(prompt: str) -> None:
1727
model = ChatGroq(model="llama-3.1-8b-instant") # requires GROQ_API_KEY
1828
server_params = StdioServerParameters(
1929
command="npx",
@@ -25,7 +35,7 @@ async def main():
2535
tools = await toolkit.get_tools()
2636
tools_map = {tool.name: tool for tool in tools}
2737
tools_model = model.bind_tools(tools)
28-
messages = [HumanMessage("Read and summarize the file ./LICENSE")]
38+
messages = [HumanMessage(prompt)]
2939
messages.append(await tools_model.ainvoke(messages))
3040
for tool_call in messages[-1].tool_calls:
3141
selected_tool = tools_map[tool_call["name"].lower()]
@@ -36,4 +46,5 @@ async def main():
3646

3747

3848
if __name__ == "__main__":
39-
asyncio.run(main())
49+
prompt = sys.argv[1] if len(sys.argv) > 1 else "Read and summarize the file ./LICENSE"
50+
asyncio.run(main(prompt))

uv.lock

Lines changed: 47 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)