Skip to content

Commit 2def658

Browse files
committed
refactor: extract command validation logic to CommandValidator class
- Create CommandValidator class to handle command validation - Move validation methods from ShellExecutor to CommandValidator - Update tests and maintain 91% code coverage
1 parent d383539 commit 2def658

File tree

6 files changed

+203
-63
lines changed

6 files changed

+203
-63
lines changed
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
"""
2+
Provides validation for shell commands and ensures they are allowed to be executed.
3+
"""
4+
5+
import os
6+
from typing import Dict, List
7+
8+
9+
class CommandValidator:
10+
"""
11+
Validates shell commands against a whitelist and checks for unsafe operators.
12+
"""
13+
14+
def __init__(self):
15+
"""
16+
Initialize the validator.
17+
"""
18+
pass
19+
20+
def _get_allowed_commands(self) -> set[str]:
21+
"""Get the set of allowed commands from environment variables"""
22+
allow_commands = os.environ.get("ALLOW_COMMANDS", "")
23+
allowed_commands = os.environ.get("ALLOWED_COMMANDS", "")
24+
commands = allow_commands + "," + allowed_commands
25+
return {cmd.strip() for cmd in commands.split(",") if cmd.strip()}
26+
27+
def get_allowed_commands(self) -> list[str]:
28+
"""Get the list of allowed commands from environment variables"""
29+
return list(self._get_allowed_commands())
30+
31+
def is_command_allowed(self, command: str) -> bool:
32+
"""Check if a command is in the allowed list"""
33+
cmd = command.strip()
34+
return cmd in self._get_allowed_commands()
35+
36+
def validate_no_shell_operators(self, cmd: str) -> None:
37+
"""
38+
Validate that the command does not contain shell operators.
39+
40+
Args:
41+
cmd (str): Command to validate
42+
43+
Raises:
44+
ValueError: If the command contains shell operators
45+
"""
46+
if cmd in [";", "&&", "||", "|"]:
47+
raise ValueError(f"Unexpected shell operator: {cmd}")
48+
49+
def validate_pipeline(self, commands: List[str]) -> Dict[str, str]:
50+
"""
51+
Validate pipeline command and ensure all parts are allowed.
52+
53+
Args:
54+
commands (List[str]): List of commands to validate
55+
56+
Returns:
57+
Dict[str, str]: Error message if validation fails, empty dict if success
58+
59+
Raises:
60+
ValueError: If validation fails
61+
"""
62+
current_cmd: List[str] = []
63+
64+
for token in commands:
65+
if token == "|":
66+
if not current_cmd:
67+
raise ValueError("Empty command before pipe operator")
68+
if not self.is_command_allowed(current_cmd[0]):
69+
raise ValueError(f"Command not allowed: {current_cmd[0]}")
70+
current_cmd = []
71+
elif token in [";", "&&", "||"]:
72+
raise ValueError(f"Unexpected shell operator in pipeline: {token}")
73+
else:
74+
current_cmd.append(token)
75+
76+
if current_cmd:
77+
if not self.is_command_allowed(current_cmd[0]):
78+
raise ValueError(f"Command not allowed: {current_cmd[0]}")
79+
80+
return {}
81+
82+
def validate_command(self, command: List[str]) -> None:
83+
"""
84+
Validate if the command is allowed to be executed.
85+
86+
Args:
87+
command (List[str]): Command and its arguments
88+
89+
Raises:
90+
ValueError: If the command is empty, not allowed, or contains invalid shell operators
91+
"""
92+
if not command:
93+
raise ValueError("Empty command")
94+
95+
allowed_commands = self._get_allowed_commands()
96+
if not allowed_commands:
97+
raise ValueError(
98+
"No commands are allowed. Please set ALLOW_COMMANDS environment variable."
99+
)
100+
101+
# Clean and check the first command
102+
cleaned_cmd = command[0].strip()
103+
if cleaned_cmd not in allowed_commands:
104+
raise ValueError(f"Command not allowed: {cleaned_cmd}")

src/mcp_shell_server/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(self):
2828

2929
def get_allowed_commands(self) -> list[str]:
3030
"""Get the allowed commands"""
31-
return self.executor.get_allowed_commands()
31+
return self.executor.validator.get_allowed_commands()
3232

3333
def get_tool_description(self) -> Tool:
3434
"""Get the tool description for the execute command"""

src/mcp_shell_server/shell_executor.py

Lines changed: 19 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import time
66
from typing import IO, Any, Dict, List, Optional, Tuple, Union
77

8+
from mcp_shell_server.command_validator import CommandValidator
9+
810

