Skip to content

Commit 97eb934

Browse files
committed
feat(server.py): add detailed docstrings to ShellServer class and its methods for better documentation
feat(shell_executor.py): enhance ShellExecutor with command validation and error handling for secure execution fix(pyproject.toml): update mcp dependency version to ensure compatibility chore(tests): configure pytest-asyncio for strict mode and improve test structure with event loop fixture test(shell_executor.py): add tests for command execution, including validation of shell operators and error handling
1 parent 815e038 commit 97eb934

File tree

5 files changed

+138
-53
lines changed

5 files changed

+138
-53
lines changed

mcp_shell_server/server.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,38 @@
1+
import asyncio
12
from mcp.server.stdio import stdio_server
23
from .shell_executor import ShellExecutor
34

4-
55
class ShellServer:
6+
"""
7+
MCP server that executes shell commands in a secure manner.
8+
Only commands listed in the ALLOW_COMMANDS environment variable can be executed.
9+
"""
10+
611
def __init__(self):
712
self.executor = ShellExecutor()
813

914
async def handle(self, args: dict) -> dict:
15+
"""
16+
Handle incoming MCP requests to execute shell commands.
17+
18+
Args:
19+
args (dict): Arguments containing the command to execute and optional stdin
20+
21+
Returns:
22+
dict: Execution results including stdout, stderr, status code, and execution time
23+
"""
1024
command = args.get("command", [])
1125
stdin = args.get("stdin")
12-
26+
1327
if not command:
14-
return {"error": "No command provided", "status": 1}
15-
28+
return {
29+
"error": "No command provided",
30+
"status": 1
31+
}
32+
1633
return await self.executor.execute(command, stdin)
1734

18-
1935
def main():
36+
"""Entry point for the MCP shell server"""
2037
server = ShellServer()
21-
stdio_server(server.handle)
38+
stdio_server(server.handle)

mcp_shell_server/shell_executor.py

Lines changed: 89 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,74 @@
11
import os
22
import time
33
import asyncio
4+
import shlex
45
from typing import Dict, List, Optional, Any
56

67
class ShellExecutor:
8+
"""
9+
Executes shell commands in a secure manner by validating against a whitelist.
10+
"""
11+
712
def __init__(self):
8-
# Allow whitespace in ALLOW_COMMANDS and trim each command
13+
"""
14+
Initialize the executor. The allowed commands are read from ALLOW_COMMANDS
15+
environment variable during command validation, not at initialization.
16+
"""
17+
pass
18+
19+
def _get_allowed_commands(self) -> set:
20+
"""
21+
Get the set of allowed commands from environment variable.
22+
23+
Returns:
24+
set: Set of allowed command names
25+
"""
926
allow_commands = os.environ.get("ALLOW_COMMANDS", "")
10-
self.allowed_commands = set(cmd.strip() for cmd in allow_commands.split(",") if cmd.strip())
27+
return {cmd.strip() for cmd in allow_commands.split(",") if cmd.strip()}
1128

1229
def _validate_command(self, command: List[str]) -> None:
30+
"""
31+
Validate if the command is allowed to be executed.
32+
33+
Args:
34+
command (List[str]): Command and its arguments
35+
36+
Raises:
37+
ValueError: If the command is empty, not allowed, or contains invalid shell operators
38+
"""
1339
if not command:
1440
raise ValueError("Empty command")
1541

16-
# Check first command
17-
if command[0] not in self.allowed_commands:
42+
allowed_commands = self._get_allowed_commands()
43+
if not allowed_commands:
44+
raise ValueError("No commands are allowed. Please set ALLOW_COMMANDS environment variable.")
45+
46+
if command[0] not in allowed_commands:
1847
raise ValueError(f"Command not allowed: {command[0]}")
1948

20-
# Check for shell operators and subsequent commands
21-
for arg in command[1:]:
49+
# Check for shell operators and validate subsequent commands
50+
for i, arg in enumerate(command[1:], start=1):
2251
if arg in [";", "&&", "||", "|"]:
23-
next_cmd_idx = command.index(arg) + 1
24-
if next_cmd_idx < len(command):
25-
next_cmd = command[next_cmd_idx]
26-
if next_cmd not in self.allowed_commands:
27-
raise ValueError(f"Command not allowed: {next_cmd}")
52+
if i + 1 >= len(command):
53+
raise ValueError(f"Unexpected shell operator: {arg}")
54+
next_cmd = command[i + 1]
55+
if next_cmd not in allowed_commands:
56+
raise ValueError(f"Command not allowed after {arg}: {next_cmd}")
2857

2958
async def execute(self, command: List[str], stdin: Optional[str] = None) -> Dict[str, Any]:
59+
"""
60+
Execute a shell command with optional stdin input.
61+
62+
Args:
63+
command (List[str]): Command and its arguments
64+
stdin (Optional[str]): Input to be passed to the command via stdin
65+
66+
Returns:
67+
Dict[str, Any]: Execution result containing stdout, stderr, status code, and execution time.
68+
If error occurs, result contains additional 'error' field.
69+
"""
70+
start_time = time.time()
71+
3072
try:
3173
self._validate_command(command)
3274
except ValueError as e:
@@ -35,28 +77,41 @@ async def execute(self, command: List[str], stdin: Optional[str] = None) -> Dict
3577
"status": 1,
3678
"stdout": "",
3779
"stderr": str(e),
38-
"execution_time": 0
80+
"execution_time": time.time() - start_time
3981
}
4082

