Skip to content

Commit b9224d6

Browse files
committed
Step context changes
1 parent 1bb1cd3 commit b9224d6

File tree

3 files changed

+102
-80
lines changed

3 files changed

+102
-80
lines changed

src/zenml/orchestrators/step_runner.py

Lines changed: 66 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,6 @@ def run(
184184

185185
self._stack.prepare_step_run(info=step_run_info)
186186

187-
# Initialize the step context singleton
188-
StepContext._clear()
189187
step_context = StepContext(
190188
pipeline_run=pipeline_run,
191189
step_run=step_run,
@@ -196,75 +194,78 @@ def run(
196194
},
197195
)
198196

199-
# Parse the inputs for the entrypoint function.
200-
function_params = self._parse_inputs(
201-
args=spec.args,
202-
annotations=spec.annotations,
203-
input_artifacts=input_artifacts,
204-
)
197+
with step_context:
198+
function_params = self._parse_inputs(
199+
args=spec.args,
200+
annotations=spec.annotations,
201+
input_artifacts=input_artifacts,
202+
)
205203

206-
# Get all step environment variables. For most orchestrators, the
207-
# non-secret environment variables have been set before by the
208-
# orchestrator. But for some orchestrators, this is not possible and
209-
# we therefore make sure to set them here so they're at least
210-
# available for the user code.
211-
step_environment = env_utils.get_step_environment(
212-
step_config=step_run.config, stack=self._stack
213-
)
214-
secret_environment = env_utils.get_step_secret_environment(
215-
step_config=step_run.config, stack=self._stack
216-
)
217-
step_environment.update(secret_environment)
218-
219-
step_failed = False
220-
try:
221-
if (
222-
# TODO: do we need to disable this for dynamic pipelines?
223-
pipeline_run.snapshot
224-
and self._stack.orchestrator.run_init_cleanup_at_step_level
225-
):
226-
self._stack.orchestrator.run_init_hook(
227-
snapshot=pipeline_run.snapshot
228-
)
204+
# Get all step environment variables. For most orchestrators, the
205+
# non-secret environment variables have been set before by the
206+
# orchestrator. But for some orchestrators, this is not possible and
207+
# we therefore make sure to set them here so they're at least
208+
# available for the user code.
209+
step_environment = env_utils.get_step_environment(
210+
step_config=step_run.config, stack=self._stack
211+
)
212+
secret_environment = env_utils.get_step_secret_environment(
213+
step_config=step_run.config, stack=self._stack
214+
)
215+
step_environment.update(secret_environment)
229216

230-
with env_utils.temporary_environment(step_environment):
231-
return_values = step_instance.call_entrypoint(
232-
**function_params
233-
)
234-
except BaseException as step_exception: # noqa: E722
235-
step_failed = True
217+
step_failed = False
218+
try:
219+
if (
220+
# TODO: do we need to disable this for dynamic pipelines?
221+
pipeline_run.snapshot
222+
and self._stack.orchestrator.run_init_cleanup_at_step_level
223+
):
224+
self._stack.orchestrator.run_init_hook(
225+
snapshot=pipeline_run.snapshot
226+
)
236227

237-
exception_info = exception_utils.collect_exception_information(
238-
step_exception, step_instance
239-
)
228+
with env_utils.temporary_environment(step_environment):
229+
return_values = step_instance.call_entrypoint(
230+
**function_params
231+
)
232+
except BaseException as step_exception: # noqa: E722
233+
step_failed = True
240234

241-
if ENV_ZENML_STEP_OPERATOR in os.environ:
242-
# We're running in a step operator environment, so we can't
243-
# depend on the step launcher to publish the exception info
244-
Client().zen_store.update_run_step(
245-
step_run_id=step_run_info.step_run_id,
246-
step_run_update=StepRunUpdate(
247-
exception_info=exception_info,
248-
),
235+
exception_info = (
236+
exception_utils.collect_exception_information(
237+
step_exception, step_instance
238+
)
249239
)
250-
else:
251-
# This will be published by the step launcher
252-
step_exception_info.set(exception_info)
253240

254-
if not step_run.is_retriable:
255-
if (
256-
failure_hook_source
257-
:= self.configuration.failure_hook_source
258-
):
259-
logger.info("Detected failure hook. Running...")
260-
with env_utils.temporary_environment(step_environment):
261-
load_and_run_hook(
262-
failure_hook_source,
263-
step_exception=step_exception,
264-
)
265-
raise
266-
finally:
267-
try:
241+
if ENV_ZENML_STEP_OPERATOR in os.environ:
242+
# We're running in a step operator environment, so we can't
243+
# depend on the step launcher to publish the exception info
244+
Client().zen_store.update_run_step(
245+
step_run_id=step_run_info.step_run_id,
246+
step_run_update=StepRunUpdate(
247+
exception_info=exception_info,
248+
),
249+
)
250+
else:
251+
# This will be published by the step launcher
252+
step_exception_info.set(exception_info)
253+
254+
if not step_run.is_retriable:
255+
if (
256+
failure_hook_source
257+
:= self.configuration.failure_hook_source
258+
):
259+
logger.info("Detected failure hook. Running...")
260+
with env_utils.temporary_environment(
261+
step_environment
262+
):
263+
load_and_run_hook(
264+
failure_hook_source,
265+
step_exception=step_exception,
266+
)
267+
raise
268+
finally:
268269
step_run_metadata = self._stack.get_step_run_metadata(
269270
info=step_run_info,
270271
)
@@ -338,12 +339,6 @@ def run(
338339
snapshot=pipeline_run.snapshot
339340
)
340341

341-
finally:
342-
step_context._cleanup_registry.execute_callbacks(
343-
raise_on_exception=False
344-
)
345-
StepContext._clear() # Remove the step context singleton
346-
347342
# Update the status and output artifacts of the step run.
348343
output_artifact_ids = {
349344
output_name: [

src/zenml/steps/step_context.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,24 @@
1313
# permissions and limitations under the License.
1414
"""Step context class."""
1515

16+
import contextvars
1617
from typing import (
1718
TYPE_CHECKING,
1819
Any,
1920
Dict,
2021
List,
2122
Mapping,
2223
Optional,
24+
Self,
2325
Sequence,
2426
Type,
2527
)
2628

2729
from zenml.exceptions import StepContextError
2830
from zenml.logger import get_logger
31+
from zenml.utils import context_utils
2932
from zenml.utils.callback_registry import CallbackRegistry
30-
from zenml.utils.singleton import SingletonMetaClass, ThreadLocalSingleton
33+
from zenml.utils.singleton import SingletonMetaClass
3134

3235
if TYPE_CHECKING:
3336
from zenml.artifacts.artifact_config import ArtifactConfig
@@ -54,8 +57,9 @@ def get_step_context() -> "StepContext":
5457
Raises:
5558
RuntimeError: If no step is currently running.
5659
"""
57-
if StepContext._exists():
58-
return StepContext() # type: ignore
60+
if ctx := StepContext.get():
61+
return ctx
62+
5963
raise RuntimeError(
6064
"The step context is only available inside a step function."
6165
)
@@ -110,13 +114,9 @@ def initialize(self, state: Optional[Any]) -> None:
110114
self.initialized = True
111115

112116

113-
# TODO: use base context class
114-
class StepContext(metaclass=ThreadLocalSingleton):
117+
class StepContext(context_utils.BaseContext):
115118
"""Provides additional context inside a step function.
116119
117-
This singleton class is used to access information about the current run,
118-
step run, or its outputs inside a step function.
119-
120120
Usage example:
121121
122122
```python
@@ -139,6 +139,8 @@ def my_trainer_step() -> Any:
139139
```
140140
"""
141141

142+
__context_var__ = contextvars.ContextVar("step_context")
143+
142144
def __init__(
143145
self,
144146
pipeline_run: "PipelineRunResponse",
@@ -165,6 +167,8 @@ def __init__(
165167
"""
166168
from zenml.client import Client
167169

170+
super().__init__()
171+
168172
try:
169173
pipeline_run = Client().get_pipeline_run(pipeline_run.id)
170174
except KeyError:
@@ -462,6 +466,30 @@ def remove_output_tags(
462466
return
463467
output.tags = [tag for tag in output.tags if tag not in tags]
464468

469+
def __enter__(self) -> Self:
470+
"""Enter the step context.
471+
472+
Raises:
473+
RuntimeError: If the step context has already been entered.
474+
475+
Returns:
476+
The step context object.
477+
"""
478+
if self._token is not None:
479+
raise RuntimeError(
480+
"Running a step from within another step is not allowed."
481+
)
482+
return super().__enter__()
483+
484+
def __exit__(self, *_: Any) -> None:
485+
"""Exit the step context.
486+
487+
Raises:
488+
RuntimeError: If the step context has not been entered.
489+
"""
490+
self._cleanup_registry.execute_callbacks(raise_on_exception=False)
491+
super().__exit__(*_)
492+
465493

466494
class StepContextOutput:
467495
"""Represents a step output in the step context."""

src/zenml/zen_stores/migrations/versions/af27025fe19c_dynamic_pipelines.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ def upgrade() -> None:
2626
)
2727

2828
with op.batch_alter_table("step_configuration", schema=None) as batch_op:
29-
# TODO: missing check constraint
3029
batch_op.add_column(
3130
sa.Column(
3231
"step_run_id", sqlmodel.sql.sqltypes.GUID(), nullable=True

0 commit comments

Comments
 (0)