Skip to content

Commit 33b0f31

Browse files
author
Yoshihiro Takahara
committed
feat: improve shell executor tests and fix edge cases
- Add new test cases for server validation - Add tests for command redirections - Improve shell executor error handling - Fix edge cases in pipe handling
1 parent ce27d7c commit 33b0f31

9 files changed

+312
-125
lines changed

src/mcp_shell_server/server.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,15 @@ async def run_tool(self, arguments: dict) -> Sequence[TextContent]:
8181
content: list[TextContent] = []
8282
try:
8383
# Handle execution with timeout
84-
result = await asyncio.wait_for(
85-
self.executor.execute(
86-
command, directory, stdin, None
87-
), # Pass None for timeout
88-
timeout=timeout,
89-
)
84+
try:
85+
result = await asyncio.wait_for(
86+
self.executor.execute(
87+
command, directory, stdin, None
88+
), # Pass None for timeout
89+
timeout=timeout,
90+
)
91+
except asyncio.TimeoutError as e:
92+
raise ValueError("Command execution timed out") from e
9093

9194
if result.get("error"):
9295
raise ValueError(result["error"])

src/mcp_shell_server/shell_executor.py

Lines changed: 72 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,26 @@ class ShellExecutor:
1212

1313
def __init__(self):
1414
"""
15-
Initialize the executor. The allowed commands are read from ALLOW_COMMANDS
16-
environment variable during command validation, not at initialization.
15+
Initialize the executor.
1716
"""
1817
pass
1918

19+
def _get_allowed_commands(self) -> set[str]:
20+
"""Get the set of allowed commands from environment variables"""
21+
allow_commands = os.environ.get("ALLOW_COMMANDS", "")
22+
allowed_commands = os.environ.get("ALLOWED_COMMANDS", "")
23+
commands = allow_commands + "," + allowed_commands
24+
return {cmd.strip() for cmd in commands.split(",") if cmd.strip()}
25+
26+
def get_allowed_commands(self) -> list[str]:
27+
"""Get the list of allowed commands from environment variables"""
28+
return list(self._get_allowed_commands())
29+
30+
def is_command_allowed(self, command: str) -> bool:
31+
"""Check if a command is in the allowed list"""
32+
cmd = command.strip()
33+
return cmd in self._get_allowed_commands()
34+
2035
def _validate_redirection_syntax(self, command: List[str]) -> None:
2136
"""
2237
Validate the syntax of redirection operators in the command.
@@ -155,24 +170,12 @@ async def _cleanup_handles(
155170
"""
156171
for key in ["stdout", "stderr"]:
157172
handle = handles.get(key)
158-
if isinstance(handle, IO) and handle != asyncio.subprocess.PIPE:
173+
if handle and hasattr(handle, "close") and not isinstance(handle, int):
159174
try:
160175
handle.close()
161-
except IOError:
176+
except (IOError, ValueError):
162177
pass
163178

164-
def _get_allowed_commands(self) -> set:
165-
"""
166-
Get the set of allowed commands from environment variables.
167-
Checks both ALLOW_COMMANDS and ALLOWED_COMMANDS.
168-
"""
169-
allow_commands = os.environ.get("ALLOW_COMMANDS", "")
170-
allowed_commands = os.environ.get("ALLOWED_COMMANDS", "")
171-
172-
# Combine and deduplicate commands from both environment variables
173-
commands = allow_commands + "," + allowed_commands
174-
return {cmd.strip() for cmd in commands.split(",") if cmd.strip()}
175-
176179
def _clean_command(self, command: List[str]) -> List[str]:
177180
"""
178181
Clean command by trimming whitespace from each part.
@@ -253,32 +256,36 @@ def _validate_directory(self, directory: Optional[str]) -> None:
253256
if not os.access(directory, os.R_OK | os.X_OK):
254257
raise ValueError(f"Directory is not accessible: {directory}")
255258

256-
def get_allowed_commands(self) -> list[str]:
257-
"""Get the allowed commands"""
258-
return list(self._get_allowed_commands())
259-
260259
def _validate_no_shell_operators(self, cmd: str) -> None:
261260
"""Validate that the command does not contain shell operators"""
262261
if cmd in [";" "&&", "||", "|"]:
263262
raise ValueError(f"Unexpected shell operator: {cmd}")
264263

265-
def _validate_pipeline(self, commands: List[str]) -> None:
266-
"""Validate pipeline command and ensure all parts are allowed"""
264+
def _validate_pipeline(self, commands: List[str]) -> Dict[str, str]:
265+
"""Validate pipeline command and ensure all parts are allowed
266+
267+
Returns:
268+
Dict[str, str]: Error message if validation fails, empty dict if success
269+
"""
267270
current_cmd: List[str] = []
268271

269272
for token in commands:
270273
if token == "|":
271274
if not current_cmd:
272275
raise ValueError("Empty command before pipe operator")
273-
self._validate_command(current_cmd)
276+
if not self.is_command_allowed(current_cmd[0]):
277+
raise ValueError(f"Command not allowed: {current_cmd[0]}")
274278
current_cmd = []
275279
elif token in [";", "&&", "||"]:
276280
raise ValueError(f"Unexpected shell operator in pipeline: {token}")
277281
else:
278282
current_cmd.append(token)
279283

280284
if current_cmd:
281-
self._validate_command(current_cmd)
285+
if not self.is_command_allowed(current_cmd[0]):
286+
raise ValueError(f"Command not allowed: {current_cmd[0]}")
287+
288+
return {}
282289

283290
def _split_pipe_commands(self, command: List[str]) -> List[List[str]]:
284291
"""
@@ -393,33 +400,55 @@ async def execute(
393400
"execution_time": time.time() - start_time,
394401
}
395402