41-
start_time = time.time()
42-
43-
process = await asyncio.create_subprocess_exec(
44-
*command,
45-
stdin=asyncio.subprocess.PIPE if stdin else None,
46-
stdout=asyncio.subprocess.PIPE,
47-
stderr=asyncio.subprocess.PIPE
48-
)
49-
50-
if stdin:
51-
stdout, stderr = await process.communicate(stdin.encode())
52-
else:
53-
stdout, stderr = await process.communicate()
54-
55-
execution_time = time.time() - start_time
56-
57-
return {
58-
"stdout": stdout.decode() if stdout else "",
59-
"stderr": stderr.decode() if stderr else "",
60-
"status": process.returncode,
61-
"execution_time": execution_time
62-
}
83+
try:
84+
process = await asyncio.create_subprocess_exec(
85+
command[0],
86+
*command[1:],
87+
stdin=asyncio.subprocess.PIPE if stdin else None,
88+
stdout=asyncio.subprocess.PIPE,
89+
stderr=asyncio.subprocess.PIPE,
90+
env={"PATH": os.environ.get("PATH", "")}
91+
)
92+
93+
stdin_bytes = stdin.encode() if stdin else None
94+
stdout, stderr = await process.communicate(input=stdin_bytes)
95+
96+
return {
97+
"stdout": stdout.decode() if stdout else "",
98+
"stderr": stderr.decode() if stderr else "",
99+
"status": process.returncode,
100+
"execution_time": time.time() - start_time
101+
}
102+
except FileNotFoundError:
103+
return {
104+
"error": f"Command not found: {command[0]}",
105+
"status": 1,
106+
"stdout": "",
107+
"stderr": f"Command not found: {command[0]}",
108+
"execution_time": time.time() - start_time
109+
}
110+
except Exception as e:
111+
return {
112+
"error": str(e),
113+
"status": 1,
114+
"stdout": "",
115+
"stderr": str(e),
116+
"execution_time": time.time() - start_time
117+
}

pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ authors = [
66
{ name = "tumf" }
77
]
88
dependencies = [
9-
"mcp>=1.1.1",
9+
"mcp>=1.1.0",
1010
]
1111
requires-python = ">=3.11"
1212
readme = "README.md"
@@ -25,3 +25,7 @@ test = [
2525
[build-system]
2626
requires = ["hatchling"]
2727
build-backend = "hatchling.build"
28+
29+
[tool.pytest.ini_options]
30+
asyncio_mode = "strict"
31+
testpaths = "tests"

tests/conftest.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,13 @@
11
import pytest
2+
import asyncio
23

3-
pytest_plugins = ('pytest_asyncio',)
4+
# Configure pytest-asyncio to use function scope
5+
def pytest_configure(config):
6+
config.option.asyncio_mode = "strict"
7+
8+
@pytest.fixture(scope="function")
9+
def event_loop():
10+
"""Create a new event loop for each test case"""
11+
loop = asyncio.new_event_loop()
12+
yield loop
13+
loop.close()

tests/test_shell_executor.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import os
21
import pytest
32
from mcp_shell_server.shell_executor import ShellExecutor
43

@@ -8,12 +7,10 @@ def executor():
87

98
@pytest.mark.asyncio
109
async def test_basic_command_execution(executor, monkeypatch):
11-
monkeypatch.setenv("ALLOW_COMMANDS", "echo,ls")
10+
monkeypatch.setenv("ALLOW_COMMANDS", "echo")
1211
result = await executor.execute(["echo", "hello"])
1312
assert result["stdout"].strip() == "hello"
1413
assert result["status"] == 0
15-
assert result["stderr"] == ""
16-
assert "execution_time" in result
1714

1815
@pytest.mark.asyncio
1916
async def test_stdin_input(executor, monkeypatch):
@@ -45,13 +42,15 @@ async def test_command_with_space_in_allow_commands(executor, monkeypatch):
4542
@pytest.mark.asyncio
4643
async def test_multiple_commands_with_operator(executor, monkeypatch):
4744
monkeypatch.setenv("ALLOW_COMMANDS", "echo,ls")
48-
result = await executor.execute(["echo", "hello", ";", "ls", "-l"])
49-
assert "Command not allowed: ls" in result["error"]
45+
result = await executor.execute(["echo", "hello", ";"])
46+
assert result["error"] == "Unexpected shell operator: ;"
5047
assert result["status"] == 1
5148

5249
@pytest.mark.asyncio
53-
async def test_command_with_error_output(executor, monkeypatch):
54-
monkeypatch.setenv("ALLOW_COMMANDS", "ls")
55-
result = await executor.execute(["ls", "/nonexistent"])
56-
assert result["stderr"] != ""
57-
assert result["status"] != 0
50+
async def test_shell_operators_not_allowed(executor, monkeypatch):
51+
monkeypatch.setenv("ALLOW_COMMANDS", "echo,ls")
52+
operators = [";", "&&", "||", "|"]
53+
for op in operators:
54+
result = await executor.execute(["echo", "hello", op])
55+
assert result["error"] == f"Unexpected shell operator: {op}"
56+
assert result["status"] == 1

0 commit comments

Comments
 (0)