1
1
"""
2
2
Security middleware for API protection
3
3
"""
4
- from typing import Callable
4
+ import time
5
+ import hashlib
6
+ import hmac
7
+ import json
8
+ from typing import Callable , Dict , Set , Optional
5
9
from starlette .middleware .base import BaseHTTPMiddleware
6
10
from starlette .requests import Request
7
- from starlette .responses import Response
11
+ from starlette .responses import Response , JSONResponse
8
12
from starlette .types import ASGIApp
13
+ import structlog
14
+
15
+ logger = structlog .get_logger ()
9
16
10
17
11
18
class SecurityHeadersMiddleware (BaseHTTPMiddleware ):
@@ -66,10 +73,19 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response:
66
73
return response
67
74
68
75
76
+ class APIKeyQuota :
77
+ """API Key quota configuration."""
78
+ def __init__ (self , calls_per_hour : int = 1000 , calls_per_day : int = 10000 ,
79
+ max_concurrent_jobs : int = 5 , max_file_size_mb : int = 1000 ):
80
+ self .calls_per_hour = calls_per_hour
81
+ self .calls_per_day = calls_per_day
82
+ self .max_concurrent_jobs = max_concurrent_jobs
83
+ self .max_file_size_mb = max_file_size_mb
84
+
85
+
69
86
class RateLimitMiddleware (BaseHTTPMiddleware ):
70
87
"""
71
- Simple rate limiting middleware for additional protection.
72
- Note: Primary rate limiting is handled by KrakenD API Gateway.
88
+ Enhanced rate limiting middleware with API key quotas.
73
89
"""
74
90
75
91
def __init__ (
@@ -78,61 +94,273 @@ def __init__(
78
94
calls : int = 1000 ,
79
95
period : int = 3600 , # 1 hour
80
96
enabled : bool = True ,
97
+ redis_client = None , # Redis client for distributed rate limiting
81
98
):
82
99
super ().__init__ (app )
83
100
self .calls = calls
84
101
self .period = period
85
102
self .enabled = enabled
86
- self .clients = {} # Simple in-memory store (use Redis in production)
103
+ self .redis_client = redis_client
104
+ self .clients = {} # Fallback in-memory store
105
+
106
+ # Default quotas for different API key tiers
107
+ self .default_quotas = {
108
+ 'free' : APIKeyQuota (calls_per_hour = 100 , calls_per_day = 1000 , max_concurrent_jobs = 2 , max_file_size_mb = 100 ),
109
+ 'basic' : APIKeyQuota (calls_per_hour = 500 , calls_per_day = 5000 , max_concurrent_jobs = 5 , max_file_size_mb = 500 ),
110
+ 'premium' : APIKeyQuota (calls_per_hour = 2000 , calls_per_day = 20000 , max_concurrent_jobs = 10 , max_file_size_mb = 2000 ),
111
+ 'enterprise' : APIKeyQuota (calls_per_hour = 10000 , calls_per_day = 100000 , max_concurrent_jobs = 50 , max_file_size_mb = 10000 )
112
+ }
87
113
88
114
async def dispatch (self , request : Request , call_next : Callable ) -> Response :
89
- """Apply rate limiting based on client IP ."""
115
+ """Apply enhanced rate limiting with API key quotas ."""
90
116
if not self .enabled :
91
117
return await call_next (request )
92
118
93
- # Get client IP
119
+ # Get client identifier (IP + API key if available)
94
120
client_ip = request .client .host
95
121
if "X-Forwarded-For" in request .headers :
96
122
client_ip = request .headers ["X-Forwarded-For" ].split ("," )[0 ].strip ()
97
123
98
- # Simple rate limiting logic (in production, use Redis)
124
+ api_key = request .headers .get ("X-API-Key" ) or request .query_params .get ("api_key" )
125
+ client_id = f"{ client_ip } :{ api_key } " if api_key else client_ip
126
+
127
+ # Get appropriate quota limits
128
+ quota = await self ._get_client_quota (api_key )
129
+
99
130
import time
100
131
current_time = time .time ()
132
+ hour_key = f"{ client_id } :hour:{ int (current_time // 3600 )} "
133
+ day_key = f"{ client_id } :day:{ int (current_time // 86400 )} "
101
134
102
- # Clean old entries (simple cleanup)
135
+ # Use Redis for distributed rate limiting if available
136
+ if self .redis_client :
137
+ try :
138
+ # Check hourly limit
139
+ hourly_count = await self .redis_client .get (hour_key ) or 0
140
+ daily_count = await self .redis_client .get (day_key ) or 0
141
+
142
+ hourly_count = int (hourly_count )
143
+ daily_count = int (daily_count )
144
+
145
+ # Check limits
146
+ if hourly_count >= quota .calls_per_hour :
147
+ return self ._rate_limit_response (quota .calls_per_hour , "hour" , hourly_count )
148
+
149
+ if daily_count >= quota .calls_per_day :
150
+ return self ._rate_limit_response (quota .calls_per_day , "day" , daily_count )
151
+
152
+ # Increment counters
153
+ await self .redis_client .incr (hour_key )
154
+ await self .redis_client .expire (hour_key , 3600 ) # 1 hour TTL
155
+ await self .redis_client .incr (day_key )
156
+ await self .redis_client .expire (day_key , 86400 ) # 1 day TTL
157
+
158
+ except Exception as e :
159
+ # Fall back to in-memory if Redis fails
160
+ import structlog
161
+ logger = structlog .get_logger ()
162
+ logger .warning ("Redis rate limiting failed, using fallback" , error = str (e ))
163
+ return await self ._fallback_rate_limiting (client_id , quota , current_time , call_next , request )
164
+ else :
165
+ # Use in-memory fallback
166
+ return await self ._fallback_rate_limiting (client_id , quota , current_time , call_next , request )
167
+
168
+ # Add rate limit headers
169
+ response = await call_next (request )
170
+ response .headers ["X-RateLimit-Limit-Hour" ] = str (quota .calls_per_hour )
171
+ response .headers ["X-RateLimit-Limit-Day" ] = str (quota .calls_per_day )
172
+ response .headers ["X-RateLimit-Remaining-Hour" ] = str (max (0 , quota .calls_per_hour - hourly_count - 1 ))
173
+ response .headers ["X-RateLimit-Remaining-Day" ] = str (max (0 , quota .calls_per_day - daily_count - 1 ))
174
+
175
+ return response
176
+
177
+ async def _get_client_quota (self , api_key : str = None ) -> APIKeyQuota :
178
+ """Get quota configuration for client based on API key tier."""
179
+ if not api_key :
180
+ return self .default_quotas ['free' ]
181
+
182
+ # In production, look up API key tier from database
183
+ # For now, return based on key prefix or default to basic
184
+ if api_key .startswith ('ent_' ):
185
+ return self .default_quotas ['enterprise' ]
186
+ elif api_key .startswith ('prem_' ):
187
+ return self .default_quotas ['premium' ]
188
+ elif api_key .startswith ('basic_' ):
189
+ return self .default_quotas ['basic' ]
190
+ else :
191
+ return self .default_quotas ['basic' ] # Default for unknown keys
192
+
193
+ def _rate_limit_response (self , limit : int , period : str , current_count : int ):
194
+ """Create rate limit exceeded response."""
195
+ from starlette .responses import JSONResponse
196
+ return JSONResponse (
197
+ status_code = 429 ,
198
+ content = {
199
+ "error" : {
200
+ "code" : "RATE_LIMIT_EXCEEDED" ,
201
+ "message" : f"Rate limit exceeded. Maximum { limit } requests per { period } ." ,
202
+ "type" : "RateLimitError" ,
203
+ "limit" : limit ,
204
+ "period" : period ,
205
+ "current_usage" : current_count
206
+ }
207
+ },
208
+ headers = {
209
+ f"X-RateLimit-Limit-{ period .title ()} " : str (limit ),
210
+ f"X-RateLimit-Remaining-{ period .title ()} " : "0" ,
211
+ "Retry-After" : "3600" if period == "hour" else "86400"
212
+ }
213
+ )
214
+
215
+ async def _fallback_rate_limiting (self , client_id : str , quota : APIKeyQuota ,
216
+ current_time : float , call_next : Callable , request : Request ):
217
+ """Fallback in-memory rate limiting when Redis is unavailable."""
218
+ # Clean old entries
103
219
self .clients = {
104
- ip : data for ip , data in self .clients .items ()
220
+ cid : data for cid , data in self .clients .items ()
105
221
if current_time - data ["window_start" ] < self .period
106
222
}
107
223
108
- # Check rate limit
109
- if client_ip in self .clients :
110
- client_data = self .clients [client_ip ]
224
+ # Check rate limit (simplified to hourly only for fallback)
225
+ if client_id in self .clients :
226
+ client_data = self .clients [client_id ]
111
227
if current_time - client_data ["window_start" ] < self .period :
112
- if client_data ["requests" ] >= self .calls :
113
- from starlette .responses import JSONResponse
114
- return JSONResponse (
115
- status_code = 429 ,
116
- content = {
117
- "error" : {
118
- "code" : "RATE_LIMIT_EXCEEDED" ,
119
- "message" : f"Rate limit exceeded. Maximum { self .calls } requests per hour." ,
120
- "type" : "RateLimitError"
121
- }
122
- }
123
- )
228
+ if client_data ["requests" ] >= quota .calls_per_hour :
229
+ return self ._rate_limit_response (quota .calls_per_hour , "hour" , client_data ["requests" ])
124
230
client_data ["requests" ] += 1
125
231
else :
126
232
# Reset window
127
- self .clients [client_ip ] = {
233
+ self .clients [client_id ] = {
128
234
"requests" : 1 ,
129
235
"window_start" : current_time
130
236
}
131
237
else :
132
238
# New client
133
- self .clients [client_ip ] = {
239
+ self .clients [client_id ] = {
134
240
"requests" : 1 ,
135
241
"window_start" : current_time
136
242
}
137
243
138
- return await call_next (request )
244
+ return await call_next (request )
245
+
246
+
247
+ class InputSanitizationMiddleware (BaseHTTPMiddleware ):
248
+ """Middleware for sanitizing and validating input data."""
249
+
250
+ def __init__ (self , app : ASGIApp , max_body_size : int = 100 * 1024 * 1024 ): # 100MB default
251
+ super ().__init__ (app )
252
+ self .max_body_size = max_body_size
253
+
254
+ async def dispatch (self , request : Request , call_next : Callable ) -> Response :
255
+ """Sanitize request data."""
256
+ try :
257
+ # Check content length
258
+ content_length = request .headers .get ('content-length' )
259
+ if content_length and int (content_length ) > self .max_body_size :
260
+ return JSONResponse (
261
+ status_code = 413 ,
262
+ content = {
263
+ "error" : {
264
+ "code" : "PAYLOAD_TOO_LARGE" ,
265
+ "message" : f"Request body too large. Maximum size: { self .max_body_size } bytes" ,
266
+ "type" : "RequestError"
267
+ }
268
+ }
269
+ )
270
+
271
+ # Validate Content-Type for POST/PUT requests
272
+ if request .method in ['POST' , 'PUT' , 'PATCH' ]:
273
+ content_type = request .headers .get ('content-type' , '' )
274
+ if not content_type .startswith (('application/json' , 'multipart/form-data' , 'application/x-www-form-urlencoded' )):
275
+ return JSONResponse (
276
+ status_code = 415 ,
277
+ content = {
278
+ "error" : {
279
+ "code" : "UNSUPPORTED_MEDIA_TYPE" ,
280
+ "message" : "Unsupported media type" ,
281
+ "type" : "RequestError"
282
+ }
283
+ }
284
+ )
285
+
286
+ return await call_next (request )
287
+
288
+ except Exception as e :
289
+ logger .error ("Input sanitization failed" , error = str (e ))
290
+ return JSONResponse (
291
+ status_code = 400 ,
292
+ content = {
293
+ "error" : {
294
+ "code" : "BAD_REQUEST" ,
295
+ "message" : "Invalid request format" ,
296
+ "type" : "RequestError"
297
+ }
298
+ }
299
+ )
300
+
301
+
302
+ class SecurityAuditMiddleware (BaseHTTPMiddleware ):
303
+ """Middleware for security auditing and monitoring."""
304
+
305
+ def __init__ (self , app : ASGIApp , log_suspicious_activity : bool = True ):
306
+ super ().__init__ (app )
307
+ self .log_suspicious_activity = log_suspicious_activity
308
+ self .suspicious_patterns = [
309
+ r'\.\./' , # Directory traversal
310
+ r'<script' , # XSS attempts
311
+ r'union\s+select' , # SQL injection
312
+ r'javascript:' , # XSS
313
+ r'eval\s*\(' , # Code injection
314
+ r'/etc/passwd' , # File access attempts
315
+ ]
316
+
317
+ async def dispatch (self , request : Request , call_next : Callable ) -> Response :
318
+ """Monitor and audit security events."""
319
+ start_time = time .time ()
320
+
321
+ # Check for suspicious patterns
322
+ if self .log_suspicious_activity :
323
+ self ._check_for_suspicious_activity (request )
324
+
325
+ response = await call_next (request )
326
+
327
+ # Log security events
328
+ processing_time = time .time () - start_time
329
+
330
+ if processing_time > 30 : # Slow request detection
331
+ logger .warning (
332
+ "Slow request detected" ,
333
+ path = request .url .path ,
334
+ processing_time = processing_time ,
335
+ client_ip = self ._get_client_ip (request )
336
+ )
337
+
338
+ if response .status_code == 401 :
339
+ logger .warning (
340
+ "Authentication failed" ,
341
+ path = request .url .path ,
342
+ client_ip = self ._get_client_ip (request )
343
+ )
344
+
345
+ return response
346
+
347
+ def _check_for_suspicious_activity (self , request : Request ):
348
+ """Check for suspicious patterns in the request."""
349
+ import re
350
+
351
+ # Check URL path
352
+ for pattern in self .suspicious_patterns :
353
+ if re .search (pattern , request .url .path , re .IGNORECASE ):
354
+ logger .warning (
355
+ "Suspicious pattern in URL" ,
356
+ pattern = pattern ,
357
+ url = request .url .path ,
358
+ client_ip = self ._get_client_ip (request )
359
+ )
360
+
361
+ def _get_client_ip (self , request : Request ) -> str :
362
+ """Get client IP address."""
363
+ forwarded_for = request .headers .get ('x-forwarded-for' )
364
+ if forwarded_for :
365
+ return forwarded_for .split (',' )[0 ].strip ()
366
+ return request .client .host if request .client else 'unknown'
0 commit comments