|
16 | 16 | from typing import ( |
17 | 17 | TYPE_CHECKING, |
18 | 18 | Any, |
| 19 | + Callable, |
19 | 20 | Dict, |
20 | 21 | Optional, |
21 | 22 | Tuple, |
|
37 | 38 | from zenml.types import HookSpecification, InitHookSpecification |
38 | 39 |
|
39 | 40 |
|
| 41 | +def _validate_hook_arguments( |
| 42 | + _func: Callable[..., Any], |
| 43 | + hook_kwargs: Dict[str, Any], |
| 44 | + exception_arg: Union[BaseException, bool] = False, |
| 45 | +) -> Dict[str, Any]: |
| 46 | + """Validates hook arguments. |
| 47 | +
|
| 48 | + Args: |
| 49 | + func: The hook function to validate. |
| 50 | + hook_kwargs: The hook keyword arguments to validate. |
| 51 | + exception_arg: The exception argument to validate. |
| 52 | +
|
| 53 | + Returns: |
| 54 | + The validated hook arguments. |
| 55 | +
|
| 56 | + Raises: |
| 57 | + HookValidationException: If the hook arguments are not valid. |
| 58 | + """ |
| 59 | + # Validate hook arguments |
| 60 | + try: |
| 61 | + hook_args: Tuple[Any, ...] = () |
| 62 | + if isinstance(exception_arg, BaseException): |
| 63 | + hook_args = (exception_arg,) |
| 64 | + elif exception_arg is True: |
| 65 | + hook_args = (Exception(),) |
| 66 | + config = ConfigDict(arbitrary_types_allowed=len(hook_args) > 0) |
| 67 | + validated_kwargs = validate_function_args( |
| 68 | + _func, config, *hook_args, **hook_kwargs |
| 69 | + ) |
| 70 | + except (ValidationError, TypeError) as e: |
| 71 | + exc_msg = ( |
| 72 | + "Failed to validate hook arguments for {func}: {e}\n" |
| 73 | + "Please observe the following guidelines:\n" |
| 74 | + "- the success hook takes no arguments\n" |
| 75 | + "- the failure hook optionally takes a single `BaseException` " |
| 76 | + "typed argument\n" |
| 77 | + "- the init hook takes any number of JSON-safe arguments\n" |
| 78 | + "- the cleanup hook takes no arguments\n" |
| 79 | + ) |
| 80 | + |
| 81 | + if not hook_args: |
| 82 | + raise HookValidationException(exc_msg.format(func=_func, e=e)) |
| 83 | + |
| 84 | + # If we have an exception argument, we try again without it. This is |
| 85 | + # to account for the case where the hook function does not expect an |
| 86 | + # exception argument. |
| 87 | + hook_args = () |
| 88 | + config = ConfigDict(arbitrary_types_allowed=False) |
| 89 | + try: |
| 90 | + validated_kwargs = validate_function_args( |
| 91 | + _func, config, *hook_args, **hook_kwargs |
| 92 | + ) |
| 93 | + except (ValidationError, TypeError) as e: |
| 94 | + raise HookValidationException(exc_msg.format(func=_func, e=e)) |
| 95 | + |
| 96 | + return validated_kwargs |
| 97 | + |
| 98 | + |
40 | 99 | def resolve_and_validate_hook( |
41 | 100 | hook: Union["HookSpecification", "InitHookSpecification"], |
42 | 101 | hook_kwargs: Optional[Dict[str, Any]] = None, |
@@ -68,24 +127,9 @@ def resolve_and_validate_hook( |
68 | 127 | raise ValueError(f"{func} is not a valid function.") |
69 | 128 |
|
70 | 129 | # Validate hook arguments |
71 | | - try: |
72 | | - hook_args: Tuple[Any, ...] = () |
73 | | - if allow_exception_arg: |
74 | | - hook_args = (Exception(),) |
75 | | - hook_kwargs = hook_kwargs or {} |
76 | | - config = ConfigDict(arbitrary_types_allowed=allow_exception_arg) |
77 | | - validated_kwargs = validate_function_args( |
78 | | - func, config, *hook_args, **hook_kwargs |
79 | | - ) |
80 | | - except (ValidationError, TypeError) as e: |
81 | | - raise HookValidationException( |
82 | | - f"Failed to validate hook arguments for {func}: {e}\n" |
83 | | - "Please observe the following guidelines:\n" |
84 | | - "- the success hook takes no arguments\n" |
85 | | - "- the failure hook takes a single `BaseException` typed argument\n" |
86 | | - "- the init hook takes any number of JSON-safe arguments\n" |
87 | | - "- the cleanup hook takes no arguments\n" |
88 | | - ) |
| 130 | + validated_kwargs = _validate_hook_arguments( |
| 131 | + func, hook_kwargs or {}, allow_exception_arg |
| 132 | + ) |
89 | 133 |
|
90 | 134 | return source_utils.resolve(func), validated_kwargs |
91 | 135 |
|
@@ -120,28 +164,14 @@ def load_and_run_hook( |
120 | 164 | logger.error(msg) |
121 | 165 | return None |
122 | 166 | try: |
123 | | - # Validate hook arguments |
124 | | - hook_args: Tuple[Any, ...] = () |
125 | | - if step_exception: |
126 | | - hook_args = (step_exception,) |
127 | | - hook_parameters = hook_parameters or {} |
128 | | - config = ConfigDict(arbitrary_types_allowed=step_exception is not None) |
129 | | - validated_kwargs = validate_function_args( |
130 | | - hook, config, *hook_args, **hook_parameters |
131 | | - ) |
132 | | - except (ValueError, TypeError) as e: |
133 | | - msg = ( |
134 | | - f"Failed to validate hook arguments for {hook}: {e}\n" |
135 | | - "Please observe the following guidelines:\n" |
136 | | - "- the success hook takes no arguments\n" |
137 | | - "- the failure hook takes a single `BaseException` typed argument\n" |
138 | | - "- the init hook takes any number of JSON-safe arguments\n" |
139 | | - "- the cleanup hook takes no arguments\n" |
| 167 | + validated_kwargs = _validate_hook_arguments( |
| 168 | + hook, hook_parameters or {}, step_exception or False |
140 | 169 | ) |
| 170 | + except HookValidationException as e: |
141 | 171 | if raise_on_error: |
142 | | - raise RuntimeError(msg) from e |
| 172 | + raise |
143 | 173 | else: |
144 | | - logger.error(msg) |
| 174 | + logger.error(e) |
145 | 175 | return None |
146 | 176 |
|
147 | 177 | try: |
|
0 commit comments