File tree Expand file tree Collapse file tree 5 files changed +25
-20
lines changed Expand file tree Collapse file tree 5 files changed +25
-20
lines changed Original file line number Diff line number Diff line change 4242from zenml .exceptions import StackValidationError
4343from zenml .models import PipelineSnapshotBase
4444from zenml .pipelines .run_utils import get_default_run_name
45+ from zenml .steps .step_invocation import StepInvocation
4546from zenml .utils import pydantic_utils , secret_utils , settings_utils
4647
4748if TYPE_CHECKING :
@@ -536,6 +537,9 @@ def _get_sorted_invocations(
536537 Returns:
537538 The sorted steps.
538539 """
540+ if pipeline .is_dynamic :
541+ return list (pipeline .invocations .items ())
542+
539543 from zenml .orchestrators .dag_runner import reverse_dag
540544 from zenml .orchestrators .topsort import topsorted_layers
541545
Original file line number Diff line number Diff line change @@ -261,6 +261,7 @@ class PartialStepConfiguration(StepConfigurationUpdate):
261261 """Class representing a partial step configuration."""
262262
263263 name : str
264+ template : Optional [str ] = None
264265 parameters : Dict [str , Any ] = {}
265266 settings : Dict [str , SerializeAsAny [BaseSettings ]] = {}
266267 environment : Dict [str , str ] = {}
Original file line number Diff line number Diff line change @@ -59,18 +59,8 @@ def get_image(self, key: str) -> str:
5959 )
6060
6161 if self .snapshot .is_dynamic :
62- # TODO: better way for this
63- for (
64- invocation_id ,
65- compiled_step ,
66- ) in self .snapshot .step_configurations .values ():
67- if (
68- compiled_step .spec .source .import_path
69- == self .spec .source .import_path
70- ):
71- step_key = invocation_id
72- break
73- else :
62+ step_key = self .config .template
63+ if not step_key :
7464 logger .warning (
7565 "Unable to find config template for step %s. Falling "
7666 "back to the pipeline image." ,
Original file line number Diff line number Diff line change @@ -286,8 +286,9 @@ def _compile_step(
286286 else :
287287 external_artifacts [name ] = ExternalArtifact (value = value )
288288
289- if template := get_static_step_template (snapshot , step ):
289+ if template := get_static_step_template (snapshot , step , pipeline ):
290290 step ._configuration = template .config
291+ step ._configuration .template = template .spec .invocation_id
291292
292293 step ._apply_dynamic_configuration ()
293294 invocation_id = pipeline .add_dynamic_invocation (
@@ -417,11 +418,12 @@ def _runs_in_process(step: "Step") -> bool:
417418def get_static_step_template (
418419 snapshot : "PipelineSnapshotResponse" ,
419420 step : "BaseStep" ,
421+ pipeline : "DynamicPipeline" ,
420422) -> Optional ["Step" ]:
421- step_source = step . resolve (). import_path
422-
423- for compiled_step in snapshot . step_configurations . values ():
424- if compiled_step . spec . source . import_path == step_source :
425- return compiled_step
423+ for index , step_ in enumerate ( pipeline . depends_on ):
424+ if step_ . _static_id == step . _static_id :
425+ break
426+ else :
427+ return None
426428
427- return None
429+ return list ( snapshot . step_configurations . values ())[ index ]
Original file line number Diff line number Diff line change @@ -166,6 +166,7 @@ def __init__(
166166 reserved_arguments = ["after" , "id" ],
167167 )
168168
169+ self ._static_id = id (self )
169170 name = name or self .__class__ .__name__
170171
171172 logger .debug (
@@ -881,7 +882,14 @@ def copy(self) -> "BaseStep":
881882 Returns:
882883 The step copy.
883884 """
884- return copy .deepcopy (self )
885+ copy_ = copy .deepcopy (self )
886+
887+ from zenml .pipelines .dynamic .context import DynamicPipelineRunContext
888+
889+ if not DynamicPipelineRunContext .get ():
890+ copy_ ._static_id = id (copy_ )
891+
892+ return copy_
885893
886894 def _apply_configuration (
887895 self ,
You can’t perform that action at this time.
0 commit comments