Skip to content

Commit b4e18c7

Browse files
PR review fixes & improvements
1 parent 5ba2e9f commit b4e18c7

File tree

5 files changed

+82
-41
lines changed

5 files changed

+82
-41
lines changed

src/zenml/config/compiler.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -385,31 +385,22 @@ def _validate_docker_settings_usage(
385385
stack: The stack the settings are validated against.
386386
387387
"""
388-
from zenml.orchestrators import (
389-
ContainerizedOrchestrator,
390-
LocalOrchestrator,
391-
)
388+
from zenml.orchestrators import LocalOrchestrator
392389

393-
if not docker_settings or isinstance(
394-
stack.orchestrator, ContainerizedOrchestrator
395-
):
390+
if not docker_settings:
396391
return
397392

398393
warning_message = (
399-
"You are specifying docker settings but you are not using a"
400-
f"containerized orchestrator: {stack.orchestrator.__class__.__name__}."
401-
f"Consider switching stack or using a containerized orchestrator, otherwise"
402-
f"your docker settings will be ignored."
394+
"You are specifying docker settings but the orchestrator"
395+
" you are using (LocalOrchestrator) will not make use of them. "
396+
"Consider switching stacks, removing the settings, or using a "
397+
"different orchestrator."
403398
)
404399

405400
if isinstance(stack.orchestrator, LocalOrchestrator):
406401
WARNING_CONTROLLER.info(
407402
warning_code=WarningCodes.ZML002, message=warning_message
408403
)
409-
else:
410-
WARNING_CONTROLLER.warn(
411-
warning_code=WarningCodes.ZML002, message=warning_message
412-
)
413404

414405
def _filter_and_validate_settings(
415406
self,

src/zenml/utils/warnings/base.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,26 +13,26 @@
1313
# permissions and limitations under the License.
1414
"""Warning configuration class and helper enums."""
1515

16-
from enum import Enum
17-
1816
from pydantic import BaseModel, Field
1917

18+
from zenml.utils.enum_utils import StrEnum
19+
2020

21-
class WarningSeverity(str, Enum):
21+
class WarningSeverity(StrEnum):
2222
"""Enum class describing the warning severity."""
2323

2424
LOW = "low"
2525
MEDIUM = "medium"
2626
HIGH = "high"
2727

2828

29-
class WarningCategory(str, Enum):
29+
class WarningCategory(StrEnum):
3030
"""Enum class describing the warning category."""
3131

3232
USAGE = "USAGE"
3333

3434

35-
class WarningVerbosity(str, Enum):
35+
class WarningVerbosity(StrEnum):
3636
"""Enum class describing the warning verbosity."""
3737

3838
LOW = "low"

src/zenml/utils/warnings/controller.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@
1414
"""Module for centralized WarningController implementation."""
1515

1616
import logging
17-
from collections import defaultdict
17+
from collections import Counter
1818
from typing import Any
1919

2020
from zenml.enums import LoggingLevels
2121
from zenml.utils.singleton import SingletonMetaClass
2222
from zenml.utils.warnings.base import WarningConfig, WarningVerbosity
23+
from zenml.utils.warnings.registry import WarningCodes
2324

2425
logger = logging.getLogger(__name__)
2526

@@ -30,7 +31,7 @@ class WarningController(metaclass=SingletonMetaClass):
3031
def __init__(self) -> None:
3132
"""WarningController constructor."""
3233
self._warning_configs: dict[str, WarningConfig] = {}
33-
self._warning_statistics: dict[str, int] = defaultdict(int)
34+
self._warning_statistics: dict[str, int] = Counter()
3435

3536
def register(self, warning_configs: dict[str, WarningConfig]) -> None:
3637
"""Register a warning config collection to the controller.
@@ -81,7 +82,7 @@ def _get_display_message(
8182

8283
def _log(
8384
self,
84-
warning_code: str,
85+
warning_code: WarningCodes,
8586
message: str,
8687
level: LoggingLevels,
8788
**kwargs: dict[str, Any],
@@ -125,7 +126,9 @@ def _log(
125126
# Assumes warning level is the default if an invalid option is passed.
126127
logger.warning(display_message.format(**kwargs))
127128

128-
def warn(self, *, warning_code: str, message: str, **kwargs: Any) -> None:
129+
def warn(
130+
self, *, warning_code: WarningCodes, message: str, **kwargs: Any
131+
) -> None:
129132
"""Method to execute warning handling logic with warning log level.
130133
131134
Args:
@@ -135,7 +138,9 @@ def warn(self, *, warning_code: str, message: str, **kwargs: Any) -> None:
135138
"""
136139
self._log(warning_code, message, LoggingLevels.WARNING, **kwargs)
137140

138-
def info(self, *, warning_code: str, message: str, **kwargs: Any) -> None:
141+
def info(
142+
self, *, warning_code: WarningCodes, message: str, **kwargs: Any
143+
) -> None:
139144
"""Method to execute warning handling logic with info log level.
140145
141146
Args:

src/zenml/utils/warnings/registry.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
# permissions and limitations under the License.
1414
"""Module for warning configurations organization and resolution."""
1515

16-
from enum import Enum
17-
16+
from zenml.utils.enum_utils import StrEnum
1817
from zenml.utils.warnings.base import (
1918
WarningCategory,
2019
WarningConfig,
@@ -23,7 +22,7 @@
2322
)
2423

2524

26-
class WarningCodes(str, Enum):
25+
class WarningCodes(StrEnum):
2726
"""Enum class organizing the warning codes."""
2827

2928
ZML001 = "ZML001"

tests/unit/utils/test_warnings.py

Lines changed: 59 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from pydantic import ValidationError
33

44
from zenml.utils.warnings.base import (
5-
WarningConfig,
65
WarningCategory,
6+
WarningConfig,
77
WarningSeverity,
88
WarningVerbosity,
99
)
@@ -16,7 +16,7 @@ def controller() -> WarningController:
1616
1717
We reuse the singleton instance but clear its internal state to ensure test isolation.
1818
"""
19-
c = WarningController()
19+
c = WarningController.create()
2020
# hard reset of internal state
2121
c._warning_configs.clear()
2222
c._warning_statistics.clear()
@@ -109,7 +109,9 @@ def test_config_validations_reject_invalid_values():
109109
)
110110

111111

112-
def test_registration_merges_configs(controller: WarningController, sample_configs: dict[str, WarningConfig]):
112+
def test_registration_merges_configs(
113+
controller: WarningController, sample_configs: dict[str, WarningConfig]
114+
):
113115
controller.register(sample_configs)
114116
assert controller._warning_configs["TEST001"].code == "TEST001"
115117
assert controller._warning_configs["TEST002"].is_throttled is True
@@ -120,7 +122,9 @@ def test_singleton_instance_is_shared(controller: WarningController):
120122
assert controller is another
121123

122124

123-
def test_statistics_increment_for_non_throttled(controller: WarningController, sample_configs: dict[str, WarningConfig]):
125+
def test_statistics_increment_for_non_throttled(
126+
controller: WarningController, sample_configs: dict[str, WarningConfig]
127+
):
124128
controller.register(sample_configs)
125129

126130
controller.warn(warning_code="TEST001", message="once")
@@ -130,7 +134,9 @@ def test_statistics_increment_for_non_throttled(controller: WarningController, s
130134
assert controller._warning_statistics["TEST001"] == 3
131135

132136

133-
def test_statistics_throttled_only_once(controller: WarningController, sample_configs: dict[str, WarningConfig]):
137+
def test_statistics_throttled_only_once(
138+
controller: WarningController, sample_configs: dict[str, WarningConfig]
139+
):
134140
controller.register(sample_configs)
135141

136142
controller.warn(warning_code="TEST002", message="first")
@@ -140,7 +146,9 @@ def test_statistics_throttled_only_once(controller: WarningController, sample_co
140146
assert controller._warning_statistics["TEST002"] == 1
141147

142148

143-
def test_warn_and_info_do_not_break_with_format_kwargs(controller, sample_configs):
149+
def test_warn_and_info_do_not_break_with_format_kwargs(
150+
controller: WarningController, sample_configs: dict[str, WarningConfig]
151+
):
144152
controller.register(sample_configs)
145153

146154
controller.warn(
@@ -159,24 +167,39 @@ def test_warn_and_info_do_not_break_with_format_kwargs(controller, sample_config
159167
assert controller._warning_statistics["TEST001"] == 2
160168

161169

162-
def test_unregistered_warning_no_crash_and_no_stats_increment(controller):
163-
controller.warn(warning_code="UNKNOWN", message="should fallback to default")
170+
def test_unregistered_warning_no_crash_and_no_stats_increment(
171+
controller: WarningController,
172+
):
173+
controller.warn(
174+
warning_code="UNKNOWN", message="should fallback to default"
175+
)
164176
controller.info(warning_code="UNKNOWN", message="also fallback")
165177

166178
assert "UNKNOWN" not in controller._warning_statistics
167179

168180

169-
def test_get_display_message_varies_by_verbosity(sample_configs):
181+
def test_get_display_message_varies_by_verbosity(
182+
sample_configs: dict[str, WarningConfig],
183+
):
170184
from zenml.utils.warnings.controller import WarningController as WC
185+
171186
module_name = "mod.path"
172187
line_number = 42
173188

174-
low = sample_configs["TEST001"].model_copy(update={"verbosity": WarningVerbosity.LOW})
175-
med = sample_configs["TEST001"].model_copy(update={"verbosity": WarningVerbosity.MEDIUM})
176-
high = sample_configs["TEST001"].model_copy(update={"verbosity": WarningVerbosity.HIGH})
189+
low = sample_configs["TEST001"].model_copy(
190+
update={"verbosity": WarningVerbosity.LOW}
191+
)
192+
med = sample_configs["TEST001"].model_copy(
193+
update={"verbosity": WarningVerbosity.MEDIUM}
194+
)
195+
high = sample_configs["TEST001"].model_copy(
196+
update={"verbosity": WarningVerbosity.HIGH}
197+
)
177198

178199
msg_low = WC._get_display_message("hello", module_name, line_number, low)
179-
assert msg_low.startswith("[TEST001](WarningCategory.USAGE)") or msg_low.startswith("[TEST001](USAGE)")
200+
assert msg_low.startswith(
201+
"[TEST001](WarningCategory.USAGE)"
202+
) or msg_low.startswith("[TEST001](USAGE)")
180203
assert "hello" in msg_low
181204
assert "mod.path" not in msg_low and "42" not in msg_low
182205
assert low.description not in msg_low
@@ -191,3 +214,26 @@ def test_get_display_message_varies_by_verbosity(sample_configs):
191214
assert "[TEST001]" in msg_high and "hello" in msg_high
192215
assert "\n" in msg_high
193216
assert high.description.strip() in msg_high
217+
218+
219+
def test_combined_registries(
220+
controller: WarningController, sample_configs: dict[str, WarningConfig]
221+
):
222+
from zenml.utils.warnings.registry import (
223+
WARNING_CONFIG_REGISTRY,
224+
WarningCodes,
225+
)
226+
227+
controller.register(sample_configs)
228+
controller.register(WARNING_CONFIG_REGISTRY)
229+
230+
assert WarningCodes.ZML001.value in controller._warning_configs
231+
assert "TEST001" in controller._warning_configs
232+
233+
controller.warn(
234+
warning_code=WarningCodes.ZML001,
235+
message="You are warned {name}",
236+
name="tester",
237+
)
238+
239+
assert controller._warning_statistics["ZML001"] >= 1

0 commit comments

Comments
 (0)