Skip to content

Commit 438b003

Browse files
committed
Threadsafe static pipeline compilation
1 parent 137413b commit 438b003

File tree

7 files changed

+214
-112
lines changed

7 files changed

+214
-112
lines changed
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import contextvars
2+
from typing import TYPE_CHECKING, Self
3+
4+
from zenml.utils import context_utils
5+
6+
if TYPE_CHECKING:
7+
from zenml.pipelines.pipeline_definition import Pipeline
8+
9+
10+
class PipelineCompilationContext(context_utils.BaseContext):
11+
"""Pipeline compilation context."""
12+
13+
__context_var__ = contextvars.ContextVar("pipeline_compilation_context")
14+
15+
def __init__(
16+
self,
17+
pipeline: "Pipeline",
18+
) -> None:
19+
"""Initialize the pipeline compilation context.
20+
21+
Args:
22+
pipeline: The pipeline that is being compiled.
23+
"""
24+
super().__init__()
25+
self._pipeline = pipeline
26+
27+
@property
28+
def pipeline(self) -> "Pipeline":
29+
"""The pipeline that is being compiled.
30+
31+
Returns:
32+
The pipeline that is being compiled.
33+
"""
34+
return self._pipeline
35+
36+
def __enter__(self) -> Self:
37+
"""Enter the pipeline compilation context.
38+
39+
Raises:
40+
RuntimeError: If the pipeline compilation context has already been
41+
entered.
42+
43+
Returns:
44+
The pipeline compilation context object.
45+
"""
46+
if self._token is not None:
47+
raise RuntimeError(
48+
"Compiling a pipeline while another pipeline is being compiled "
49+
"is not allowed."
50+
)
51+
return super().__enter__()
Lines changed: 39 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,30 @@
1+
# Copyright (c) ZenML GmbH 2025. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at:
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
12+
# or implied. See the License for the specific language governing
13+
# permissions and limitations under the License.
14+
"""Dynamic pipeline run context."""
15+
116
import contextvars
2-
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Self, cast
17+
from typing import TYPE_CHECKING, Self
18+
19+
from zenml.utils import context_utils
320

421
if TYPE_CHECKING:
522
from zenml.models import PipelineRunResponse, PipelineSnapshotResponse
623
from zenml.pipelines.dynamic.pipeline_definition import DynamicPipeline
724
from zenml.pipelines.dynamic.runner import DynamicPipelineRunner
825

926

10-
class BaseContext:
11-
"""Base context class."""
12-
13-
__context_var__: ClassVar[contextvars.ContextVar[Self]]
14-
15-
def __init__(self) -> None:
16-
"""Initialize the context."""
17-
self._token: Optional[contextvars.Token[Any]] = None
18-
19-
@classmethod
20-
def get(cls: type[Self]) -> Optional[Self]:
21-
"""Get the active context for the current thread.
22-
23-
Returns:
24-
The active context for the current thread.
25-
"""
26-
return cast(Optional[Self], cls.__context_var__.get(None))
27-
28-
def __enter__(self) -> Self:
29-
"""Enter the context.
30-
31-
Returns:
32-
The context object.
33-
"""
34-
self._token = self.__context_var__.set(self)
35-
return self
36-
37-
def __exit__(self, *_: Any) -> None:
38-
"""Exit the context.
39-
40-
Raises:
41-
RuntimeError: If the context has not been entered.
42-
"""
43-
if not self._token:
44-
raise RuntimeError(
45-
f"Can't exit {self.__class__.__name__} because it has not been "
46-
"entered."
47-
)
48-
self.__context_var__.reset(self._token)
49-
50-
51-
class DynamicPipelineRunContext(BaseContext):
27+
class DynamicPipelineRunContext(context_utils.BaseContext):
5228
"""Dynamic pipeline run context."""
5329

5430
__context_var__ = contextvars.ContextVar("dynamic_pipeline_run_context")
@@ -76,18 +52,38 @@ def __init__(
7652

7753
@property
7854
def pipeline(self) -> "DynamicPipeline":
55+
"""The pipeline that is being executed.
56+
57+
Returns:
58+
The pipeline that is being executed.
59+
"""
7960
return self._pipeline
8061

8162
@property
8263
def run(self) -> "PipelineRunResponse":
64+
"""The pipeline run.
65+
66+
Returns:
67+
The pipeline run.
68+
"""
8369
return self._run
8470

8571
@property
8672
def snapshot(self) -> "PipelineSnapshotResponse":
73+
"""The snapshot of the pipeline.
74+
75+
Returns:
76+
The snapshot of the pipeline.
77+
"""
8778
return self._snapshot
8879

8980
@property
9081
def runner(self) -> "DynamicPipelineRunner":
82+
"""The runner executing the pipeline.
83+
84+
Returns:
85+
The runner executing the pipeline.
86+
"""
9187
return self._runner
9288

9389
def __enter__(self) -> Self:
@@ -105,7 +101,3 @@ def __enter__(self) -> Self:
105101
"Calling a pipeline within a dynamic pipeline is not allowed."
106102
)
107103
return super().__enter__()
108-
109-
110-
def executing_dynamic_pipeline() -> bool:
111-
return DynamicPipelineRunContext.get() is not None

