Skip to content

Commit dcb1dfb

Browse files
committed
Merge branch 'feature/improve-tests' into develop
2 parents ce27d7c + 94e5156 commit dcb1dfb

11 files changed

+396
-131
lines changed

src/mcp_shell_server/server.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,15 @@ async def run_tool(self, arguments: dict) -> Sequence[TextContent]:
8181
content: list[TextContent] = []
8282
try:
8383
# Handle execution with timeout
84-
result = await asyncio.wait_for(
85-
self.executor.execute(
86-
command, directory, stdin, None
87-
), # Pass None for timeout
88-
timeout=timeout,
89-
)
84+
try:
85+
result = await asyncio.wait_for(
86+
self.executor.execute(
87+
command, directory, stdin, None
88+
), # Pass None for timeout
89+
timeout=timeout,
90+
)
91+
except asyncio.TimeoutError as e:
92+
raise ValueError("Command execution timed out") from e
9093

9194
if result.get("error"):
9295
raise ValueError(result["error"])

src/mcp_shell_server/shell_executor.py

Lines changed: 83 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,26 @@ class ShellExecutor:
1212

1313
def __init__(self):
1414
"""
15-
Initialize the executor. The allowed commands are read from ALLOW_COMMANDS
16-
environment variable during command validation, not at initialization.
15+
Initialize the executor.
1716
"""
1817
pass
1918

19+
def _get_allowed_commands(self) -> set[str]:
20+
"""Get the set of allowed commands from environment variables"""
21+
allow_commands = os.environ.get("ALLOW_COMMANDS", "")
22+
allowed_commands = os.environ.get("ALLOWED_COMMANDS", "")
23+
commands = allow_commands + "," + allowed_commands
24+
return {cmd.strip() for cmd in commands.split(",") if cmd.strip()}
25+
26+
def get_allowed_commands(self) -> list[str]:
27+
"""Get the list of allowed commands from environment variables"""
28+
return list(self._get_allowed_commands())
29+
30+
def is_command_allowed(self, command: str) -> bool:
31+
"""Check if a command is in the allowed list"""
32+
cmd = command.strip()
33+
return cmd in self._get_allowed_commands()
34+
2035
def _validate_redirection_syntax(self, command: List[str]) -> None:
2136
"""
2237
Validate the syntax of redirection operators in the command.
@@ -155,24 +170,12 @@ async def _cleanup_handles(
155170
"""
156171
for key in ["stdout", "stderr"]:
157172
handle = handles.get(key)
158-
if isinstance(handle, IO) and handle != asyncio.subprocess.PIPE:
173+
if handle and hasattr(handle, "close") and not isinstance(handle, int):
159174
try:
160175
handle.close()
161-
except IOError:
176+
except (IOError, ValueError):
162177
pass
163178

164-
def _get_allowed_commands(self) -> set:
165-
"""
166-
Get the set of allowed commands from environment variables.
167-
Checks both ALLOW_COMMANDS and ALLOWED_COMMANDS.
168-
"""
169-
allow_commands = os.environ.get("ALLOW_COMMANDS", "")
170-
allowed_commands = os.environ.get("ALLOWED_COMMANDS", "")
171-
172-
# Combine and deduplicate commands from both environment variables
173-
commands = allow_commands + "," + allowed_commands
174-
return {cmd.strip() for cmd in commands.split(",") if cmd.strip()}
175-
176179
def _clean_command(self, command: List[str]) -> List[str]:
177180
"""
178181
Clean command by trimming whitespace from each part.
@@ -253,32 +256,36 @@ def _validate_directory(self, directory: Optional[str]) -> None:
253256
if not os.access(directory, os.R_OK | os.X_OK):
254257
raise ValueError(f"Directory is not accessible: {directory}")
255258

256-
def get_allowed_commands(self) -> list[str]:
257-
"""Get the allowed commands"""
258-
return list(self._get_allowed_commands())
259-
260259
def _validate_no_shell_operators(self, cmd: str) -> None:
261260
"""Validate that the command does not contain shell operators"""
262261
if cmd in [";" "&&", "||", "|"]:
263262
raise ValueError(f"Unexpected shell operator: {cmd}")
264263

265-
def _validate_pipeline(self, commands: List[str]) -> None:
266-
"""Validate pipeline command and ensure all parts are allowed"""
264+
def _validate_pipeline(self, commands: List[str]) -> Dict[str, str]:
265+
"""Validate pipeline command and ensure all parts are allowed
266+
267+
Returns:
268+
Dict[str, str]: Error message if validation fails, empty dict if success
269+
"""
267270
current_cmd: List[str] = []
268271

269272
for token in commands:
270273
if token == "|":
271274
if not current_cmd:
272275
raise ValueError("Empty command before pipe operator")
273-
self._validate_command(current_cmd)
276+
if not self.is_command_allowed(current_cmd[0]):
277+
raise ValueError(f"Command not allowed: {current_cmd[0]}")
274278
current_cmd = []
275279
elif token in [";", "&&", "||"]:
276280
raise ValueError(f"Unexpected shell operator in pipeline: {token}")
277281
else:
278282
current_cmd.append(token)
279283

280284
if current_cmd:
281-
self._validate_command(current_cmd)
285+
if not self.is_command_allowed(current_cmd[0]):
286+
raise ValueError(f"Command not allowed: {current_cmd[0]}")
287+
288+
return {}
282289

283290
def _split_pipe_commands(self, command: List[str]) -> List[List[str]]:
284291
"""
@@ -379,6 +386,7 @@ async def execute(
379386
timeout: Optional[int] = None,
380387
) -> Dict[str, Any]:
381388
start_time = time.time()
389+
process = None # Initialize process variable
382390

