Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 99 additions & 86 deletions cli/app/utils/config.py
Original file line number Diff line number Diff line change
@@ -1,103 +1,115 @@
import os
import re
import sys
from typing import Any, Dict, Optional

import yaml

from app.utils.message import MISSING_CONFIG_KEY_MESSAGE


class Config:
def __init__(self, default_env="PRODUCTION"):
self.default_env = default_env
self._yaml_config = None
self._user_config_file = None
self._cache = {}
def get_config_file_path(default_env: str = "PRODUCTION") -> str:
"""Get the path to the config file based on environment"""
config_file = "config.dev.yaml" if default_env.upper() == "DEVELOPMENT" else "config.prod.yaml"

if getattr(sys, "frozen", False) and hasattr(sys, "_MEIPASS"):
return os.path.join(sys._MEIPASS, "helpers", config_file)
else:
return os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../helpers", config_file))


def load_config_file(config_file_path: str) -> Dict[str, Any]:
"""Load YAML config file"""
with open(config_file_path, "r") as f:
return yaml.safe_load(f) or {}


def get_active_config(user_config_file: Optional[str] = None, default_env: str = "PRODUCTION") -> Dict[str, Any]:
"""Get the active config (user config if provided, else default)"""
if user_config_file:
if not os.path.exists(user_config_file):
raise FileNotFoundError(f"Config file not found: {user_config_file}")
return load_config_file(user_config_file)

config_file_path = get_config_file_path(default_env)
return load_config_file(config_file_path)


def get_env(default_env: str = "PRODUCTION") -> str:
"""Get environment from ENV variable or default"""
return os.environ.get("ENV", default_env)


def is_development(default_env: str = "PRODUCTION") -> bool:
"""Check if current environment is development"""
return get_env(default_env).upper() == "DEVELOPMENT"


def get_config_value(
config: Dict[str, Any],
path: str,
cache: Optional[Dict[str, Any]] = None
) -> Any:
"""Get config value using dot notation path"""
if cache is None:
cache = {}

if path in cache:
return cache[path]

keys = path.split(".")
value = config
for key in keys:
if isinstance(value, dict) and key in value:
value = value[key]
else:
raise KeyError(MISSING_CONFIG_KEY_MESSAGE.format(path=path, key=key))

if isinstance(value, str):
value = expand_env_placeholders(value)

cache[path] = value
return value

config_file = "config.dev.yaml" if default_env.upper() == "DEVELOPMENT" else "config.prod.yaml"

if getattr(sys, "frozen", False) and hasattr(sys, "_MEIPASS"):
self._yaml_path = os.path.join(sys._MEIPASS, "helpers", config_file)
else:
self._yaml_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../helpers", config_file))

def get_env(self):
return os.environ.get("ENV", self.default_env)

def is_development(self):
return self.get_env().upper() == "DEVELOPMENT"

def load_user_config(self, config_file: str):
"""Set user config file to replace default config."""
if config_file and not os.path.exists(config_file):
raise FileNotFoundError(f"Config file not found: {config_file}")
self._user_config_file = config_file
self._yaml_config = None
self._cache = {}

def _get_active_config(self):
"""Get the active config (user config if provided, else default)."""
if self._user_config_file:
if self._yaml_config is None:
with open(self._user_config_file, "r") as f:
self._yaml_config = yaml.safe_load(f)
return self._yaml_config

if self._yaml_config is None:
with open(self._yaml_path, "r") as f:
self._yaml_config = yaml.safe_load(f)
return self._yaml_config

def get(self, path: str):
"""Get config value using dot notation path."""
if path in self._cache:
return self._cache[path]

config = self._get_active_config()
keys = path.split(".")
for key in keys:
if isinstance(config, dict) and key in config:
config = config[key]
else:
raise KeyError(MISSING_CONFIG_KEY_MESSAGE.format(path=path, key=key))

if isinstance(config, str):
config = expand_env_placeholders(config)

self._cache[path] = config
return config

def get_service_env_values(self, service_env_path: str):
"""Get service environment values as a dictionary."""
env_config = self.get(service_env_path)
if not isinstance(env_config, dict):
raise ValueError(f"Expected dictionary at path '{service_env_path}'")
return {key: expand_env_placeholders(value) if isinstance(value, str) else value for key, value in env_config.items()}

def load_yaml_config(self):
"""Return the active config dict (for backward compatibility)."""
return self._get_active_config()

def get_yaml_value(self, path: str):
"""Alias for get() for backward compatibility."""
return self.get(path)

def unflatten_config(self, flattened_config: dict) -> dict:
"""Convert flattened config back to nested structure."""
nested = {}
for key, value in flattened_config.items():
keys = key.split(".")
current = nested
for k in keys[:-1]:
if k not in current:
current[k] = {}
current = current[k]
current[keys[-1]] = value
return nested
def get_service_env_values(config: Dict[str, Any], service_env_path: str) -> Dict[str, Any]:
"""Get service environment values as a dictionary"""
env_config = get_config_value(config, service_env_path)
if not isinstance(env_config, dict):
raise ValueError(f"Expected dictionary at path '{service_env_path}'")
return {key: expand_env_placeholders(value) if isinstance(value, str) else value for key, value in env_config.items()}


def load_yaml_config(user_config_file: Optional[str] = None, default_env: str = "PRODUCTION") -> Dict[str, Any]:
"""Return the active config dict (for backward compatibility)"""
return get_active_config(user_config_file, default_env)


def get_yaml_value(
config: Dict[str, Any],
path: str,
cache: Optional[Dict[str, Any]] = None
) -> Any:
"""Alias for get_config_value() for backward compatibility"""
return get_config_value(config, path, cache)


def unflatten_config(flattened_config: dict) -> dict:
"""Convert flattened config back to nested structure"""
nested = {}
for key, value in flattened_config.items():
keys = key.split(".")
current = nested
for k in keys[:-1]:
if k not in current:
current[k] = {}
current = current[k]
current[keys[-1]] = value
return nested


def expand_env_placeholders(value: str) -> str:
# Expand environment placeholders in the form ${ENV_VAR:-default}
"""Expand environment placeholders in the form ${ENV_VAR:-default}"""
# Supports nested expansions like ${VAR1:-${VAR2:-default}}
pattern = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)(:-([^}]*))?}")
max_iterations = 10 # Prevent infinite loops
Expand All @@ -116,6 +128,7 @@ def replacer(match):
return value


# Config path constants
VIEW_ENV_FILE = "services.view.env.VIEW_ENV_FILE"
API_ENV_FILE = "services.api.env.API_ENV_FILE"
DEFAULT_REPO = "clone.repo"
Expand Down
Loading