911
class ShellExecutor:
1012
"""
@@ -13,25 +15,9 @@ class ShellExecutor:
1315

1416
def __init__(self):
1517
"""
16-
Initialize the executor.
18+
Initialize the executor with a command validator.
1719
"""
18-
pass # pragma: no cover
19-
20-
def _get_allowed_commands(self) -> set[str]:
21-
"""Get the set of allowed commands from environment variables"""
22-
allow_commands = os.environ.get("ALLOW_COMMANDS", "")
23-
allowed_commands = os.environ.get("ALLOWED_COMMANDS", "")
24-
commands = allow_commands + "," + allowed_commands
25-
return {cmd.strip() for cmd in commands.split(",") if cmd.strip()}
26-
27-
def get_allowed_commands(self) -> list[str]:
28-
"""Get the list of allowed commands from environment variables"""
29-
return list(self._get_allowed_commands())
30-
31-
def is_command_allowed(self, command: str) -> bool:
32-
"""Check if a command is in the allowed list"""
33-
cmd = command.strip()
34-
return cmd in self._get_allowed_commands()
20+
self.validator = CommandValidator()
3521

3622
def _validate_redirection_syntax(self, command: List[str]) -> None:
3723
"""
@@ -223,16 +209,7 @@ def _validate_command(self, command: List[str]) -> None:
223209
if not command:
224210
raise ValueError("Empty command")
225211

226-
allowed_commands = self._get_allowed_commands()
227-
if not allowed_commands:
228-
raise ValueError(
229-
"No commands are allowed. Please set ALLOW_COMMANDS environment variable."
230-
)
231-
232-
# Clean and check the first command
233-
cleaned_cmd = command[0].strip()
234-
if cleaned_cmd not in allowed_commands:
235-
raise ValueError(f"Command not allowed: {cleaned_cmd}")
212+
self.validator.validate_command(command)
236213