383391
try:
384392
# Validate directory if specified
@@ -393,33 +401,55 @@ async def execute(
393401
"execution_time": time.time() - start_time,
394402
}
395403

396-
# Preprocess command to handle pipe operators
404+
# Process command
397405
preprocessed_command = self._preprocess_command(command)
398406
cleaned_command = self._clean_command(preprocessed_command)
399407
if not cleaned_command:
400-
raise ValueError("Empty command")
408+
return {
409+
"error": "Empty command",
410+
"status": 1,
411+
"stdout": "",
412+
"stderr": "Empty command",
413+
"execution_time": time.time() - start_time,
414+
}
401415

402416
# First check for pipe operators and handle pipeline
403417
if "|" in cleaned_command:
404-
commands: List[List[str]] = []
405-
current_cmd: List[str] = []
406-
for token in cleaned_command:
407-
if token == "|":
408-
if current_cmd:
409-
commands.append(current_cmd)
410-
current_cmd = []
418+
try:
419+
# Validate pipeline first
420+
error = self._validate_pipeline(cleaned_command)
421+
if error:
422+
return {
423+
**error,
424+
"status": 1,
425+
"stdout": "",
426+
"execution_time": time.time() - start_time,
427+
}
428+
429+
# Split commands
430+
commands: List[List[str]] = []
431+
current_cmd: List[str] = []
432+
for token in cleaned_command:
433+
if token == "|":
434+
if current_cmd:
435+
commands.append(current_cmd)
436+
current_cmd = []
437+
else:
438+
raise ValueError("Empty command before pipe operator")
411439
else:
412-
raise ValueError("Empty command before pipe operator")
413-
else:
414-
current_cmd.append(token)
415-
if current_cmd:
416-
commands.append(current_cmd)
417-
418-
# Validate each command in pipeline
419-
for cmd in commands:
420-
self._validate_command(cmd)
440+
current_cmd.append(token)
441+
if current_cmd:
442+
commands.append(current_cmd)
421443

422-
return await self._execute_pipeline(commands, directory, timeout)
444+
return await self._execute_pipeline(commands, directory, timeout)
445+
except ValueError as e:
446+
return {
447+
"error": str(e),
448+
"status": 1,
449+
"stdout": "",
450+
"stderr": str(e),
451+
"execution_time": time.time() - start_time,
452+
}
423453

424454
# Then check for other shell operators
425455
for token in cleaned_command:
@@ -561,6 +591,11 @@ async def communicate_with_timeout():
561591
"stderr": str(e),
562592
"execution_time": time.time() - start_time,
563593
}
594+
finally:
595+
# Ensure process is terminated
596+
if process and process.returncode is None:
597+
process.kill()
598+
await process.wait()
564599

565600
async def _execute_pipeline(
566601
self,
@@ -663,5 +698,10 @@ async def _execute_pipeline(
663698
}
664699

665700
finally:
701+
# Ensure all processes are terminated
702+
for process in processes:
703+
if process.returncode is None:
704+
process.kill()
705+
await process.wait()
666706
if isinstance(last_stdout, IO):
667707
last_stdout.close()

tests/conftest.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,12 @@
1+
import os
2+
3+
14
# Configure pytest-asyncio
25
def pytest_configure(config):
36
"""Configure pytest-asyncio defaults"""
4-
config.option.asyncio_mode = "strict"
7+
# Enable command execution for tests
8+
os.environ["ALLOW_COMMANDS"] = "1"
9+
# Add allowed commands for tests
10+
os.environ["ALLOWED_COMMANDS"] = (
11+
"echo,sleep,cat,ls,pwd,touch,mkdir,rm,mv,cp,grep,awk,sed"
12+
)

tests/test_server.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,20 @@ def temp_test_dir():
1818
@pytest.mark.asyncio
1919
async def test_list_tools():
2020
"""Test listing of available tools"""
21+
22+
23+
@pytest.mark.asyncio
24+
async def test_tool_execution_timeout():
25+
"""Test tool execution with timeout"""
26+
with pytest.raises(RuntimeError, match="Command execution timed out"):
27+
await call_tool(
28+
"shell_execute",
29+
{
30+
"command": ["sleep", "2"],
31+
"directory": "/tmp",
32+
"timeout": 1,
33+
},
34+
)
2135
tools = await list_tools()
2236
assert len(tools) == 1
2337
tool = tools[0]
@@ -183,7 +197,7 @@ async def test_call_tool_with_timeout(monkeypatch):
183197
monkeypatch.setenv("ALLOW_COMMANDS", "sleep")
184198
with pytest.raises(RuntimeError) as excinfo:
185199
await call_tool("shell_execute", {"command": ["sleep", "2"], "timeout": 1})
186-
assert "Command timed out after 1 seconds" in str(excinfo.value)
200+
assert "Command execution timed out" in str(excinfo.value)
187201

188202

189203
@pytest.mark.asyncio

tests/test_server_validation.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import pytest
2+
3+
from mcp_shell_server.server import ExecuteToolHandler
4+
5+
6+
@pytest.mark.asyncio
7+
async def test_server_input_validation():
8+
"""Test input validation in execute tool"""
9+
handler = ExecuteToolHandler()
10+
11+
# Test command must be an array
12+
with pytest.raises(ValueError, match="'command' must be an array"):
13+
await handler.run_tool({"command": "not an array", "directory": "/"})
14+
15+
# Test directory is required
16+
with pytest.raises(ValueError, match="Directory is required"):
17+
await handler.run_tool({"command": ["echo", "test"], "directory": ""})
18+
19+
# Test command without arguments
20+
with pytest.raises(ValueError, match="No command provided"):
21+
await handler.run_tool({"directory": "/"})

0 commit comments

Comments
 (0)