Skip to content

Commit 51d5e1f

Browse files
committed
Multiple Fixes
1 parent 8c2e1a6 commit 51d5e1f

File tree

6 files changed

+1666
-57
lines changed

6 files changed

+1666
-57
lines changed

api/middleware/security.py

Lines changed: 256 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
"""
22
Security middleware for API protection
33
"""
4-
from typing import Callable
4+
import time
5+
import hashlib
6+
import hmac
7+
import json
8+
from typing import Callable, Dict, Set, Optional
59
from starlette.middleware.base import BaseHTTPMiddleware
610
from starlette.requests import Request
7-
from starlette.responses import Response
11+
from starlette.responses import Response, JSONResponse
812
from starlette.types import ASGIApp
13+
import structlog
14+
15+
logger = structlog.get_logger()
916

1017

1118
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
@@ -66,10 +73,19 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response:
6673
return response
6774

6875

76+
class APIKeyQuota:
77+
"""API Key quota configuration."""
78+
def __init__(self, calls_per_hour: int = 1000, calls_per_day: int = 10000,
79+
max_concurrent_jobs: int = 5, max_file_size_mb: int = 1000):
80+
self.calls_per_hour = calls_per_hour
81+
self.calls_per_day = calls_per_day
82+
self.max_concurrent_jobs = max_concurrent_jobs
83+
self.max_file_size_mb = max_file_size_mb
84+
85+
6986
class RateLimitMiddleware(BaseHTTPMiddleware):
7087
"""
71-
Simple rate limiting middleware for additional protection.
72-
Note: Primary rate limiting is handled by KrakenD API Gateway.
88+
Enhanced rate limiting middleware with API key quotas.
7389
"""
7490

