Skip to content

Commit 0778e8f

Browse files
committed
Fix hook validators to allow for failure hook without any args
1 parent dd3bd5c commit 0778e8f

File tree

2 files changed

+68
-38
lines changed

2 files changed

+68
-38
lines changed

docs/book/how-to/steps-pipelines/advanced_features.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -642,7 +642,7 @@ def my_step():
642642
The following conventions apply to hooks:
643643

644644
* the success hook takes no arguments
645-
* the failure hook takes a single `BaseException` typed argument
645+
* the failure hook optionally takes a single `BaseException` typed argument
646646

647647
You can also define hooks at the pipeline level to apply to all steps:
648648

src/zenml/hooks/hook_validators.py

Lines changed: 67 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typing import (
1717
TYPE_CHECKING,
1818
Any,
19+
Callable,
1920
Dict,
2021
Optional,
2122
Tuple,
@@ -37,6 +38,64 @@
3738
from zenml.types import HookSpecification, InitHookSpecification
3839

3940

41+
def _validate_hook_arguments(
42+
_func: Callable[..., Any],
43+
hook_kwargs: Dict[str, Any],
44+
exception_arg: Union[BaseException, bool] = False,
45+
) -> Dict[str, Any]:
46+
"""Validates hook arguments.
47+
48+
Args:
49+
func: The hook function to validate.
50+
hook_kwargs: The hook keyword arguments to validate.
51+
exception_arg: The exception argument to validate.
52+
53+
Returns:
54+
The validated hook arguments.
55+
56+
Raises:
57+
HookValidationException: If the hook arguments are not valid.
58+
"""
59+
# Validate hook arguments
60+
try:
61+
hook_args: Tuple[Any, ...] = ()
62+
if isinstance(exception_arg, BaseException):
63+
hook_args = (exception_arg,)
64+
elif exception_arg is True:
65+
hook_args = (Exception(),)
66+
config = ConfigDict(arbitrary_types_allowed=len(hook_args) > 0)
67+
validated_kwargs = validate_function_args(
68+
_func, config, *hook_args, **hook_kwargs
69+
)
70+
except (ValidationError, TypeError) as e:
71+
exc_msg = (
72+
"Failed to validate hook arguments for {func}: {e}\n"
73+
"Please observe the following guidelines:\n"
74+
"- the success hook takes no arguments\n"
75+
"- the failure hook optionally takes a single `BaseException` "
76+
"typed argument\n"
77+
"- the init hook takes any number of JSON-safe arguments\n"
78+
"- the cleanup hook takes no arguments\n"
79+
)
80+
81+
if not hook_args:
82+
raise HookValidationException(exc_msg.format(func=_func, e=e))
83+
84+
# If we have an exception argument, we try again without it. This is
85+
# to account for the case where the hook function does not expect an
86+
# exception argument.
87+
hook_args = ()
88+
config = ConfigDict(arbitrary_types_allowed=False)
89+
try:
90+
validated_kwargs = validate_function_args(
91+
_func, config, *hook_args, **hook_kwargs
92+
)
93+
except (ValidationError, TypeError) as e:
94+
raise HookValidationException(exc_msg.format(func=_func, e=e))
95+
96+
return validated_kwargs
97+
98+
4099
def resolve_and_validate_hook(
41100
hook: Union["HookSpecification", "InitHookSpecification"],
42101
hook_kwargs: Optional[Dict[str, Any]] = None,
@@ -68,24 +127,9 @@ def resolve_and_validate_hook(
68127
raise ValueError(f"{func} is not a valid function.")
69128

70129
# Validate hook arguments
71-
try:
72-
hook_args: Tuple[Any, ...] = ()
73-
if allow_exception_arg:
74-
hook_args = (Exception(),)
75-
hook_kwargs = hook_kwargs or {}
76-
config = ConfigDict(arbitrary_types_allowed=allow_exception_arg)
77-
validated_kwargs = validate_function_args(
78-
func, config, *hook_args, **hook_kwargs
79-
)
80-
except (ValidationError, TypeError) as e:
81-
raise HookValidationException(
82-
f"Failed to validate hook arguments for {func}: {e}\n"
83-
"Please observe the following guidelines:\n"
84-
"- the success hook takes no arguments\n"
85-
"- the failure hook takes a single `BaseException` typed argument\n"
86-
"- the init hook takes any number of JSON-safe arguments\n"
87-
"- the cleanup hook takes no arguments\n"
88-
)
130+
validated_kwargs = _validate_hook_arguments(
131+
func, hook_kwargs or {}, allow_exception_arg
132+
)
89133

90134
return source_utils.resolve(func), validated_kwargs
91135

@@ -120,28 +164,14 @@ def load_and_run_hook(
120164
logger.error(msg)
121165
return None
122166
try:
123-
# Validate hook arguments
124-
hook_args: Tuple[Any, ...] = ()
125-
if step_exception:
126-
hook_args = (step_exception,)
127-
hook_parameters = hook_parameters or {}
128-
config = ConfigDict(arbitrary_types_allowed=step_exception is not None)
129-
validated_kwargs = validate_function_args(
130-
hook, config, *hook_args, **hook_parameters
131-
)
132-
except (ValueError, TypeError) as e:
133-
msg = (
134-
f"Failed to validate hook arguments for {hook}: {e}\n"
135-
"Please observe the following guidelines:\n"
136-
"- the success hook takes no arguments\n"
137-
"- the failure hook takes a single `BaseException` typed argument\n"
138-
"- the init hook takes any number of JSON-safe arguments\n"
139-
"- the cleanup hook takes no arguments\n"
167+
validated_kwargs = _validate_hook_arguments(
168+
hook, hook_parameters or {}, step_exception or False
140169
)
170+
except HookValidationException as e:
141171
if raise_on_error:
142-
raise RuntimeError(msg) from e
172+
raise
143173
else:
144-
logger.error(msg)
174+
logger.error(e)
145175
return None
146176

147177
try:

0 commit comments

Comments
 (0)