| 
 | 1 | +#  Copyright (c) ZenML GmbH 2025. All Rights Reserved.  | 
 | 2 | +#  | 
 | 3 | +#  Licensed under the Apache License, Version 2.0 (the "License");  | 
 | 4 | +#  you may not use this file except in compliance with the License.  | 
 | 5 | +#  You may obtain a copy of the License at:  | 
 | 6 | +#  | 
 | 7 | +#       https://www.apache.org/licenses/LICENSE-2.0  | 
 | 8 | +#  | 
 | 9 | +#  Unless required by applicable law or agreed to in writing, software  | 
 | 10 | +#  distributed under the License is distributed on an "AS IS" BASIS,  | 
 | 11 | +#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express  | 
 | 12 | +#  or implied. See the License for the specific language governing  | 
 | 13 | +#  permissions and limitations under the License.  | 
 | 14 | +"""Module for centralized WarningController implementation."""  | 
 | 15 | + | 
 | 16 | +import logging  | 
 | 17 | +from collections import defaultdict  | 
 | 18 | +from typing import Any  | 
 | 19 | + | 
 | 20 | +from zenml.enums import LoggingLevels  | 
 | 21 | +from zenml.utils.singleton import SingletonMetaClass  | 
 | 22 | +from zenml.utils.warnings.base import WarningConfig, WarningVerbosity  | 
 | 23 | + | 
 | 24 | +logger = logging.getLogger(__name__)  | 
 | 25 | + | 
 | 26 | + | 
 | 27 | +class WarningController(metaclass=SingletonMetaClass):  | 
 | 28 | +    """Class responsible for centralized handling of warning messages."""  | 
 | 29 | + | 
 | 30 | +    def __init__(self) -> None:  | 
 | 31 | +        """WarningController constructor."""  | 
 | 32 | +        self._warning_configs: dict[str, WarningConfig] = {}  | 
 | 33 | +        self._warning_statistics: dict[str, int] = defaultdict(int)  | 
 | 34 | + | 
 | 35 | +    def register(self, warning_configs: dict[str, WarningConfig]) -> None:  | 
 | 36 | +        """Register a warning config collection to the controller.  | 
 | 37 | +
  | 
 | 38 | +        Args:  | 
 | 39 | +            warning_configs: Configs to be registered. Key should be the warning code.  | 
 | 40 | +        """  | 
 | 41 | +        self._warning_configs.update(warning_configs)  | 
 | 42 | + | 
 | 43 | +    @staticmethod  | 
 | 44 | +    def _resolve_call_details() -> tuple[str, int]:  | 
 | 45 | +        import inspect  | 
 | 46 | + | 
 | 47 | +        frame = inspect.stack()[3]  # public methods -> _log -> _resolve  | 
 | 48 | +        module = inspect.getmodule(frame[0])  | 
 | 49 | +        module_name = module.__name__ if module else "<unknown module>"  | 
 | 50 | +        line_number = frame.lineno  | 
 | 51 | + | 
 | 52 | +        return module_name, line_number  | 
 | 53 | + | 
 | 54 | +    @staticmethod  | 
 | 55 | +    def _get_display_message(  | 
 | 56 | +        message: str,  | 
 | 57 | +        module_name: str,  | 
 | 58 | +        line_number: int,  | 
 | 59 | +        config: WarningConfig,  | 
 | 60 | +    ) -> str:  | 
 | 61 | +        """Helper method to build the warning message string.  | 
 | 62 | +
  | 
 | 63 | +        Args:  | 
 | 64 | +            message: The warning message.  | 
 | 65 | +            module_name: The module that the warning call originated from.  | 
 | 66 | +            line_number: The line number that the warning call originated from.  | 
 | 67 | +            config: The warning configuration.  | 
 | 68 | +
  | 
 | 69 | +        Returns: A warning message containing extra fields/info based on warning config.  | 
 | 70 | +        """  | 
 | 71 | +        display = f"[{config.code}]({config.category}) - {message}"  | 
 | 72 | + | 
 | 73 | +        if config.verbosity == WarningVerbosity.MEDIUM:  | 
 | 74 | +            display = f"{module_name}:{line_number} {display}"  | 
 | 75 | + | 
 | 76 | +        if config.verbosity == WarningVerbosity.HIGH:  | 
 | 77 | +            display = f"{display}\n{config.description}"  | 
 | 78 | + | 
 | 79 | +        return display  | 
 | 80 | + | 
 | 81 | +    def _log(  | 
 | 82 | +        self,  | 
 | 83 | +        warning_code: str,  | 
 | 84 | +        message: str,  | 
 | 85 | +        level: LoggingLevels,  | 
 | 86 | +        **kwargs: dict[str, Any],  | 
 | 87 | +    ) -> None:  | 
 | 88 | +        """Core function for warning handling.  | 
 | 89 | +
  | 
 | 90 | +        Args:  | 
 | 91 | +            warning_code: The code of the warning configuration.  | 
 | 92 | +            message: The warning message.  | 
 | 93 | +            level: The level of the warning.  | 
 | 94 | +            **kwargs: Keyword arguments (for formatted messages).  | 
 | 95 | +
  | 
 | 96 | +        """  | 
 | 97 | +        warning_config = self._warning_configs.get(warning_code)  | 
 | 98 | + | 
 | 99 | +        # resolves the module and line number of the warning call.  | 
 | 100 | +        module_name, line_number = self._resolve_call_details()  | 
 | 101 | + | 
 | 102 | +        if not warning_config:  | 
 | 103 | +            # If no config is available just follow default behavior:  | 
 | 104 | +            logger.warning(f"{module_name}:{line_number} - {message}")  | 
 | 105 | +            return  | 
 | 106 | + | 
 | 107 | +        if warning_config.is_throttled:  | 
 | 108 | +            if warning_code in self._warning_statistics:  | 
 | 109 | +                # Throttled code has already appeared - skip.  | 
 | 110 | +                return  | 
 | 111 | + | 
 | 112 | +        display_message = self._get_display_message(  | 
 | 113 | +            message=message,  | 
 | 114 | +            module_name=module_name,  | 
 | 115 | +            line_number=line_number,  | 
 | 116 | +            config=warning_config,  | 
 | 117 | +        )  | 
 | 118 | + | 
 | 119 | +        self._warning_statistics[warning_code] += 1  | 
 | 120 | + | 
 | 121 | +        if level == LoggingLevels.INFO:  | 
 | 122 | +            logger.info(display_message.format(**kwargs))  | 
 | 123 | +        else:  | 
 | 124 | +            # Assumes warning level is the default if an invalid option is passed.  | 
 | 125 | +            logger.warning(display_message.format(**kwargs))  | 
 | 126 | + | 
 | 127 | +    def warn(  | 
 | 128 | +        self, *, warning_code: str, message: str, **kwargs: dict[str, Any]  | 
 | 129 | +    ) -> None:  | 
 | 130 | +        """Method to execute warning handling logic with warning log level.  | 
 | 131 | +
  | 
 | 132 | +        Args:  | 
 | 133 | +            warning_code: The code of the warning (see WarningCodes enum)  | 
 | 134 | +            message: The message to display.  | 
 | 135 | +            **kwargs: Keyword arguments (for formatted messages).  | 
 | 136 | +        """  | 
 | 137 | +        self._log(warning_code, message, LoggingLevels.WARNING, **kwargs)  | 
 | 138 | + | 
 | 139 | +    def info(  | 
 | 140 | +        self, *, warning_code: str, message: str, **kwargs: dict[str, Any]  | 
 | 141 | +    ) -> None:  | 
 | 142 | +        """Method to execute warning handling logic with info log level.  | 
 | 143 | +
  | 
 | 144 | +        Args:  | 
 | 145 | +            warning_code: The code of the warning (see WarningCodes enum)  | 
 | 146 | +            message: The message to display.  | 
 | 147 | +            **kwargs: Keyword arguments (for formatted messages).  | 
 | 148 | +        """  | 
 | 149 | +        self._log(warning_code, message, LoggingLevels.INFO, **kwargs)  | 
 | 150 | + | 
 | 151 | +    @staticmethod  | 
 | 152 | +    def create() -> "WarningController":  | 
 | 153 | +        """Factory function for WarningController.  | 
 | 154 | +
  | 
 | 155 | +        Creates a new warning controller and registers system warning configs.  | 
 | 156 | +
  | 
 | 157 | +        Returns:  | 
 | 158 | +            A warning controller instance.  | 
 | 159 | +        """  | 
 | 160 | +        from zenml.utils.warnings.registry import WARNING_CONFIG_REGISTRY  | 
 | 161 | + | 
 | 162 | +        registry = WarningController()  | 
 | 163 | + | 
 | 164 | +        registry.register(warning_configs=WARNING_CONFIG_REGISTRY)  | 
 | 165 | + | 
 | 166 | +        return registry  | 
0 commit comments