Skip to content

Commit 114b928

Browse files
authored
Merge pull request #17 from garland3/security-fixes
Fix critical security vulnerabilities (P0) - GitHub Issue #16
2 parents 68c48ed + 6ae80f8 commit 114b928

File tree

5 files changed

+509
-25
lines changed

5 files changed

+509
-25
lines changed

src/talkpipe/app/chatterlang_serve.py

Lines changed: 69 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,9 @@ async def favicon():
179179
# Configure middleware
180180
self._setup_middleware()
181181

182+
# Add security headers middleware
183+
self._setup_security_headers()
184+
182185
# Configure routes
183186
self._setup_routes()
184187

@@ -221,13 +224,15 @@ def get_or_create_session(self, request: Request, response: Response) -> UserSes
221224
)
222225
self.sessions[session_id] = session
223226

224-
# Set session cookie (expires in 24 hours)
227+
# Set session cookie (expires in 24 hours) with security attributes
225228
response.set_cookie(
226229
key="talkpipe_session_id",
227230
value=session_id,
228231
max_age=86400, # 24 hours
229-
httponly=True,
230-
samesite="lax"
232+
httponly=True, # Prevent JavaScript access
233+
samesite="lax", # CSRF protection
234+
secure=False, # Set to True in production with HTTPS
235+
path="/" # Restrict cookie path
231236
)
232237

233238
logger.info(f"Created new session: {session_id}")
@@ -268,15 +273,61 @@ def cleanup_worker():
268273
logger.info("Started session cleanup background task")
269274

270275
def _setup_middleware(self):
271-
"""Configure CORS middleware"""
276+
"""Configure CORS middleware with security restrictions"""
277+
# Define allowed origins - never use "*" in production
278+
allowed_origins = [
279+
"http://localhost:3000",
280+
"http://localhost:8000",
281+
"http://127.0.0.1:3000",
282+
"http://127.0.0.1:8000",
283+
f"http://localhost:{self.port}",
284+
f"http://127.0.0.1:{self.port}"
285+
]
286+
287+
# Add environment-specific origins if configured
288+
import os
289+
env_origins = os.getenv('TALKPIPE_ALLOWED_ORIGINS', '').split(',')
290+
allowed_origins.extend([origin.strip() for origin in env_origins if origin.strip()])
291+
272292
self.app.add_middleware(
273293
CORSMiddleware,
274-
allow_origins=["*"],
294+
allow_origins=allowed_origins, # Specific origins only - never "*"
275295
allow_credentials=True,
276-
allow_methods=["*"],
277-
allow_headers=["*"],
296+
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], # Specific methods only
297+
allow_headers=["Content-Type", "Authorization", "X-API-Key"], # Specific headers only
298+
expose_headers=["Content-Type"],
299+
max_age=86400, # Cache preflight requests for 24 hours
278300
)
279301

302+
def _setup_security_headers(self):
303+
"""Add security headers to all responses"""
304+
@self.app.middleware("http")
305+
async def add_security_headers(request, call_next):
306+
response = await call_next(request)
307+
308+
# Security headers
309+
response.headers["X-Frame-Options"] = "DENY"
310+
response.headers["X-Content-Type-Options"] = "nosniff"
311+
response.headers["X-XSS-Protection"] = "1; mode=block"
312+
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
313+
response.headers["Content-Security-Policy"] = (
314+
"default-src 'self'; "
315+
"script-src 'self' 'unsafe-inline'; "
316+
"style-src 'self' 'unsafe-inline'; "
317+
"img-src 'self' data:; "
318+
"connect-src 'self'; "
319+
"font-src 'self'; "
320+
"object-src 'none'; "
321+
"media-src 'self'; "
322+
"child-src 'none';"
323+
)
324+
response.headers["Permissions-Policy"] = (
325+
"camera=(), microphone=(), geolocation=(), payment=(), "
326+
"usb=(), magnetometer=(), gyroscope=(), speaker=()"
327+
)
328+
329+
return response
330+
280331
def _setup_routes(self):
281332
"""Configure all API routes"""
282333

@@ -326,12 +377,21 @@ async def get_form_config():
326377
return self.form_config.model_dump()
327378

328379
@self.app.get("/output-stream")
329-
async def output_stream(request: Request, response: Response):
380+
async def output_stream(
381+
request: Request,
382+
response: Response,
383+
api_key: str = Depends(self._verify_api_key)
384+
):
330385
"""Server-Sent Events endpoint for streaming output"""
331386
session = self.get_or_create_session(request, response)
332387
return StreamingResponse(
333388
self._generate_output_stream(session),
334-
media_type="text/event-stream"
389+
media_type="text/event-stream",
390+
headers={
391+
"Cache-Control": "no-cache",
392+
"Connection": "keep-alive",
393+
"X-Content-Type-Options": "nosniff"
394+
}
335395
)
336396

