Skip to content

Commit 575843d

Browse files
committed
Maybe solution for step configs
1 parent 7956099 commit 575843d

File tree

3 files changed

+21
-8
lines changed

3 files changed

+21
-8
lines changed

src/zenml/config/compiler.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from zenml.exceptions import StackValidationError
4343
from zenml.models import PipelineSnapshotBase
4444
from zenml.pipelines.run_utils import get_default_run_name
45+
from zenml.steps.step_invocation import StepInvocation
4546
from zenml.utils import pydantic_utils, secret_utils, settings_utils
4647

4748
if TYPE_CHECKING:
@@ -536,6 +537,9 @@ def _get_sorted_invocations(
536537
Returns:
537538
The sorted steps.
538539
"""
540+
if pipeline.is_dynamic:
541+
return list(pipeline.invocations.items())
542+
539543
from zenml.orchestrators.dag_runner import reverse_dag
540544
from zenml.orchestrators.topsort import topsorted_layers
541545

src/zenml/pipelines/dynamic/runner.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def _compile_step(
286286
else:
287287
external_artifacts[name] = ExternalArtifact(value=value)
288288

289-
if template := get_static_step_template(snapshot, step):
289+
if template := get_static_step_template(snapshot, step, pipeline):
290290
step._configuration = template.config
291291

292292
step._apply_dynamic_configuration()
@@ -417,11 +417,12 @@ def _runs_in_process(step: "Step") -> bool:
417417
def get_static_step_template(
418418
snapshot: "PipelineSnapshotResponse",
419419
step: "BaseStep",
420+
pipeline: "DynamicPipeline",
420421
) -> Optional["Step"]:
421-
step_source = step.resolve().import_path
422-
423-
for compiled_step in snapshot.step_configurations.values():
424-
if compiled_step.spec.source.import_path == step_source:
425-
return compiled_step
422+
for index, step_ in enumerate(pipeline.depends_on):
423+
if step_._static_id == step._static_id:
424+
break
425+
else:
426+
return None
426427

427-
return None
428+
return list(snapshot.step_configurations.values())[index]

src/zenml/steps/base_step.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ def __init__(
166166
reserved_arguments=["after", "id"],
167167
)
168168

169+
self._static_id = id(self)
169170
name = name or self.__class__.__name__
170171

171172
logger.debug(
@@ -881,7 +882,14 @@ def copy(self) -> "BaseStep":
881882
Returns:
882883
The step copy.
883884
"""
884-
return copy.deepcopy(self)
885+
copy_ = copy.deepcopy(self)
886+
887+
from zenml.pipelines.dynamic.context import DynamicPipelineRunContext
888+
889+
if not DynamicPipelineRunContext.get():
890+
copy_._static_id = id(copy_)
891+
892+
return copy_
885893

886894
def _apply_configuration(
887895
self,

0 commit comments

Comments
 (0)