src/zenml/pipelines/dynamic/runner.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
# Copyright (c) ZenML GmbH 2025. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at:
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
12+
# or implied. See the License for the specific language governing
13+
# permissions and limitations under the License.
14+
"""Dynamic pipeline runner."""
15+
116
import contextvars
217
import inspect
318
import time
@@ -34,8 +49,8 @@
3449
publish_successful_pipeline_run,
3550
)
3651
from zenml.orchestrators.step_launcher import StepLauncher
37-
from zenml.pipelines.dynamic.context import DynamicPipelineRunContext
3852
from zenml.pipelines.dynamic.pipeline_definition import DynamicPipeline
53+
from zenml.pipelines.dynamic.run_context import DynamicPipelineRunContext
3954
from zenml.pipelines.run_utils import create_placeholder_run
4055
from zenml.stack import Stack
4156
from zenml.steps.entrypoint_function_utils import StepArtifact

src/zenml/pipelines/pipeline_context.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,11 @@ def get_pipeline_context() -> "PipelineContext":
3333
RuntimeError: If no active pipeline is found.
3434
RuntimeError: If inside a running step.
3535
"""
36-
from zenml.pipelines.pipeline_definition import Pipeline
36+
from zenml.pipelines.compilation_context import PipelineCompilationContext
3737

38-
if Pipeline.ACTIVE_PIPELINE is None:
38+
context = PipelineCompilationContext.get()
39+
40+
if context is None:
3941
try:
4042
from zenml.steps.step_context import get_step_context
4143

@@ -49,7 +51,7 @@ def get_pipeline_context() -> "PipelineContext":
4951
)
5052

5153
return PipelineContext(
52-
pipeline_configuration=Pipeline.ACTIVE_PIPELINE.configuration
54+
pipeline_configuration=context.pipeline.configuration
5355
)
5456

5557

src/zenml/pipelines/pipeline_definition.py

Lines changed: 12 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
TYPE_CHECKING,
2323
Any,
2424
Callable,
25-
ClassVar,
2625
Dict,
2726
Iterator,
2827
List,
@@ -81,6 +80,7 @@
8180
ScheduleRequest,
8281
)
8382
from zenml.pipelines import build_utils
83+
from zenml.pipelines.compilation_context import PipelineCompilationContext
8484
from zenml.pipelines.run_utils import (
8585
create_placeholder_run,
8686
should_prevent_pipeline_execution,
@@ -131,11 +131,6 @@
131131
class Pipeline:
132132
"""ZenML pipeline class."""
133133

134-
# The active pipeline is the pipeline to which step invocations will be
135-
# added when a step is called. It is set using a context manager when a
136-
# pipeline is called (see Pipeline.__call__ for more context)
137-
ACTIVE_PIPELINE: ClassVar[Optional["Pipeline"]] = None
138-
139134
def __init__(
140135
self,
141136
name: str,
@@ -627,10 +622,7 @@ def pipeline_(param_name: str):
627622
if k not in kwargs:
628623
kwargs[k] = v_config
629624

630-
with self:
631-
# Enter the context manager, so we become the active pipeline. This
632-
# means that all steps that get called while the entrypoint function
633-
# is executed will be added as invocation to this pipeline instance.
625+
with PipelineCompilationContext(pipeline=self):
634626
self._call_entrypoint(*args, **kwargs)
635627

636628
def register(self) -> "PipelineResponse":
@@ -1361,7 +1353,15 @@ def add_step_invocation(
13611353
Returns:
13621354
The step invocation ID.
13631355
"""
1364-
if not self.is_dynamic and Pipeline.ACTIVE_PIPELINE != self:
1356+
from zenml.pipelines.dynamic.run_context import (
1357+
DynamicPipelineRunContext,
1358+
)
1359+
1360+
context = (
1361+
PipelineCompilationContext.get() or DynamicPipelineRunContext.get()
1362+
)
1363+
1364+
if not context or context.pipeline != self:
13651365
raise RuntimeError(
13661366
"A step invocation can only be added to an active pipeline."
13671367
)
@@ -1429,32 +1429,6 @@ def _compute_invocation_id(
14291429

14301430
raise RuntimeError("Unable to find step ID")
14311431

1432-
def __enter__(self) -> Self:
1433-
"""Activate the pipeline context.
1434-
1435-
Raises:
1436-
RuntimeError: If a different pipeline is already active.
1437-
1438-
Returns:
1439-
The pipeline instance.
1440-
"""
1441-
if Pipeline.ACTIVE_PIPELINE:
1442-
raise RuntimeError(
1443-
"Unable to enter pipeline context. A different pipeline "
1444-
f"{Pipeline.ACTIVE_PIPELINE.name} is already active."
1445-
)
1446-
1447-
Pipeline.ACTIVE_PIPELINE = self
1448-
return self
1449-
1450-
def __exit__(self, *args: Any) -> None:
1451-
"""Deactivates the pipeline context.
1452-
1453-
Args:
1454-
*args: The arguments passed to the context exit handler.
1455-
"""
1456-
Pipeline.ACTIVE_PIPELINE = None
1457-
14581432
def _parse_config_file(
14591433
self, config_path: Optional[str], matcher: List[str]
14601434
) -> Dict[str, Any]:
@@ -1590,7 +1564,7 @@ def __call__(
15901564
`entrypoint` method. Otherwise, returns the pipeline run or `None`
15911565
if running with a schedule.
15921566
"""
1593-
if Pipeline.ACTIVE_PIPELINE:
1567+
if PipelineCompilationContext.is_active():
15941568
# Calling a pipeline inside a pipeline, we return the potential
15951569
# outputs of the entrypoint function
15961570

0 commit comments

Comments
 (0)