Skip to content

Commit ab3c031

Browse files
authored
Compute cascading tags server-side (#3781)
* Compute cascading tags server-side * Linting * More linting
1 parent 331eaef commit ab3c031

File tree

4 files changed

+64
-44
lines changed

4 files changed

+64
-44
lines changed

src/zenml/orchestrators/step_launcher.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -292,10 +292,6 @@ def _bypass() -> None:
292292
artifacts=step_run.outputs,
293293
model_version=model_version,
294294
)
295-
step_run_utils.cascade_tags_for_output_artifacts(
296-
artifacts=step_run.outputs,
297-
tags=pipeline_run.config.tags,
298-
)
299295

300296
except: # noqa: E722
301297
logger.error(f"Pipeline run `{pipeline_run.name}` failed.")

src/zenml/orchestrators/step_run_utils.py

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,8 @@
1414
"""Utilities for creating step runs."""
1515

1616
import json
17-
from typing import Dict, List, Optional, Set, Tuple, Union
17+
from typing import Dict, List, Optional, Set, Tuple
1818

19-
from zenml import Tag, add_tags
2019
from zenml.client import Client
2120
from zenml.config.step_configurations import Step
2221
from zenml.constants import CODE_HASH_PARAMETER_NAME, TEXT_FIELD_MAX_LENGTH
@@ -404,29 +403,6 @@ def link_output_artifacts_to_model_version(
404403
)
405404

406405

407-
def cascade_tags_for_output_artifacts(
408-
artifacts: Dict[str, List[ArtifactVersionResponse]],
409-
tags: Optional[List[Union[str, Tag]]] = None,
410-
) -> None:
411-
"""Tag the outputs of a step run.
412-
413-
Args:
414-
artifacts: The step output artifacts.
415-
tags: The tags to add to the artifacts.
416-
"""
417-
if tags is None:
418-
return
419-
420-
cascade_tags = [t for t in tags if isinstance(t, Tag) and t.cascade]
421-
422-
for output_artifacts in artifacts.values():
423-
for output_artifact in output_artifacts:
424-
add_tags(
425-
tags=[t.name for t in cascade_tags],
426-
artifact_version_id=output_artifact.id,
427-
)
428-
429-
430406
def publish_cached_step_run(
431407
request: "StepRunRequest", pipeline_run: "PipelineRunResponse"
432408
) -> "StepRunResponse":
@@ -447,11 +423,6 @@ def publish_cached_step_run(
447423
model_version=model_version,
448424
)
449425

450-
cascade_tags_for_output_artifacts(
451-
artifacts=step_run.outputs,
452-
tags=pipeline_run.config.tags,
453-
)
454-
455426
return step_run
456427

457428

src/zenml/zen_stores/schemas/pipeline_run_schemas.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,29 @@ def from_request(
334334
trigger_execution_id=request.trigger_execution_id,
335335
)
336336

337+
def get_pipeline_configuration(self) -> PipelineConfiguration:
338+
"""Get the pipeline configuration for the pipeline run.
339+
340+
Raises:
341+
RuntimeError: if the pipeline run has no deployment and no pipeline
342+
configuration.
343+
344+
Returns:
345+
The pipeline configuration.
346+
"""
347+
if self.deployment:
348+
return PipelineConfiguration.model_validate_json(
349+
self.deployment.pipeline_configuration
350+
)
351+
elif self.pipeline_configuration:
352+
return PipelineConfiguration.model_validate_json(
353+
self.pipeline_configuration
354+
)
355+
else:
356+
raise RuntimeError(
357+
"Pipeline run has no deployment and no pipeline configuration."
358+
)
359+
337360
def fetch_metadata_collection(
338361
self, include_full_metadata: bool = False, **kwargs: Any
339362
) -> Dict[str, List[RunMetadataEntry]]:

src/zenml/zen_stores/sql_zen_store.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8835,7 +8835,6 @@ def create_run_step(self, step_run: StepRunRequest) -> StepRunResponse:
88358835
reference_id=step_run.pipeline_run_id,
88368836
session=session,
88378837
)
8838-
88398838
self._get_reference_schema_by_id(
88408839
resource=step_run,
88418840
reference_schema=StepRunSchema,
@@ -8911,15 +8910,41 @@ def create_run_step(self, step_run: StepRunRequest) -> StepRunResponse:
89118910
)
89128911
for link in original_metadata_links
89138912
]
8914-
# Add all new links in a single operation
8915-
session.add_all(new_links)
8916-
# Commit the changes
8917-
session.commit()
8918-
session.refresh(step_schema)
89198913

8920-
session.commit()
8921-
session.refresh(step_schema)
8914+
if new_links:
8915+
session.add_all(new_links)
8916+
session.commit()
8917+
session.refresh(step_schema, ["run_metadata"])
89228918

8919+
if step_run.status == ExecutionStatus.CACHED:
8920+
from zenml.utils.tag_utils import Tag
8921+
8922+
cascading_tags = [
8923+
tag
8924+
for tag in run.get_pipeline_configuration().tags or []
8925+
if isinstance(tag, Tag) and tag.cascade
8926+
]
8927+
8928+
if cascading_tags:
8929+
output_artifact_ids = [
8930+
id for ids in step_run.outputs.values() for id in ids
8931+
]
8932+
output_artifacts = list(
8933+
session.exec(
8934+
select(ArtifactVersionSchema).where(
8935+
col(ArtifactVersionSchema.id).in_(
8936+
output_artifact_ids
8937+
)
8938+
)
8939+
).all()
8940+
)
8941+
self._attach_tags_to_resources(
8942+
cascading_tags,
8943+
resources=output_artifacts,
8944+
session=session,
8945+
)
8946+
8947+
session.commit()
89238948
step_model = step_schema.to_model(include_metadata=True)
89248949

89258950
for upstream_step in step_model.spec.upstream_steps:
@@ -8977,7 +9002,9 @@ def create_run_step(self, step_run: StepRunRequest) -> StepRunResponse:
89779002
)
89789003

89799004
session.commit()
8980-
session.refresh(step_schema)
9005+
session.refresh(
9006+
step_schema, ["input_artifacts", "output_artifacts"]
9007+
)
89819008

89829009
if model_version_id := self._get_or_create_model_version_for_run(
89839010
step_schema
@@ -9330,6 +9357,9 @@ def _update_pipeline_run_status(
93309357
"""
93319358
from zenml.orchestrators.publish_utils import get_pipeline_run_status
93329359

9360+
# Make sure we start with a fresh transaction before locking the
9361+
# pipeline run
9362+
session.commit()
93339363
pipeline_run = session.exec(
93349364
select(PipelineRunSchema)
93359365
.with_for_update()
@@ -12253,7 +12283,7 @@ def _get_tag_schema(
1225312283
def _attach_tags_to_resources(
1225412284
self,
1225512285
tags: Optional[Sequence[Union[str, tag_utils.Tag]]],
12256-
resources: Union[BaseSchema, List[BaseSchema]],
12286+
resources: Union[BaseSchema, Sequence[BaseSchema]],
1225712287
session: Session,
1225812288
) -> None:
1225912289
"""Attaches multiple tags to multiple resources.

0 commit comments

Comments
 (0)