|
| 1 | +""" |
| 2 | +API Authentication and Rate Limiting Module |
| 3 | +Provides API key validation, rate limiting, and tier-based access control |
| 4 | +""" |
| 5 | +import functools |
| 6 | +import hashlib |
| 7 | +import secrets |
| 8 | +import time |
| 9 | +import json |
| 10 | +import os |
| 11 | +from datetime import datetime, timedelta |
| 12 | +from flask import request, jsonify, g |
| 13 | + |
| 14 | +# Rate limit storage (in production, use Redis) |
| 15 | +_rate_limit_storage = {} |
| 16 | +_api_keys_file = 'api_keys.json' |
| 17 | + |
| 18 | + |
| 19 | +class RateLimiter: |
| 20 | + """Rate limiter with tier-based limits""" |
| 21 | + |
| 22 | + # Rate limits by tier: (requests_per_minute, requests_per_day) |
| 23 | + TIER_LIMITS = { |
| 24 | + 'free': (10, 100), |
| 25 | + 'basic': (60, 1000), |
| 26 | + 'premium': (300, 10000), |
| 27 | + 'enterprise': (10000, 1000000) # Effectively unlimited |
| 28 | + } |
| 29 | + |
| 30 | + def __init__(self): |
| 31 | + self.storage = _rate_limit_storage |
| 32 | + |
| 33 | + def _get_window_key(self, api_key: str, window: str) -> str: |
| 34 | + """Generate storage key for rate limit window""" |
| 35 | + return f"{api_key}:{window}" |
| 36 | + |
| 37 | + def check_rate_limit(self, api_key: str, tier: str) -> tuple: |
| 38 | + """ |
| 39 | + Check if request is within rate limits |
| 40 | + Returns: (allowed: bool, remaining: int, reset_time: int) |
| 41 | + """ |
| 42 | + limits = self.TIER_LIMITS.get(tier, self.TIER_LIMITS['free']) |
| 43 | + per_minute, per_day = limits |
| 44 | + |
| 45 | + now = time.time() |
| 46 | + minute_window = int(now / 60) |
| 47 | + day_window = int(now / 86400) |
| 48 | + |
| 49 | + minute_key = self._get_window_key(api_key, f"min:{minute_window}") |
| 50 | + day_key = self._get_window_key(api_key, f"day:{day_window}") |
| 51 | + |
| 52 | + # Check minute limit |
| 53 | + minute_count = self.storage.get(minute_key, 0) |
| 54 | + if minute_count >= per_minute: |
| 55 | + reset_time = (minute_window + 1) * 60 |
| 56 | + return False, 0, int(reset_time - now) |
| 57 | + |
| 58 | + # Check daily limit |
| 59 | + day_count = self.storage.get(day_key, 0) |
| 60 | + if day_count >= per_day: |
| 61 | + reset_time = (day_window + 1) * 86400 |
| 62 | + return False, 0, int(reset_time - now) |
| 63 | + |
| 64 | + # Increment counters |
| 65 | + self.storage[minute_key] = minute_count + 1 |
| 66 | + self.storage[day_key] = day_count + 1 |
| 67 | + |
| 68 | + # Clean old entries periodically |
| 69 | + self._cleanup_old_entries(now) |
| 70 | + |
| 71 | + remaining = min(per_minute - minute_count - 1, per_day - day_count - 1) |
| 72 | + return True, remaining, 60 - int(now % 60) |
| 73 | + |
| 74 | + def _cleanup_old_entries(self, now: float): |
| 75 | + """Remove expired rate limit entries""" |
| 76 | + if len(self.storage) > 10000: # Only cleanup when storage is large |
| 77 | + current_minute = int(now / 60) |
| 78 | + current_day = int(now / 86400) |
| 79 | + keys_to_delete = [] |
| 80 | + |
| 81 | + for key in self.storage: |
| 82 | + if ':min:' in key: |
| 83 | + window = int(key.split(':')[-1]) |
| 84 | + if window < current_minute - 1: |
| 85 | + keys_to_delete.append(key) |
| 86 | + elif ':day:' in key: |
| 87 | + window = int(key.split(':')[-1]) |
| 88 | + if window < current_day - 1: |
| 89 | + keys_to_delete.append(key) |
| 90 | + |
| 91 | + for key in keys_to_delete: |
| 92 | + del self.storage[key] |
| 93 | + |
| 94 | + |
| 95 | +class APIKeyManager: |
| 96 | + """Manages API keys for authentication""" |
| 97 | + |
| 98 | + def __init__(self): |
| 99 | + self.keys_file = _api_keys_file |
| 100 | + self.keys = self._load_keys() |
| 101 | + |
| 102 | + def _load_keys(self) -> dict: |
| 103 | + """Load API keys from file""" |
| 104 | + if os.path.exists(self.keys_file): |
| 105 | + try: |
| 106 | + with open(self.keys_file, 'r') as f: |
| 107 | + return json.load(f) |
| 108 | + except: |
| 109 | + pass |
| 110 | + return {} |
| 111 | + |
| 112 | + def _save_keys(self): |
| 113 | + """Save API keys to file""" |
| 114 | + try: |
| 115 | + with open(self.keys_file, 'w') as f: |
| 116 | + json.dump(self.keys, f, indent=2) |
| 117 | + except Exception as e: |
| 118 | + print(f"Error saving API keys: {e}") |
| 119 | + |
| 120 | + def generate_api_key(self, user_id: str = None) -> str: |
| 121 | + """Generate a new API key""" |
| 122 | + key = f"arb_{secrets.token_hex(24)}" |
| 123 | + key_hash = hashlib.sha256(key.encode()).hexdigest() |
| 124 | + |
| 125 | + self.keys[key_hash] = { |
| 126 | + 'user_id': user_id or 'default', |
| 127 | + 'created_at': datetime.now().isoformat(), |
| 128 | + 'last_used': None, |
| 129 | + 'requests_count': 0, |
| 130 | + 'active': True |
| 131 | + } |
| 132 | + self._save_keys() |
| 133 | + return key |
| 134 | + |
| 135 | + def validate_api_key(self, api_key: str) -> dict: |
| 136 | + """ |
| 137 | + Validate API key and return key info |
| 138 | + Returns None if invalid |
| 139 | + """ |
| 140 | + if not api_key: |
| 141 | + return None |
| 142 | + |
| 143 | + # For demo/testing: accept any key starting with 'demo-' or 'test-' |
| 144 | + if api_key.startswith('demo-') or api_key.startswith('test-'): |
| 145 | + return { |
| 146 | + 'user_id': 'demo_user', |
| 147 | + 'tier': 'premium', |
| 148 | + 'active': True |
| 149 | + } |
| 150 | + |
| 151 | + key_hash = hashlib.sha256(api_key.encode()).hexdigest() |
| 152 | + key_info = self.keys.get(key_hash) |
| 153 | + |
| 154 | + if key_info and key_info.get('active', True): |
| 155 | + # Update usage stats |
| 156 | + key_info['last_used'] = datetime.now().isoformat() |
| 157 | + key_info['requests_count'] = key_info.get('requests_count', 0) + 1 |
| 158 | + self._save_keys() |
| 159 | + return key_info |
| 160 | + |
| 161 | + return None |
| 162 | + |
| 163 | + def revoke_api_key(self, api_key: str) -> bool: |
| 164 | + """Revoke an API key""" |
| 165 | + key_hash = hashlib.sha256(api_key.encode()).hexdigest() |
| 166 | + if key_hash in self.keys: |
| 167 | + self.keys[key_hash]['active'] = False |
| 168 | + self._save_keys() |
| 169 | + return True |
| 170 | + return False |
| 171 | + |
| 172 | + def list_keys(self, user_id: str = None) -> list: |
| 173 | + """List all API keys for a user (returns masked keys)""" |
| 174 | + keys = [] |
| 175 | + for key_hash, info in self.keys.items(): |
| 176 | + if user_id is None or info.get('user_id') == user_id: |
| 177 | + keys.append({ |
| 178 | + 'key_prefix': f"arb_...{key_hash[-8:]}", |
| 179 | + 'created_at': info.get('created_at'), |
| 180 | + 'last_used': info.get('last_used'), |
| 181 | + 'requests_count': info.get('requests_count', 0), |
| 182 | + 'active': info.get('active', True) |
| 183 | + }) |
| 184 | + return keys |
| 185 | + |
| 186 | + |
| 187 | +# Global instances |
| 188 | +rate_limiter = RateLimiter() |
| 189 | +api_key_manager = APIKeyManager() |
| 190 | + |
| 191 | + |
| 192 | +def require_api_key(required_tier: str = 'free'): |
| 193 | + """ |
| 194 | + Decorator to require API key authentication |
| 195 | + Also enforces rate limiting based on tier |
| 196 | + |
| 197 | + Args: |
| 198 | + required_tier: Minimum tier required ('free', 'basic', 'premium', 'enterprise') |
| 199 | + """ |
| 200 | + tier_hierarchy = ['free', 'basic', 'premium', 'enterprise'] |
| 201 | + |
| 202 | + def decorator(f): |
| 203 | + @functools.wraps(f) |
| 204 | + def decorated_function(*args, **kwargs): |
| 205 | + # Get API key from header |
| 206 | + api_key = request.headers.get('X-API-Key') |
| 207 | + |
| 208 | + if not api_key: |
| 209 | + return jsonify({ |
| 210 | + 'error': 'API key required', |
| 211 | + 'message': 'Please provide an API key in the X-API-Key header' |
| 212 | + }), 401 |
| 213 | + |
| 214 | + # Validate API key |
| 215 | + key_info = api_key_manager.validate_api_key(api_key) |
| 216 | + if not key_info: |
| 217 | + return jsonify({ |
| 218 | + 'error': 'Invalid API key', |
| 219 | + 'message': 'The provided API key is invalid or has been revoked' |
| 220 | + }), 401 |
| 221 | + |
| 222 | + # Get subscription tier |
| 223 | + try: |
| 224 | + from src.subscription import SubscriptionManager |
| 225 | + sub_manager = SubscriptionManager() |
| 226 | + tier = sub_manager.get_tier() |
| 227 | + except: |
| 228 | + tier = key_info.get('tier', 'free') |
| 229 | + |
| 230 | + # Check tier requirements |
| 231 | + if tier_hierarchy.index(tier) < tier_hierarchy.index(required_tier): |
| 232 | + return jsonify({ |
| 233 | + 'error': 'Insufficient tier', |
| 234 | + 'message': f'This endpoint requires {required_tier} tier or higher', |
| 235 | + 'current_tier': tier, |
| 236 | + 'required_tier': required_tier |
| 237 | + }), 403 |
| 238 | + |
| 239 | + # Check rate limit |
| 240 | + allowed, remaining, reset_time = rate_limiter.check_rate_limit(api_key, tier) |
| 241 | + |
| 242 | + if not allowed: |
| 243 | + response = jsonify({ |
| 244 | + 'error': 'Rate limit exceeded', |
| 245 | + 'message': 'Too many requests. Please try again later.', |
| 246 | + 'retry_after': reset_time |
| 247 | + }) |
| 248 | + response.headers['X-RateLimit-Remaining'] = '0' |
| 249 | + response.headers['X-RateLimit-Reset'] = str(reset_time) |
| 250 | + response.headers['Retry-After'] = str(reset_time) |
| 251 | + return response, 429 |
| 252 | + |
| 253 | + # Store user info in g for use in endpoint |
| 254 | + g.api_key = api_key |
| 255 | + g.user_id = key_info.get('user_id', 'unknown') |
| 256 | + g.tier = tier |
| 257 | + |
| 258 | + # Call the actual function |
| 259 | + response = f(*args, **kwargs) |
| 260 | + |
| 261 | + # Add rate limit headers to response |
| 262 | + if hasattr(response, 'headers'): |
| 263 | + response.headers['X-RateLimit-Remaining'] = str(remaining) |
| 264 | + response.headers['X-RateLimit-Reset'] = str(reset_time) |
| 265 | + |
| 266 | + return response |
| 267 | + |
| 268 | + return decorated_function |
| 269 | + return decorator |
| 270 | + |
| 271 | + |
| 272 | +def log_api_request(endpoint: str, method: str, status_code: int, response_time_ms: float): |
| 273 | + """Log API request for analytics""" |
| 274 | + try: |
| 275 | + log_entry = { |
| 276 | + 'timestamp': datetime.now().isoformat(), |
| 277 | + 'endpoint': endpoint, |
| 278 | + 'method': method, |
| 279 | + 'status_code': status_code, |
| 280 | + 'response_time_ms': response_time_ms, |
| 281 | + 'user_id': getattr(g, 'user_id', 'anonymous'), |
| 282 | + 'tier': getattr(g, 'tier', 'unknown') |
| 283 | + } |
| 284 | + |
| 285 | + # Append to log file |
| 286 | + with open('api_requests.log', 'a') as f: |
| 287 | + f.write(json.dumps(log_entry) + '\n') |
| 288 | + except: |
| 289 | + pass # Don't fail on logging errors |
0 commit comments