diff --git a/.gitignore b/.gitignore index b057ab7..1a015f8 100644 --- a/.gitignore +++ b/.gitignore @@ -77,4 +77,5 @@ share/python-wheels/ *.egg MANIFEST -prompt.md \ No newline at end of file +prompt.md +.aider* diff --git a/README.md b/README.md index 2ebb79a..d5657da 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ code ~/Library/Application\ Support/Claude/claude_desktop_config.json "command": "uv", "args": [ "--directory", - ".", + "/path/to/your/cloned/repository", "run", "mcp-shell-server" ], @@ -109,6 +109,18 @@ ALLOWED_COMMANDS="ls ,echo, cat" # With spaces (using alias) ALLOW_COMMANDS="ls, cat , echo" # Multiple spaces ``` +### Configuring Regex Patterns + +You can allow commands using regex patterns by setting the `ALLOW_PATTERNS` environment variable. Patterns should be separated by commas. + +Example: + +```bash +ALLOW_PATTERNS="^cmd[0-9]+$,^test.*$" +``` + +This configuration allows commands like `cmd123` and `testCommand`. + ### Request Format ```python diff --git a/pyproject.toml b/pyproject.toml index 5dbf3b5..a33322c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,6 @@ markers = [ ] filterwarnings = [ "ignore::RuntimeWarning:selectors:", - "ignore::pytest.PytestUnhandledCoroutineWarning:", "ignore::pytest.PytestUnraisableExceptionWarning:", "ignore::DeprecationWarning:pytest_asyncio.plugin:", ] diff --git a/src/mcp_shell_server/command_validator.py b/src/mcp_shell_server/command_validator.py index 2ca1765..a4267f0 100644 --- a/src/mcp_shell_server/command_validator.py +++ b/src/mcp_shell_server/command_validator.py @@ -3,7 +3,8 @@ """ import os -from typing import Dict, List +import re +from typing import Dict, List, Set class CommandValidator: @@ -28,10 +29,23 @@ def get_allowed_commands(self) -> list[str]: """Get the list of allowed commands from environment variables""" return list(self._get_allowed_commands()) + def get_allowed_patterns(self) -> List[re.Pattern]: + """Get the list of allowed regex patterns from environment variables""" + allow_patterns = os.environ.get("ALLOW_PATTERNS", "") + patterns = [pattern.strip() for pattern in allow_patterns.split(",") if pattern.strip()] + return [re.compile(pattern) for pattern in patterns] + def is_command_allowed(self, command: str) -> bool: - """Check if a command is in the allowed list""" + """Check if a command is in the allowed list or matches an allowed pattern""" + cmd = command.strip() + if cmd in self._get_allowed_commands(): + return True + for pattern in self.get_allowed_patterns(): + if pattern.match(cmd): + return True + return False cmd = command.strip() - return cmd in self._get_allowed_commands() + return cmd in self.get_allowed_commands() def validate_no_shell_operators(self, cmd: str) -> None: """ @@ -92,13 +106,12 @@ def validate_command(self, command: List[str]) -> None: if not command: raise ValueError("Empty command") - allowed_commands = self._get_allowed_commands() - if not allowed_commands: + if not self._get_allowed_commands() and not self.get_allowed_patterns(): raise ValueError( "No commands are allowed. Please set ALLOW_COMMANDS environment variable." ) # Clean and check the first command cleaned_cmd = command[0].strip() - if cleaned_cmd not in allowed_commands: + if not self.is_command_allowed(cleaned_cmd): raise ValueError(f"Command not allowed: {cleaned_cmd}") diff --git a/src/mcp_shell_server/server.py b/src/mcp_shell_server/server.py index 0a57702..2da683e 100644 --- a/src/mcp_shell_server/server.py +++ b/src/mcp_shell_server/server.py @@ -30,7 +30,17 @@ def get_allowed_commands(self) -> list[str]: """Get the allowed commands""" return self.executor.validator.get_allowed_commands() + def get_allowed_patterns(self) -> list[str]: + """Get the allowed regex patterns""" + return [pattern.pattern for pattern in self.executor.validator.get_allowed_patterns()] + def get_tool_description(self) -> Tool: + """Get the tool description for the execute command""" + allowed_commands = ', '.join(self.get_allowed_commands()) + allowed_patterns = ', '.join(self.get_allowed_patterns()) + description = f"{self.description}\n" + if allowed_commands != '': description += f"Allowed commands: {allowed_commands}\n" + if allowed_patterns != '': description += f"Allowed patterns: {allowed_patterns}" """Get the tool description for the execute command""" return Tool( name=self.name, diff --git a/tests/test_command_validator.py b/tests/test_command_validator.py index d93f5c6..a926ead 100644 --- a/tests/test_command_validator.py +++ b/tests/test_command_validator.py @@ -22,7 +22,15 @@ def test_get_allowed_commands(validator, monkeypatch): assert set(validator.get_allowed_commands()) == {"cmd1", "cmd2", "cmd3", "cmd4"} -def test_is_command_allowed(validator, monkeypatch): +def test_is_command_allowed_with_patterns(validator, monkeypatch): + clear_env(monkeypatch) + monkeypatch.setenv("ALLOW_COMMANDS", "allowed_cmd") + monkeypatch.setenv("ALLOW_PATTERNS", "^cmd[0-9]+$") + + assert validator.is_command_allowed("allowed_cmd") + assert validator.is_command_allowed("cmd123") + assert not validator.is_command_allowed("disallowed_cmd") + assert not validator.is_command_allowed("cmdabc") clear_env(monkeypatch) monkeypatch.setenv("ALLOW_COMMANDS", "allowed_cmd") assert validator.is_command_allowed("allowed_cmd")