337397
async def _verify_api_key(self, x_api_key: Optional[str] = Header(None)):

src/talkpipe/util/data_manipulation.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,23 @@ def compileLambda(expression: str, fail_on_error: bool = True):
284284
Returns:
285285
A callable function that takes a single 'item' parameter and returns the evaluated expression result
286286
"""
287+
# Security check: block dangerous patterns in expressions
288+
dangerous_patterns = [
289+
'__import__', 'import', 'exec', 'eval', 'compile', 'open', 'file',
290+
'input', 'raw_input', 'reload', 'vars', 'locals', 'globals',
291+
'dir', 'hasattr', 'getattr', 'setattr', 'delattr', 'classmethod',
292+
'staticmethod', 'super', 'property', '__', '.mro', '.subclasses'
293+
]
294+
295+
expression_lower = expression.lower()
296+
for pattern in dangerous_patterns:
297+
if pattern in expression_lower:
298+
raise ValueError(f"Security violation: Expression contains prohibited pattern '{pattern}'")
299+
300+
# Additional security: check for attribute access to dangerous methods
301+
if '.__' in expression or 'getitem' in expression_lower or 'setitem' in expression_lower:
302+
raise ValueError("Security violation: Expression contains prohibited attribute access patterns")
303+
287304
# Set of safe built-ins that can be used in expressions
288305
_SAFE_BUILTINS = {
289306
'abs': abs, 'all': all, 'any': any, 'bool': bool, 'dict': dict,
@@ -315,11 +332,18 @@ def lambda_function(item: Any) -> Any:
315332

316333
# If item is a dictionary, add its keys as variables for convenience
317334
if isinstance(item, dict):
318-
locals_dict.update(item)
335+
# Filter dictionary keys to prevent injection of dangerous names
336+
safe_keys = {k: v for k, v in item.items()
337+
if isinstance(k, str) and not k.startswith('_') and k not in dangerous_patterns}
338+
locals_dict.update(safe_keys)
339+
340+
# Create a completely restricted environment with no access to dangerous globals
341+
restricted_globals = {'__builtins__': {}}
342+
restricted_globals.update(SAFE_BUILTINS)
319343

320-
# Evaluate the expression in a restricted environment
344+
# Evaluate the expression in a heavily restricted environment
321345
try:
322-
result = eval(compiled_code, dict(SAFE_BUILTINS), locals_dict)
346+
result = eval(compiled_code, restricted_globals, locals_dict)
323347
return result
324348
except Exception as e:
325349
error_msg = f"Error evaluating expression '{expression}' on item {item}: {e}"

src/talkpipe/util/os.py

Lines changed: 128 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,145 @@
11
import logging
2+
import subprocess
3+
import shlex
24

35
logger = logging.getLogger(__name__)
46

57

6-
import subprocess
8+
class SecurityError(Exception):
9+
"""Raised when a security violation is detected."""
10+
pass
711

812

913
def run_command(command: str):
1014
"""
1115
Runs an external command and yields each line from stdout.
16+
17+
Security note: This function implements security checks to prevent
18+
command injection attacks.
1219
1320
Args:
1421
command: The command to run as a string.
1522
1623
Yields:
1724
Each line from the command's stdout.
25+
26+
Raises:
27+
SecurityError: If the command contains dangerous patterns.
28+
subprocess.CalledProcessError: If the command fails.
29+
"""
30+
# Security validation
31+
_validate_command_security(command)
32+
33+
logger.debug(f"Executing validated command: {command}")
34+
35+
# Use shell=False and split command properly to prevent injection
36+
try:
37+
# Split command safely using shlex
38+
command_parts = shlex.split(command)
39+
40+
# Additional validation on command parts
41+
if not command_parts:
42+
raise ValueError("Empty command provided")
43+
44+
# Check if the base command is in a safe list (optional additional security)
45+
base_command = command_parts[0]
46+
_validate_base_command(base_command)
47+
48+
process = subprocess.Popen(
49+
command_parts,
50+
stdout=subprocess.PIPE,
51+
stderr=subprocess.PIPE,
52+
text=True,
53+
shell=False # Critical: never use shell=True
54+
)
55+
56+
for line in process.stdout:
57+
logger.debug(f"Command output: {line.rstrip()}")
58+
yield line.rstrip() # Remove trailing newline
59+
60+
process.wait() # Wait for the command to complete
61+
62+
if process.returncode != 0:
63+
# Get stderr for better error reporting
64+
stderr_output = process.stderr.read() if process.stderr else "No error details available"
65+
logger.error(f"Command failed with return code {process.returncode}: {stderr_output}")
66+
raise subprocess.CalledProcessError(process.returncode, command)
67+
68+
logger.debug("Command completed successfully")
69+
70+
except subprocess.CalledProcessError:
71+
raise # Re-raise subprocess errors
72+
except Exception as e:
73+
logger.error(f"Error executing command '{command}': {e}")
74+
raise SecurityError(f"Command execution failed: {e}")
75+
76+
77+
def _validate_command_security(command: str):
78+
"""
79+
Validate that the command does not contain dangerous patterns.
80+
81+
Args:
82+
command: The command string to validate.
83+
84+
Raises:
85+
SecurityError: If dangerous patterns are detected.
86+
"""
87+
# Check for dangerous shell metacharacters and patterns
88+
dangerous_patterns = [
89+
';', # Command separator
90+
'&&', # Command chaining
91+
'||', # Command chaining
92+
'|', # Pipe (could be used maliciously)
93+
'$(', # Command substitution
94+
'`', # Command substitution (backticks)
95+
'>', # Redirection
96+
'<', # Redirection
97+
'&', # Background execution
98+
'\n', # Newline injection
99+
'\r', # Carriage return injection
100+
]
101+
102+
for pattern in dangerous_patterns:
103+
if pattern in command:
104+
raise SecurityError(f"Security violation: Command contains dangerous pattern '{pattern}'")
105+
106+
# Check for path traversal attempts
107+
if '..' in command or '~/' in command:
108+
raise SecurityError("Security violation: Command contains path traversal patterns")
109+
110+
# Check for attempts to access sensitive files
111+
sensitive_paths = ['/etc/passwd', '/etc/shadow', '/root/', '~root']
112+
command_lower = command.lower()
113+
for path in sensitive_paths:
114+
if path in command_lower:
115+
raise SecurityError(f"Security violation: Command attempts to access sensitive path '{path}'")
116+
117+
118+
def _validate_base_command(base_command: str):
119+
"""
120+
Validate that the base command is from an allowed list.
121+
122+
Args:
123+
base_command: The base command to validate.
124+
125+
Raises:
126+
SecurityError: If the command is not allowed.
18127
"""
19-
logger.debug(f"Executing command: {command}")
20-
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, shell=True)
21-
for line in process.stdout:
22-
logger.debug(f"Command output: {line.rstrip()}")
23-
yield line.rstrip() # Remove trailing newline
24-
process.wait() # Wait for the command to complete
25-
if process.returncode != 0:
26-
logger.error(f"Command failed with return code {process.returncode}")
27-
raise subprocess.CalledProcessError(process.returncode, command)
28-
logger.debug("Command completed successfully")
128+
# Define a whitelist of allowed commands (can be extended as needed)
129+
allowed_commands = {
130+
'ls', 'cat', 'echo', 'pwd', 'head', 'tail', 'grep', 'find', 'wc',
131+
'sort', 'uniq', 'cut', 'awk', 'sed', 'tr', 'date', 'whoami',
132+
'id', 'uptime', 'df', 'du', 'ps', 'top', 'free', 'mount',
133+
'python', 'python3', 'pip', 'git', 'curl', 'wget', 'ssh',
134+
'rsync', 'tar', 'gzip', 'gunzip', 'zip', 'unzip'
135+
}
136+
137+
# Extract just the command name (remove path if present)
138+
command_name = base_command.split('/')[-1]
139+
140+
if command_name not in allowed_commands:
141+
# Log the attempt for security monitoring
142+
logger.warning(f"Attempted execution of non-whitelisted command: {base_command}")
143+
raise SecurityError(f"Security violation: Command '{command_name}' is not in the allowed list")
144+
145+
logger.debug(f"Base command '{command_name}' validated successfully")

tests/talkpipe/util/test_util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ def test_run_command_error_handling():
5555
# Use a command that's extremely unlikely to exist
5656
command = "this_command_definitely_does_not_exist_12345"
5757

58-
# Should raise CalledProcessError or print error message
59-
with pytest.raises(subprocess.CalledProcessError):
58+
# Should raise CalledProcessError or SecurityError (for non-whitelisted commands)
59+
with pytest.raises((subprocess.CalledProcessError, talkpipe.util.os.SecurityError)):
6060
list(talkpipe.util.os.run_command(command))
6161

6262
def test_run_command_with_arguments():

0 commit comments

Comments
 (0)