|
| 1 | +from flask import Flask, request, Response |
| 2 | +import requests |
| 3 | +import time |
| 4 | +import wakeonlan |
| 5 | +import os |
| 6 | +import yaml |
| 7 | +import logging |
| 8 | +import sys |
| 9 | +from urllib.parse import urljoin |
| 10 | +from waitress import serve |
| 11 | + |
| 12 | +# Set up logging |
| 13 | +LOG_LEVEL = os.getenv('LOG_LEVEL', 'INFO').upper() |
| 14 | +logging.basicConfig(level=LOG_LEVEL, format='%(asctime)s - %(levelname)s - %(message)s') |
| 15 | +logger = logging.getLogger(__name__) |
| 16 | + |
| 17 | +app = Flask(__name__) |
| 18 | + |
| 19 | +# Configuration |
| 20 | +SERVICES_CONFIG_PATH = os.getenv('SERVICES_CONFIG_PATH', '/config/services.yaml') |
| 21 | +GLOBAL_POLL_INTERVAL = int(os.getenv('GLOBAL_POLL_INTERVAL', 5)) # seconds |
| 22 | +GLOBAL_MAX_RETRIES = int(os.getenv('GLOBAL_MAX_RETRIES', 10)) |
| 23 | +SERVER_PORT = int(os.getenv('SERVER_PORT', 3000)) |
| 24 | +GLOBAL_REQUEST_TIMEOUT = int(os.getenv('GLOBAL_REQUEST_TIMEOUT', 5)) # seconds |
| 25 | +GLOBAL_AWAKE_REQUEST_TIMEOUT = int(os.getenv('GLOBAL_AWAKE_REQUEST_TIMEOUT', GLOBAL_REQUEST_TIMEOUT)) # seconds |
| 26 | + |
| 27 | +def get_service_configs(): |
| 28 | + |
| 29 | + service_configs = {} |
| 30 | + logger.debug(f"Loading service configurations from {SERVICES_CONFIG_PATH} and environment variables") |
| 31 | + |
| 32 | + # Load from YAML file if it exists |
| 33 | + yaml_file_path = SERVICES_CONFIG_PATH |
| 34 | + if os.path.exists(yaml_file_path): |
| 35 | + logger.debug(f"Found YAML config file at {yaml_file_path}") |
| 36 | + with open(yaml_file_path, 'r') as file: |
| 37 | + yaml_config = yaml.safe_load(file) |
| 38 | + logger.debug(f"Loaded YAML config: {yaml_config}") |
| 39 | + for host, config in yaml_config.items(): |
| 40 | + host = host.lower() |
| 41 | + service_configs[host] = { |
| 42 | + "base_url": config.get("base_url"), |
| 43 | + "awake_check_endpoint": config.get("awake_check_endpoint"), |
| 44 | + "mac_address": config.get("mac_address"), |
| 45 | + "poll_interval": int(config.get("poll_interval", GLOBAL_POLL_INTERVAL)), |
| 46 | + "max_retries": int(config.get("max_retries", GLOBAL_MAX_RETRIES)), |
| 47 | + "request_timeout": int(config.get("request_timeout", GLOBAL_REQUEST_TIMEOUT)), |
| 48 | + "awake_request_timeout": int(config.get("awake_request_timeout", GLOBAL_AWAKE_REQUEST_TIMEOUT)), |
| 49 | + } |
| 50 | + logger.debug(f"Added service from YAML: {host} with config: {service_configs[host]}") |
| 51 | + else: |
| 52 | + logger.warning(f"YAML config file not found at {yaml_file_path}") |
| 53 | + |
| 54 | + # Process environment variables to override or add service configurations |
| 55 | + env_override_count = 0 |
| 56 | + |
| 57 | + PREFIX = 'SERVICE_' |
| 58 | + VALID_ENV_SUFFIXES = { |
| 59 | + "BASE_URL", |
| 60 | + "AWAKE_CHECK_ENDPOINT", |
| 61 | + "MAC_ADDRESS", |
| 62 | + "POLL_INTERVAL", |
| 63 | + "MAX_RETRIES", |
| 64 | + "REQUEST_TIMEOUT", |
| 65 | + "AWAKE_REQUEST_TIMEOUT", |
| 66 | + } |
| 67 | + SUFFIX_TO_CONFIG_KEY = {s: s.lower() for s in VALID_ENV_SUFFIXES} |
| 68 | + NUMERIC_CONFIG_KEYS = { |
| 69 | + "poll_interval", |
| 70 | + "max_retries", |
| 71 | + "request_timeout", |
| 72 | + "awake_request_timeout", |
| 73 | + } |
| 74 | + DEFAULT_CONFIG_TEMPLATE = { |
| 75 | + "base_url": None, |
| 76 | + "awake_check_endpoint": None, |
| 77 | + "mac_address": None, |
| 78 | + "poll_interval": GLOBAL_POLL_INTERVAL, |
| 79 | + "max_retries": GLOBAL_MAX_RETRIES, |
| 80 | + "request_timeout": GLOBAL_REQUEST_TIMEOUT, |
| 81 | + "awake_request_timeout": GLOBAL_AWAKE_REQUEST_TIMEOUT, |
| 82 | + } |
| 83 | + |
| 84 | + env_override_count = 0 |
| 85 | + |
| 86 | + |
| 87 | + logger.info("Scanning environment variables for service configurations...") |
| 88 | + |
| 89 | + VALID_SUFFIX_MAP = {'_' + s: s for s in VALID_ENV_SUFFIXES} |
| 90 | + |
| 91 | + for key, value in os.environ.items(): |
| 92 | + if not key.startswith(PREFIX): |
| 93 | + continue # Skip variables not starting with the prefix |
| 94 | + |
| 95 | + matched_env_suffix = None |
| 96 | + extracted_host = None |
| 97 | + |
| 98 | + for suffix_with_underscore, original_suffix in VALID_SUFFIX_MAP.items(): |
| 99 | + if key.endswith(suffix_with_underscore): |
| 100 | + # Check if this is the longest match. If suffixes overlap, this is needed to make sure this is the longest match. |
| 101 | + # For current valid suffixes, it doesn't matter, but we'll do this in case overlapping keys are added. |
| 102 | + if matched_env_suffix is None or len(original_suffix) > len(matched_env_suffix): |
| 103 | + matched_env_suffix = original_suffix |
| 104 | + |
| 105 | + if matched_env_suffix: |
| 106 | + suffix_with_underscore = '_' + matched_env_suffix |
| 107 | + end_of_host_pos = len(key) - len(suffix_with_underscore) |
| 108 | + host_part = key[len(PREFIX):end_of_host_pos] |
| 109 | + |
| 110 | + if not host_part: |
| 111 | + logger.warning(f"Ignoring env var '{key}': Contains prefix and valid suffix but no host.") |
| 112 | + continue |
| 113 | + |
| 114 | + host = host_part.lower() |
| 115 | + config_key = SUFFIX_TO_CONFIG_KEY[matched_env_suffix] |
| 116 | + |
| 117 | + host_config = service_configs.setdefault(host, DEFAULT_CONFIG_TEMPLATE.copy()) |
| 118 | + |
| 119 | + try: |
| 120 | + if config_key in NUMERIC_CONFIG_KEYS: |
| 121 | + host_config[config_key] = int(value) |
| 122 | + else: |
| 123 | + host_config[config_key] = value |
| 124 | + |
| 125 | + logger.debug(f"Applied config: host='{host}', key='{config_key}', value='{host_config[config_key]}' (from env var '{key}')") |
| 126 | + env_override_count += 1 |
| 127 | + |
| 128 | + except (ValueError, TypeError) as e: |
| 129 | + logger.error(f"Failed to apply config from env var '{key}={value}' for host '{host}', key '{config_key}': Invalid value format - {e}") |
| 130 | + |
| 131 | + else: |
| 132 | + logger.debug(f"Ignoring env var '{key}': Does not end with a known service suffix.") |
| 133 | + |
| 134 | + |
| 135 | + logger.info(f"Finished scanning environment variables. Applied {env_override_count} overrides.") |
| 136 | + |
| 137 | + # Validate all services have required configuration |
| 138 | + valid_services = {} |
| 139 | + for host, config in service_configs.items(): |
| 140 | + if not (config["base_url"] and config["awake_check_endpoint"] and config["mac_address"]): |
| 141 | + logger.error(f"Service {host} does not have all of the required config values: base URL, awake check endpoint, and MAC address") |
| 142 | + continue |
| 143 | + valid_services[host] = config |
| 144 | + |
| 145 | + if env_override_count > 0: |
| 146 | + logger.debug(f"Applied {env_override_count} configuration overrides from environment variables") |
| 147 | + else: |
| 148 | + logger.debug("No service configurations found in environment variables") |
| 149 | + |
| 150 | + logger.info(f"Loaded {len(valid_services)} total valid services: {list(valid_services.keys())}") |
| 151 | + return valid_services |
| 152 | + |
| 153 | +service_configs = get_service_configs() |
| 154 | + |
| 155 | +def send_wol_packet(mac_address): |
| 156 | + logger.info(f"Sending WoL packet to {mac_address}") |
| 157 | + wakeonlan.send_magic_packet(mac_address) |
| 158 | + |
| 159 | +def is_server_awake(url, timeout): |
| 160 | + try: |
| 161 | + # Simple GET request for awake check. Connection pooling handles cleanup. |
| 162 | + response = requests.request( |
| 163 | + method='GET', |
| 164 | + url=url, |
| 165 | + timeout=timeout |
| 166 | + ) |
| 167 | + logger.debug(f"Awake check to {url} status: {response.status_code}") |
| 168 | + if 200 <= response.status_code < 300: |
| 169 | + logger.info(f"Server at {url} is awake (status: {response.status_code})") |
| 170 | + return True |
| 171 | + else: |
| 172 | + logger.info(f"Server at {url} responded status {response.status_code}. Considering not awake.") |
| 173 | + return False |
| 174 | + except requests.RequestException as e: |
| 175 | + logger.info(f"Awake check to {url} failed: {e}") |
| 176 | + return False |
| 177 | + |
| 178 | +@app.route('/', defaults={'path': ''}, methods=['GET', 'POST']) |
| 179 | +@app.route('/<path:path>', methods=['GET', 'POST']) |
| 180 | +def proxy_request(path): |
| 181 | + original_request = request |
| 182 | + data = original_request.data |
| 183 | + headers = {key: value for (key, value) in original_request.headers if key != 'Host'} |
| 184 | + |
| 185 | + logger.debug(f"Received request for path: {path}, method: {original_request.method}") |
| 186 | + logger.debug(f"Request headers: {original_request.headers}") |
| 187 | + |
| 188 | + host_header = original_request.headers.get('Host') |
| 189 | + if not host_header: |
| 190 | + logger.error("Host header is missing.") |
| 191 | + return "Host header is missing.", 400 |
| 192 | + |
| 193 | + logger.debug(f"Processing request with Host header: {host_header}") |
| 194 | + logger.debug(f"Available services: {list(service_configs.keys())}") |
| 195 | + |
| 196 | + # Check if the host matches any of the configured services |
| 197 | + target_service = None |
| 198 | + for identifier in service_configs: |
| 199 | + logger.debug(f"Comparing host header '{host_header}' with service identifier '{identifier}'") |
| 200 | + if host_header == identifier: |
| 201 | + target_service = identifier |
| 202 | + logger.debug(f"Found matching service: {target_service}") |
| 203 | + break |
| 204 | + |
| 205 | + if not target_service: |
| 206 | + logger.error(f"Unknown target service: {host_header}. Available services: {list(service_configs.keys())}") |
| 207 | + return f"Unknown target service: {host_header}.", 404 |
| 208 | + |
| 209 | + config = service_configs.get(target_service) |
| 210 | + if not config: |
| 211 | + logger.error(f"Unknown target service: {target_service}. This should not happen as we already checked the service exists.") |
| 212 | + return f"Unknown target service: {target_service}.", 404 |
| 213 | + |
| 214 | + logger.debug(f"Using configuration for service {target_service}: {config}") |
| 215 | + base_url = config["base_url"] |
| 216 | + destination_url = urljoin(base_url, request.full_path) |
| 217 | + awake_check_endpoint = config["awake_check_endpoint"] |
| 218 | + awake_check_url = urljoin(base_url, awake_check_endpoint) |
| 219 | + mac_address = config["mac_address"] |
| 220 | + poll_interval = config["poll_interval"] |
| 221 | + max_retries = config["max_retries"] |
| 222 | + request_timeout = config["request_timeout"] |
| 223 | + awake_request_timeout = config["awake_request_timeout"] |
| 224 | + |
| 225 | + # Poll until the server is awake |
| 226 | + retries = -1 # First try is not a retry, so start from -1 |
| 227 | + server_awake = False |
| 228 | + |
| 229 | + while retries < max_retries: |
| 230 | + if retries > -1: |
| 231 | + logger.info(f"Server {target_service} is not awake. Sending wake-on-LAN magic packet and retrying in {poll_interval} seconds...") |
| 232 | + send_wol_packet(mac_address) |
| 233 | + time.sleep(poll_interval) |
| 234 | + |
| 235 | + server_awake = is_server_awake(awake_check_url, awake_request_timeout) |
| 236 | + |
| 237 | + if server_awake: |
| 238 | + break |
| 239 | + |
| 240 | + retries += 1 |
| 241 | + |
| 242 | + if not server_awake: |
| 243 | + logger.error(f"Failed to reach the server {target_service} after {max_retries} attempts.") |
| 244 | + return f"Failed to reach the server {target_service} after {max_retries} attempts.", 503 |
| 245 | + |
| 246 | + # Make the actual request |
| 247 | + try: |
| 248 | + response = requests.request( |
| 249 | + method=original_request.method, |
| 250 | + url=destination_url, |
| 251 | + data=data, |
| 252 | + headers=headers, |
| 253 | + timeout=request_timeout, # Timeout for connection/initial read; should not time out in middle of stream |
| 254 | + stream=True # Always use stream=True to handle both streaming and non-streaming responses robustly |
| 255 | + ) |
| 256 | + logger.info(f"Proxying response from {destination_url} with status code: {response.status_code}") |
| 257 | + |
| 258 | + # Filter out hop-by-hop headers that shouldn't be forwarded directly. |
| 259 | + # Let Flask/Waitress handle Content-Length or Transfer-Encoding as needed. |
| 260 | + excluded_headers = ['content-encoding', 'content-length', 'transfer-encoding', 'connection'] |
| 261 | + response_headers = [(name, value) for (name, value) in response.headers.items() |
| 262 | + if name.lower() not in excluded_headers] |
| 263 | + |
| 264 | + # Create a streaming Flask response using iter_content. |
| 265 | + # This works correctly whether the original response from the downstream server |
| 266 | + # was chunked (streaming) or had a fixed Content-Length (non-streaming). |
| 267 | + return Response(response.iter_content(chunk_size=8192), status=response.status_code, headers=response_headers) |
| 268 | + except requests.RequestException as e: |
| 269 | + logger.error(f"Request failed after server {target_service} woke up: {e}") |
| 270 | + return f"Failed to reach the server {target_service} after it woke up.", 503 |
| 271 | + |
| 272 | +if __name__ == '__main__': |
| 273 | + logger.info(f"Starting wake-on-http server on port {SERVER_PORT}") |
| 274 | + logger.info(f"Log level set to: {LOG_LEVEL}") |
| 275 | + logger.info(f"Services config path: {SERVICES_CONFIG_PATH}") |
| 276 | + logger.info(f"Global poll interval: {GLOBAL_POLL_INTERVAL} seconds") |
| 277 | + logger.info(f"Global max retries: {GLOBAL_MAX_RETRIES}") |
| 278 | + logger.info(f"Global request timeout: {GLOBAL_REQUEST_TIMEOUT} seconds") |
| 279 | + logger.info(f"Global awake request timeout: {GLOBAL_AWAKE_REQUEST_TIMEOUT} seconds") |
| 280 | + |
| 281 | + if not service_configs: |
| 282 | + logger.error("No services configured! Please provide configuration via services.yaml or environment variables. Exiting.") |
| 283 | + sys.exit(1) # Exit if no services are configured |
| 284 | + else: |
| 285 | + logger.info(f"Configured services:") |
| 286 | + for service_name, config in service_configs.items(): |
| 287 | + logger.info(f" - {service_name}:") |
| 288 | + logger.info(f" Base URL: {config['base_url']}") |
| 289 | + logger.info(f" Awake Check Endpoint: {config['awake_check_endpoint']}") |
| 290 | + logger.info(f" MAC Address: {config['mac_address']}") |
| 291 | + logger.info(f" Poll Interval: {config['poll_interval']} seconds") |
| 292 | + logger.info(f" Max Retries: {config['max_retries']}") |
| 293 | + logger.info(f" Request Timeout: {config['request_timeout']}") |
| 294 | + logger.info(f" Awake Request Timeout: {config['awake_request_timeout']}") |
| 295 | + |
| 296 | + serve(app, host='0.0.0.0', port=SERVER_PORT) |
0 commit comments