Skip to content

Commit cffc807

Browse files
committed
Fix DAG
1 parent b0b6162 commit cffc807

File tree

1 file changed

+22
-7
lines changed

1 file changed

+22
-7
lines changed

src/zenml/zen_stores/sql_zen_store.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6053,13 +6053,15 @@ def get_pipeline_run_dag(self, pipeline_run_id: UUID) -> PipelineRunDAG:
60536053
"""
60546054
helper = DAGGeneratorHelper()
60556055
with Session(self.engine) as session:
6056+
# TODO: better loads for dynamic/static pipelines
60566057
run = self._get_schema_by_id(
60576058
resource_id=pipeline_run_id,
60586059
schema_class=PipelineRunSchema,
60596060
session=session,
60606061
query_options=[
60616062
selectinload(jl_arg(PipelineRunSchema.snapshot)).load_only(
60626063
jl_arg(PipelineSnapshotSchema.pipeline_configuration),
6064+
jl_arg(PipelineSnapshotSchema.is_dynamic),
60636065
),
60646066
selectinload(
60656067
jl_arg(PipelineRunSchema.snapshot)
@@ -6072,6 +6074,9 @@ def get_pipeline_run_dag(self, pipeline_run_id: UUID) -> PipelineRunDAG:
60726074
selectinload(
60736075
jl_arg(PipelineRunSchema.step_runs)
60746076
).selectinload(jl_arg(StepRunSchema.output_artifacts)),
6077+
selectinload(
6078+
jl_arg(PipelineRunSchema.step_runs)
6079+
).selectinload(jl_arg(StepRunSchema.dynamic_config)),
60756080
selectinload(jl_arg(PipelineRunSchema.step_runs))
60766081
.selectinload(jl_arg(StepRunSchema.triggered_runs))
60776082
.load_only(
@@ -6098,13 +6103,23 @@ def get_pipeline_run_dag(self, pipeline_run_id: UUID) -> PipelineRunDAG:
60986103
start_time=run.start_time, inplace=True
60996104
)
61006105

6101-
steps = {
6102-
config_table.name: Step.from_dict(
6103-
json.loads(config_table.config),
6104-
pipeline_configuration=pipeline_configuration,
6105-
)
6106-
for config_table in snapshot.step_configurations
6107-
}
6106+
if snapshot.is_dynamic:
6107+
# Ignore static steps for dynamic pipeline DAGs
6108+
steps = {
6109+
name: Step.from_dict(
6110+
json.loads(step_run.dynamic_config.config),
6111+
pipeline_configuration=pipeline_configuration,
6112+
)
6113+
for name, step_run in step_runs.items()
6114+
}
6115+
else:
6116+
steps = {
6117+
config_table.name: Step.from_dict(
6118+
json.loads(config_table.config),
6119+
pipeline_configuration=pipeline_configuration,
6120+
)
6121+
for config_table in snapshot.step_configurations
6122+
}
61086123
regular_output_artifact_nodes: Dict[
61096124
str, Dict[str, PipelineRunDAG.Node]
61106125
] = defaultdict(dict)

0 commit comments

Comments
 (0)