237214
def _validate_directory(self, directory: Optional[str]) -> None:
238215
"""
@@ -259,34 +236,15 @@ def _validate_directory(self, directory: Optional[str]) -> None:
259236

260237
def _validate_no_shell_operators(self, cmd: str) -> None:
261238
"""Validate that the command does not contain shell operators"""
262-
if cmd in [";" "&&", "||", "|"]:
263-
raise ValueError(f"Unexpected shell operator: {cmd}")
239+
self.validator.validate_no_shell_operators(cmd)
264240

265241
def _validate_pipeline(self, commands: List[str]) -> Dict[str, str]:
266242
"""Validate pipeline command and ensure all parts are allowed
267243
268244
Returns:
269245
Dict[str, str]: Error message if validation fails, empty dict if success
270246
"""
271-
current_cmd: List[str] = []
272-
273-
for token in commands:
274-
if token == "|":
275-
if not current_cmd:
276-
raise ValueError("Empty command before pipe operator")
277-
if not self.is_command_allowed(current_cmd[0]):
278-
raise ValueError(f"Command not allowed: {current_cmd[0]}")
279-
current_cmd = []
280-
elif token in [";", "&&", "||"]:
281-
raise ValueError(f"Unexpected shell operator in pipeline: {token}")
282-
else:
283-
current_cmd.append(token)
284-
285-
if current_cmd:
286-
if not self.is_command_allowed(current_cmd[0]):
287-
raise ValueError(f"Command not allowed: {current_cmd[0]}")
288-
289-
return {}
247+
return self.validator.validate_pipeline(commands)
290248

291249
def _split_pipe_commands(self, command: List[str]) -> List[List[str]]:
292250
"""
@@ -425,13 +383,15 @@ async def execute(
425383
# First check for pipe operators and handle pipeline
426384
if "|" in cleaned_command:
427385
try:
428-
# Validate pipeline first
429-
error = self._validate_pipeline(cleaned_command)
430-
if error:
386+
# Validate pipeline first using the validator
387+
try:
388+
self.validator.validate_pipeline(cleaned_command)
389+
except ValueError as e:
431390
return {
432-
**error,
391+
"error": str(e),
433392
"status": 1,
434393
"stdout": "",
394+
"stderr": str(e),
435395
"execution_time": time.time() - start_time,
436396
}
437397

@@ -464,18 +424,20 @@ async def execute(
464424

465425
# Then check for other shell operators
466426
for token in cleaned_command:
467-
if token in [";", "&&", "||"]:
427+
try:
428+
self.validator.validate_no_shell_operators(token)
429+
except ValueError as e:
468430
return {
469-
"error": f"Unexpected shell operator: {token}",
431+
"error": str(e),
470432
"status": 1,
471433
"stdout": "",
472-
"stderr": f"Unexpected shell operator: {token}",
434+
"stderr": str(e),
473435
"execution_time": time.time() - start_time,
474436
}
475437

476438
# Single command execution
477439
cmd, redirects = self._parse_command(cleaned_command)
478-
self._validate_command(cmd)
440+
self.validator.validate_command(cmd)
479441

480442
# Directory validation
481443
if directory:

tests/test_command_validator.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
"""Test cases for the CommandValidator class."""
2+
3+
import pytest
4+
5+
from mcp_shell_server.command_validator import CommandValidator
6+
7+
8+
def clear_env(monkeypatch):
9+
monkeypatch.delenv("ALLOW_COMMANDS", raising=False)
10+
monkeypatch.delenv("ALLOWED_COMMANDS", raising=False)
11+
12+
13+
@pytest.fixture
14+
def validator():
15+
return CommandValidator()
16+
17+
18+
def test_get_allowed_commands(validator, monkeypatch):
19+
clear_env(monkeypatch)
20+
monkeypatch.setenv("ALLOW_COMMANDS", "cmd1,cmd2")
21+
monkeypatch.setenv("ALLOWED_COMMANDS", "cmd3,cmd4")
22+
assert set(validator.get_allowed_commands()) == {"cmd1", "cmd2", "cmd3", "cmd4"}
23+
24+
25+
def test_is_command_allowed(validator, monkeypatch):
26+
clear_env(monkeypatch)
27+
monkeypatch.setenv("ALLOW_COMMANDS", "allowed_cmd")
28+
assert validator.is_command_allowed("allowed_cmd")
29+
assert not validator.is_command_allowed("disallowed_cmd")
30+
31+
32+
def test_validate_no_shell_operators(validator):
33+
validator.validate_no_shell_operators("echo") # Should not raise
34+
with pytest.raises(ValueError, match="Unexpected shell operator"):
35+
validator.validate_no_shell_operators(";")
36+
with pytest.raises(ValueError, match="Unexpected shell operator"):
37+
validator.validate_no_shell_operators("&&")
38+
39+
40+
def test_validate_pipeline(validator, monkeypatch):
41+
clear_env(monkeypatch)
42+
monkeypatch.setenv("ALLOW_COMMANDS", "ls,grep")
43+
44+
# Valid pipeline
45+
validator.validate_pipeline(["ls", "|", "grep", "test"])
46+
47+
# Empty command before pipe
48+
with pytest.raises(ValueError, match="Empty command before pipe operator"):
49+
validator.validate_pipeline(["|", "grep", "test"])
50+
51+
# Command not allowed
52+
with pytest.raises(ValueError, match="Command not allowed"):
53+
validator.validate_pipeline(["invalid_cmd", "|", "grep", "test"])
54+
55+
56+
def test_validate_command(validator, monkeypatch):
57+
clear_env(monkeypatch)
58+
59+
# No allowed commands
60+
with pytest.raises(ValueError, match="No commands are allowed"):
61+
validator.validate_command(["cmd"])
62+
63+
monkeypatch.setenv("ALLOW_COMMANDS", "allowed_cmd")
64+
65+
# Empty command
66+
with pytest.raises(ValueError, match="Empty command"):
67+
validator.validate_command([])
68+
69+
# Command not allowed
70+
with pytest.raises(ValueError, match="Command not allowed"):
71+
validator.validate_command(["disallowed_cmd"])
72+
73+
# Command allowed
74+
validator.validate_command(["allowed_cmd", "-arg"]) # Should not raise

tests/test_shell_executor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ async def test_allow_commands_precedence(executor, temp_test_dir, monkeypatch):
233233
monkeypatch.setenv("ALLOW_COMMANDS", "echo,ls")
234234
monkeypatch.setenv("ALLOWED_COMMANDS", "echo,cat")
235235

236-
allowed = executor.get_allowed_commands()
236+
allowed = executor.validator.get_allowed_commands()
237237
assert set(allowed) == {"echo", "ls", "cat"}
238238

239239

@@ -494,11 +494,11 @@ def test_validate_pipeline(executor, monkeypatch):
494494
monkeypatch.setenv("ALLOWED_COMMANDS", "echo,grep,cat")
495495

496496
# Test valid pipeline
497-
executor._validate_pipeline(["echo", "hello", "|", "grep", "h"])
497+
executor.validator.validate_pipeline(["echo", "hello", "|", "grep", "h"])
498498

499499
# Test empty command before pipe
500500
with pytest.raises(ValueError) as exc:
501-
executor._validate_pipeline(["|", "grep", "test"])
501+
executor.validator.validate_pipeline(["|", "grep", "test"])
502502
assert str(exc.value) == "Empty command before pipe operator"
503503

504504
# Test disallowed commands in pipeline

tests/test_shell_executor_error_cases.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,15 @@ async def test_no_allowed_commands_validation(monkeypatch):
4242
ValueError,
4343
match="No commands are allowed. Please set ALLOW_COMMANDS environment variable.",
4444
):
45-
executor._validate_command(["any_command"])
45+
executor.validator.validate_command(["any_command"])
4646

4747

4848
@pytest.mark.asyncio
4949
async def test_shell_operator_validation():
5050
"""Test validation of shell operators"""
5151
executor = ShellExecutor()
5252

53-
operators = [";" "&&", "||", "|"]
53+
operators = [";", "&&", "||", "|"]
5454
for op in operators:
5555
# Test shell operator validation
5656
with pytest.raises(ValueError, match=f"Unexpected shell operator: {op}"):

0 commit comments

Comments
 (0)