396-
# Preprocess command to handle pipe operators
403+
# Process command
397404
preprocessed_command = self._preprocess_command(command)
398405
cleaned_command = self._clean_command(preprocessed_command)
399406
if not cleaned_command:
400-
raise ValueError("Empty command")
407+
return {
408+
"error": "Empty command",
409+
"status": 1,
410+
"stdout": "",
411+
"stderr": "Empty command",
412+
"execution_time": time.time() - start_time,
413+
}
401414

402415
# First check for pipe operators and handle pipeline
403416
if "|" in cleaned_command:
404-
commands: List[List[str]] = []
405-
current_cmd: List[str] = []
406-
for token in cleaned_command:
407-
if token == "|":
408-
if current_cmd:
409-
commands.append(current_cmd)
410-
current_cmd = []
417+
try:
418+
# Validate pipeline first
419+
error = self._validate_pipeline(cleaned_command)
420+
if error:
421+
return {
422+
**error,
423+
"status": 1,
424+
"stdout": "",
425+
"execution_time": time.time() - start_time,
426+
}
427+
428+
# Split commands
429+
commands: List[List[str]] = []
430+
current_cmd: List[str] = []
431+
for token in cleaned_command:
432+
if token == "|":
433+
if current_cmd:
434+
commands.append(current_cmd)
435+
current_cmd = []
436+
else:
437+
raise ValueError("Empty command before pipe operator")
411438
else:
412-
raise ValueError("Empty command before pipe operator")
413-
else:
414-
current_cmd.append(token)
415-
if current_cmd:
416-
commands.append(current_cmd)
439+
current_cmd.append(token)
440+
if current_cmd:
441+
commands.append(current_cmd)
417442

418-
# Validate each command in pipeline
419-
for cmd in commands:
420-
self._validate_command(cmd)
421-
422-
return await self._execute_pipeline(commands, directory, timeout)
443+
return await self._execute_pipeline(commands, directory, timeout)
444+
except ValueError as e:
445+
return {
446+
"error": str(e),
447+
"status": 1,
448+
"stdout": "",
449+
"stderr": str(e),
450+
"execution_time": time.time() - start_time,
451+
}
423452

424453
# Then check for other shell operators
425454
for token in cleaned_command:

tests/conftest.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,13 @@
1+
import os
2+
3+
14
# Configure pytest-asyncio
25
def pytest_configure(config):
36
"""Configure pytest-asyncio defaults"""
47
config.option.asyncio_mode = "strict"
8+
# Enable command execution for tests
9+
os.environ["ALLOW_COMMANDS"] = "1"
10+
# Add allowed commands for tests
11+
os.environ["ALLOWED_COMMANDS"] = (
12+
"echo,sleep,cat,ls,pwd,touch,mkdir,rm,mv,cp,grep,awk,sed"
13+
)

tests/test_server.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,20 @@ def temp_test_dir():
1818
@pytest.mark.asyncio
1919
async def test_list_tools():
2020
"""Test listing of available tools"""
21+
22+
23+
@pytest.mark.asyncio
24+
async def test_tool_execution_timeout():
25+
"""Test tool execution with timeout"""
26+
with pytest.raises(RuntimeError, match="Command execution timed out"):
27+
await call_tool(
28+
"shell_execute",
29+
{
30+
"command": ["sleep", "2"],
31+
"directory": "/tmp",
32+
"timeout": 1,
33+
},
34+
)
2135
tools = await list_tools()
2236
assert len(tools) == 1
2337
tool = tools[0]
@@ -183,7 +197,7 @@ async def test_call_tool_with_timeout(monkeypatch):
183197
monkeypatch.setenv("ALLOW_COMMANDS", "sleep")
184198
with pytest.raises(RuntimeError) as excinfo:
185199
await call_tool("shell_execute", {"command": ["sleep", "2"], "timeout": 1})
186-
assert "Command timed out after 1 seconds" in str(excinfo.value)
200+
assert "Command execution timed out" in str(excinfo.value)
187201

188202

189203
@pytest.mark.asyncio

tests/test_server_validation.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import pytest
2+
3+
from mcp_shell_server.server import ExecuteToolHandler
4+
5+
6+
@pytest.mark.asyncio
7+
async def test_server_input_validation():
8+
"""Test input validation in execute tool"""
9+
handler = ExecuteToolHandler()
10+
11+
# Test command must be an array
12+
with pytest.raises(ValueError, match="'command' must be an array"):
13+
await handler.run_tool({"command": "not an array", "directory": "/"})
14+
15+
# Test directory is required
16+
with pytest.raises(ValueError, match="Directory is required"):
17+
await handler.run_tool({"command": ["echo", "test"], "directory": ""})
18+
19+
# Test command without arguments
20+
with pytest.raises(ValueError, match="No command provided"):
21+
await handler.run_tool({"directory": "/"})

0 commit comments

Comments
 (0)