7591
def __init__(
@@ -78,61 +94,273 @@ def __init__(
7894
calls: int = 1000,
7995
period: int = 3600, # 1 hour
8096
enabled: bool = True,
97+
redis_client = None, # Redis client for distributed rate limiting
8198
):
8299
super().__init__(app)
83100
self.calls = calls
84101
self.period = period
85102
self.enabled = enabled
86-
self.clients = {} # Simple in-memory store (use Redis in production)
103+
self.redis_client = redis_client
104+
self.clients = {} # Fallback in-memory store
105+
106+
# Default quotas for different API key tiers
107+
self.default_quotas = {
108+
'free': APIKeyQuota(calls_per_hour=100, calls_per_day=1000, max_concurrent_jobs=2, max_file_size_mb=100),
109+
'basic': APIKeyQuota(calls_per_hour=500, calls_per_day=5000, max_concurrent_jobs=5, max_file_size_mb=500),
110+
'premium': APIKeyQuota(calls_per_hour=2000, calls_per_day=20000, max_concurrent_jobs=10, max_file_size_mb=2000),
111+
'enterprise': APIKeyQuota(calls_per_hour=10000, calls_per_day=100000, max_concurrent_jobs=50, max_file_size_mb=10000)
112+
}
87113

88114
async def dispatch(self, request: Request, call_next: Callable) -> Response:
89-
"""Apply rate limiting based on client IP."""
115+
"""Apply enhanced rate limiting with API key quotas."""
90116
if not self.enabled:
91117
return await call_next(request)
92118

93-
# Get client IP
119+
# Get client identifier (IP + API key if available)
94120
client_ip = request.client.host
95121
if "X-Forwarded-For" in request.headers:
96122
client_ip = request.headers["X-Forwarded-For"].split(",")[0].strip()
97123

98-
# Simple rate limiting logic (in production, use Redis)
124+
api_key = request.headers.get("X-API-Key") or request.query_params.get("api_key")
125+
client_id = f"{client_ip}:{api_key}" if api_key else client_ip
126+
127+
# Get appropriate quota limits
128+
quota = await self._get_client_quota(api_key)
129+
99130
import time
100131
current_time = time.time()
132+
hour_key = f"{client_id}:hour:{int(current_time // 3600)}"
133+
day_key = f"{client_id}:day:{int(current_time // 86400)}"
101134

102-
# Clean old entries (simple cleanup)
135+
# Use Redis for distributed rate limiting if available
136+
if self.redis_client:
137+
try:
138+
# Check hourly limit
139+
hourly_count = await self.redis_client.get(hour_key) or 0
140+
daily_count = await self.redis_client.get(day_key) or 0
141+
142+
hourly_count = int(hourly_count)
143+
daily_count = int(daily_count)
144+
145+
# Check limits
146+
if hourly_count >= quota.calls_per_hour:
147+
return self._rate_limit_response(quota.calls_per_hour, "hour", hourly_count)
148+
149+
if daily_count >= quota.calls_per_day:
150+
return self._rate_limit_response(quota.calls_per_day, "day", daily_count)
151+
152+
# Increment counters
153+
await self.redis_client.incr(hour_key)
154+
await self.redis_client.expire(hour_key, 3600) # 1 hour TTL
155+
await self.redis_client.incr(day_key)
156+
await self.redis_client.expire(day_key, 86400) # 1 day TTL
157+
158+
except Exception as e:
159+
# Fall back to in-memory if Redis fails
160+
import structlog
161+
logger = structlog.get_logger()
162+
logger.warning("Redis rate limiting failed, using fallback", error=str(e))
163+
return await self._fallback_rate_limiting(client_id, quota, current_time, call_next, request)
164+
else:
165+
# Use in-memory fallback
166+
return await self._fallback_rate_limiting(client_id, quota, current_time, call_next, request)
167+
168+
# Add rate limit headers
169+
response = await call_next(request)
170+
response.headers["X-RateLimit-Limit-Hour"] = str(quota.calls_per_hour)
171+
response.headers["X-RateLimit-Limit-Day"] = str(quota.calls_per_day)
172+
response.headers["X-RateLimit-Remaining-Hour"] = str(max(0, quota.calls_per_hour - hourly_count - 1))
173+
response.headers["X-RateLimit-Remaining-Day"] = str(max(0, quota.calls_per_day - daily_count - 1))
174+
175+
return response
176+
177+
async def _get_client_quota(self, api_key: str = None) -> APIKeyQuota:
178+
"""Get quota configuration for client based on API key tier."""
179+
if not api_key:
180+
return self.default_quotas['free']
181+
182+
# In production, look up API key tier from database
183+
# For now, return based on key prefix or default to basic
184+
if api_key.startswith('ent_'):
185+
return self.default_quotas['enterprise']
186+
elif api_key.startswith('prem_'):
187+
return self.default_quotas['premium']
188+
elif api_key.startswith('basic_'):
189+
return self.default_quotas['basic']
190+
else:
191+
return self.default_quotas['basic'] # Default for unknown keys
192+
193+
def _rate_limit_response(self, limit: int, period: str, current_count: int):
194+
"""Create rate limit exceeded response."""
195+
from starlette.responses import JSONResponse
196+
return JSONResponse(
197+
status_code=429,
198+
content={
199+
"error": {
200+
"code": "RATE_LIMIT_EXCEEDED",
201+
"message": f"Rate limit exceeded. Maximum {limit} requests per {period}.",
202+
"type": "RateLimitError",
203+
"limit": limit,
204+
"period": period,
205+
"current_usage": current_count
206+
}
207+
},
208+
headers={
209+
f"X-RateLimit-Limit-{period.title()}": str(limit),
210+
f"X-RateLimit-Remaining-{period.title()}": "0",
211+
"Retry-After": "3600" if period == "hour" else "86400"
212+
}
213+
)
214+
215+
async def _fallback_rate_limiting(self, client_id: str, quota: APIKeyQuota,
216+
current_time: float, call_next: Callable, request: Request):
217+
"""Fallback in-memory rate limiting when Redis is unavailable."""
218+
# Clean old entries
103219
self.clients = {
104-
ip: data for ip, data in self.clients.items()
220+
cid: data for cid, data in self.clients.items()
105221
if current_time - data["window_start"] < self.period
106222
}
107223

108-
# Check rate limit
109-
if client_ip in self.clients:
110-
client_data = self.clients[client_ip]
224+
# Check rate limit (simplified to hourly only for fallback)
225+
if client_id in self.clients:
226+
client_data = self.clients[client_id]
111227
if current_time - client_data["window_start"] < self.period:
112-
if client_data["requests"] >= self.calls:
113-
from starlette.responses import JSONResponse
114-
return JSONResponse(
115-
status_code=429,
116-
content={
117-
"error": {
118-
"code": "RATE_LIMIT_EXCEEDED",
119-
"message": f"Rate limit exceeded. Maximum {self.calls} requests per hour.",
120-
"type": "RateLimitError"
121-
}
122-
}
123-
)
228+
if client_data["requests"] >= quota.calls_per_hour:
229+
return self._rate_limit_response(quota.calls_per_hour, "hour", client_data["requests"])
124230
client_data["requests"] += 1
125231
else:
126232
# Reset window
127-
self.clients[client_ip] = {
233+
self.clients[client_id] = {
128234
"requests": 1,
129235
"window_start": current_time
130236
}
131237
else:
132238
# New client
133-
self.clients[client_ip] = {
239+
self.clients[client_id] = {
134240
"requests": 1,
135241
"window_start": current_time
136242
}
137243

138-
return await call_next(request)
244+
return await call_next(request)
245+
246+
247+
class InputSanitizationMiddleware(BaseHTTPMiddleware):
248+
"""Middleware for sanitizing and validating input data."""
249+
250+
def __init__(self, app: ASGIApp, max_body_size: int = 100 * 1024 * 1024): # 100MB default
251+
super().__init__(app)
252+
self.max_body_size = max_body_size
253+
254+
async def dispatch(self, request: Request, call_next: Callable) -> Response:
255+
"""Sanitize request data."""
256+
try:
257+
# Check content length
258+
content_length = request.headers.get('content-length')
259+
if content_length and int(content_length) > self.max_body_size:
260+
return JSONResponse(
261+
status_code=413,
262+
content={
263+
"error": {
264+
"code": "PAYLOAD_TOO_LARGE",
265+
"message": f"Request body too large. Maximum size: {self.max_body_size} bytes",
266+
"type": "RequestError"
267+
}
268+
}
269+
)
270+
271+
# Validate Content-Type for POST/PUT requests
272+
if request.method in ['POST', 'PUT', 'PATCH']:
273+
content_type = request.headers.get('content-type', '')
274+
if not content_type.startswith(('application/json', 'multipart/form-data', 'application/x-www-form-urlencoded')):
275+
return JSONResponse(
276+
status_code=415,
277+
content={
278+
"error": {
279+
"code": "UNSUPPORTED_MEDIA_TYPE",
280+
"message": "Unsupported media type",
281+
"type": "RequestError"
282+
}
283+
}
284+
)
285+
286+
return await call_next(request)
287+
288+
except Exception as e:
289+
logger.error("Input sanitization failed", error=str(e))
290+
return JSONResponse(
291+
status_code=400,
292+
content={
293+
"error": {
294+
"code": "BAD_REQUEST",
295+
"message": "Invalid request format",
296+
"type": "RequestError"
297+
}
298+
}
299+
)
300+
301+
302+
class SecurityAuditMiddleware(BaseHTTPMiddleware):
303+
"""Middleware for security auditing and monitoring."""
304+
305+
def __init__(self, app: ASGIApp, log_suspicious_activity: bool = True):
306+
super().__init__(app)
307+
self.log_suspicious_activity = log_suspicious_activity
308+
self.suspicious_patterns = [
309+
r'\.\./', # Directory traversal
310+
r'<script', # XSS attempts
311+
r'union\s+select', # SQL injection
312+
r'javascript:', # XSS
313+
r'eval\s*\(', # Code injection
314+
r'/etc/passwd', # File access attempts
315+
]
316+
317+
async def dispatch(self, request: Request, call_next: Callable) -> Response:
318+
"""Monitor and audit security events."""
319+
start_time = time.time()
320+
321+
# Check for suspicious patterns
322+
if self.log_suspicious_activity:
323+
self._check_for_suspicious_activity(request)
324+
325+
response = await call_next(request)
326+
327+
# Log security events
328+
processing_time = time.time() - start_time
329+
330+
if processing_time > 30: # Slow request detection
331+
logger.warning(
332+
"Slow request detected",
333+
path=request.url.path,
334+
processing_time=processing_time,
335+
client_ip=self._get_client_ip(request)
336+
)
337+
338+
if response.status_code == 401:
339+
logger.warning(
340+
"Authentication failed",
341+
path=request.url.path,
342+
client_ip=self._get_client_ip(request)
343+
)
344+
345+
return response
346+
347+
def _check_for_suspicious_activity(self, request: Request):
348+
"""Check for suspicious patterns in the request."""
349+
import re
350+
351+
# Check URL path
352+
for pattern in self.suspicious_patterns:
353+
if re.search(pattern, request.url.path, re.IGNORECASE):
354+
logger.warning(
355+
"Suspicious pattern in URL",
356+
pattern=pattern,
357+
url=request.url.path,
358+
client_ip=self._get_client_ip(request)
359+
)
360+
361+
def _get_client_ip(self, request: Request) -> str:
362+
"""Get client IP address."""
363+
forwarded_for = request.headers.get('x-forwarded-for')
364+
if forwarded_for:
365+
return forwarded_for.split(',')[0].strip()
366+
return request.client.host if request.client else 'unknown'

0 commit comments

Comments
 (0)