Skip to content

Commit 27777dc

Browse files
committed
Add tests
1 parent e6bc448 commit 27777dc

File tree

7 files changed

+125
-0
lines changed

7 files changed

+125
-0
lines changed

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,7 @@ select = [
6666

6767
[tool.ruff.lint.per-file-ignores]
6868
"tests/*" = ["S", "INP001"]
69+
70+
[tool.pytest.ini_options]
71+
asyncio_mode = "auto"
72+
asyncio_default_fixture_loop_scope = "session"

tests/conftest.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (C) 2024 Andrew Wason
2+
# SPDX-License-Identifier: MIT
3+
4+
from unittest import mock
5+
6+
import pytest
7+
from mcp import ClientSession, ListToolsResult, Tool
8+
from mcp.types import CallToolResult, TextContent
9+
10+
from langchain_mcp import MCPToolkit
11+
12+
13+
@pytest.fixture(scope="class")
14+
def mcptoolkit(request):
15+
session_mock = mock.AsyncMock(spec=ClientSession)
16+
session_mock.list_tools.return_value = ListToolsResult(
17+
tools=[
18+
Tool(
19+
name="read_file",
20+
description=(
21+
"Read the complete contents of a file from the file system. Handles various text encodings "
22+
"and provides detailed error messages if the file cannot be read. "
23+
"Use this tool when you need to examine the contents of a single file. "
24+
"Only works within allowed directories."
25+
),
26+
inputSchema={
27+
"type": "object",
28+
"properties": {"path": {"type": "string"}},
29+
"required": ["path"],
30+
"additionalProperties": False,
31+
"$schema": "http://json-schema.org/draft-07/schema#",
32+
},
33+
)
34+
]
35+
)
36+
session_mock.call_tool.return_value = CallToolResult(
37+
content=[TextContent(type="text", text="MIT License\n\nCopyright (c) 2024 Andrew Wason\n")],
38+
isError=False,
39+
)
40+
toolkit = MCPToolkit(session=session_mock)
41+
request.cls.toolkit = toolkit
42+
yield toolkit
43+
# session_mock.call_tool.assert_called_with("read_file", arguments={"path": "LICENSE"})
44+
45+
46+
@pytest.fixture(scope="class")
47+
async def mcptool(request, mcptoolkit):
48+
tool = (await request.cls.toolkit.get_tools())[0]
49+
request.cls.tool = tool
50+
yield tool

tests/demo.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright (C) 2024 Andrew Wason
2+
# SPDX-License-Identifier: MIT
3+
4+
import asyncio
5+
import pathlib
6+
7+
from langchain_core.messages import HumanMessage
8+
from langchain_core.output_parsers import StrOutputParser
9+
from langchain_groq import ChatGroq
10+
from mcp import ClientSession, StdioServerParameters
11+
from mcp.client.stdio import stdio_client
12+
13+
from langchain_mcp import MCPToolkit
14+
15+
16+
async def main():
17+
model = ChatGroq(model="llama-3.1-8b-instant") # requires GROQ_API_KEY
18+
server_params = StdioServerParameters(
19+
command="npx",
20+
args=["-y", "@modelcontextprotocol/server-filesystem", str(pathlib.Path(__file__).parent.parent)],
21+
)
22+
async with stdio_client(server_params) as (read, write):
23+
async with ClientSession(read, write) as session:
24+
toolkit = MCPToolkit(session=session)
25+
tools = await toolkit.get_tools()
26+
tools_map = {tool.name: tool for tool in tools}
27+
tools_model = model.bind_tools(tools)
28+
messages = [HumanMessage("Read and summarize the file ./LICENSE")]
29+
messages.append(await tools_model.ainvoke(messages))
30+
for tool_call in messages[-1].tool_calls:
31+
selected_tool = tools_map[tool_call["name"].lower()]
32+
tool_msg = await selected_tool.ainvoke(tool_call)
33+
messages.append(tool_msg)
34+
result = await (tools_model | StrOutputParser()).ainvoke(messages)
35+
print(result)
36+
37+
38+
if __name__ == "__main__":
39+
asyncio.run(main())

tests/integration_tests/__init__.py

Whitespace-only changes.
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright (C) 2024 Andrew Wason
2+
# SPDX-License-Identifier: MIT
3+
4+
import pytest
5+
from langchain_tests.unit_tests import ToolsUnitTests
6+
7+
8+
@pytest.mark.usefixtures("mcptool")
9+
class TestMCPToolIntegration(ToolsUnitTests):
10+
@property
11+
def tool_constructor(self):
12+
return self.tool
13+
14+
@property
15+
def tool_invoke_params_example(self) -> dict:
16+
return {"path": "LICENSE"}

tests/unit_tests/__init__.py

Whitespace-only changes.

tests/unit_tests/test_tool.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright (C) 2024 Andrew Wason
2+
# SPDX-License-Identifier: MIT
3+
4+
import pytest
5+
from langchain_tests.unit_tests import ToolsUnitTests
6+
7+
8+
@pytest.mark.usefixtures("mcptool")
9+
class TestMCPToolUnit(ToolsUnitTests):
10+
@property
11+
def tool_constructor(self):
12+
return self.tool
13+
14+
@property
15+
def tool_invoke_params_example(self) -> dict:
16+
return {"path": "LICENSE"}

0 commit comments

Comments
 (0)