Skip to content

Commit 97a2db0

Browse files
Feature:4015 Bulk log metadata functionality
1 parent d48985a commit 97a2db0

File tree

7 files changed

+613
-2
lines changed

7 files changed

+613
-2
lines changed

src/zenml/models/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,11 @@
320320
StackResponseMetadata,
321321
StackResponseResources
322322
)
323+
from zenml.models.v2.misc.param_groups import (
324+
PipelineRunIdentifier,
325+
StepRunIdentifier,
326+
VersionedIdentifier,
327+
)
323328
from zenml.models.v2.misc.statistics import (
324329
ProjectStatistics,
325330
ServerStatistics,
@@ -874,4 +879,7 @@
874879
"ProjectStatistics",
875880
"PipelineRunDAG",
876881
"ExceptionInfo",
882+
"VersionedIdentifier",
883+
"PipelineRunIdentifier",
884+
"StepRunIdentifier",
877885
]
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Copyright (c) ZenML GmbH 2024. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at:
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
12+
# or implied. See the License for the specific language governing
13+
# permissions and limitations under the License.
14+
"""Parameter group classes."""
15+
16+
from uuid import UUID
17+
18+
from pydantic import BaseModel, model_validator
19+
20+
21+
class VersionedIdentifier(BaseModel):
22+
"""Class grouping identifiers for entities resolved by UUID or name&version."""
23+
24+
id: UUID | None = None
25+
name: str | None = None
26+
version: str | None = None
27+
28+
@model_validator(mode="after")
29+
def _validate_options(self) -> "VersionedIdentifier":
30+
if self.id and self.name:
31+
raise ValueError(
32+
"You can use only identification option at a time."
33+
"Use either id or name."
34+
)
35+
36+
if not (self.id or self.name):
37+
raise ValueError(
38+
"You have to use at least one identification option."
39+
"Use either id or name."
40+
)
41+
42+
if bool(self.name) ^ bool(self.version):
43+
raise ValueError("You need to specify both name and version.")
44+
45+
return self
46+
47+
48+
class PipelineRunIdentifier(BaseModel):
49+
"""Class grouping different pipeline run identifiers."""
50+
51+
id: UUID | None = None
52+
name: str | None = None
53+
prefix: str | None = None
54+
55+
@property
56+
def value(self) -> str | UUID:
57+
"""Resolves the set value out of id, name, prefix etc.
58+
59+
Returns:
60+
The id/name/prefix (if set, in this exact order).
61+
"""
62+
return self.id or self.name or self.prefix # type: ignore[return-value]
63+
64+
@model_validator(mode="after")
65+
def _validate_options(self) -> "PipelineRunIdentifier":
66+
options = [
67+
bool(self.id),
68+
bool(self.name),
69+
bool(self.prefix),
70+
]
71+
72+
if sum(options) > 1:
73+
raise ValueError(
74+
"You can use only identification option at a time."
75+
"Use either id or name or prefix."
76+
)
77+
78+
if sum(options) == 0:
79+
raise ValueError(
80+
"You have to use at least one identification option."
81+
"Use either id or name or prefix."
82+
)
83+
84+
return self
85+
86+
87+
class StepRunIdentifier(BaseModel):
88+
"""Class grouping different step run identifiers."""
89+
90+
id: UUID | None = None
91+
name: str | None = None
92+
pipeline: PipelineRunIdentifier | None = None
93+
94+
@model_validator(mode="after")
95+
def _validate_options(self) -> "StepRunIdentifier":
96+
if self.id and self.name:
97+
raise ValueError(
98+
"You can use only identification option at a time."
99+
"Use either id or name."
100+
)
101+
102+
if not (self.id or self.name):
103+
raise ValueError(
104+
"You have to use at least one identification option."
105+
"Use either id or name."
106+
)
107+
108+
if bool(self.name) ^ bool(self.pipeline):
109+
raise ValueError(
110+
"To identify a run by name you need to specify a pipeline run identifier."
111+
)
112+
113+
return self

src/zenml/models/v2/misc/run_metadata.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"""Utility classes for modeling run metadata."""
1515

1616
from datetime import datetime
17+
from typing import Any
1718
from uuid import UUID
1819

1920
from pydantic import BaseModel, Field
@@ -28,6 +29,31 @@ class RunMetadataResource(BaseModel):
2829
id: UUID = Field(title="The ID of the resource.")
2930
type: MetadataResourceTypes = Field(title="The type of the resource.")
3031

32+
def __eq__(self, other: Any):
33+
"""Overrides equality operator.
34+
35+
Args:
36+
other: The object to compare.
37+
38+
Returns:
39+
True if the object is equal to the given object.
40+
41+
Raises:
42+
TypeError: If the object is not an instance of RunMetadataResource.
43+
"""
44+
if not isinstance(other, RunMetadataResource):
45+
raise TypeError(f"Expected RunMetadataResource, got {type(other)}")
46+
47+
return hash(other) == hash(self)
48+
49+
def __hash__(self) -> int:
50+
"""Overrides hash operator.
51+
52+
Returns:
53+
The hash value of the object.
54+
"""
55+
return hash(f"{str(self.id)}_{self.type.value}")
56+
3157

3258
class RunMetadataEntry(BaseModel):
3359
"""Utility class to sort/list run metadata entries."""

src/zenml/utils/metadata_utils.py

Lines changed: 165 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,19 @@
1313
# permissions and limitations under the License.
1414
"""Utility functions to handle metadata for ZenML entities."""
1515

16-
from typing import Dict, List, Optional, Union, overload
16+
from typing import Dict, List, Optional, Set, Union, overload
1717
from uuid import UUID
1818

1919
from zenml.client import Client
2020
from zenml.enums import MetadataResourceTypes, ModelStages
2121
from zenml.logger import get_logger
2222
from zenml.metadata.metadata_types import MetadataType
23-
from zenml.models import RunMetadataResource
23+
from zenml.models import (
24+
PipelineRunIdentifier,
25+
RunMetadataResource,
26+
StepRunIdentifier,
27+
VersionedIdentifier,
28+
)
2429
from zenml.steps.step_context import get_step_context
2530

2631
logger = get_logger(__name__)
@@ -366,3 +371,161 @@ def log_metadata(
366371
resources=resources,
367372
publisher_step_id=publisher_step_id,
368373
)
374+
375+
376+
def bulk_log_metadata(
377+
metadata: Dict[str, MetadataType],
378+
pipeline_runs: list[PipelineRunIdentifier] | None = None,
379+
step_runs: list[StepRunIdentifier] | None = None,
380+
artifact_versions: list[VersionedIdentifier] | None = None,
381+
model_versions: list[VersionedIdentifier] | None = None,
382+
infer_models: bool = False,
383+
infer_artifacts: bool = False,
384+
) -> None:
385+
"""Logs metadata for multiple entities in a single invocation.
386+
387+
Args:
388+
metadata: The metadata to log.
389+
pipeline_runs: A list of pipeline runs to log metadata for.
390+
step_runs: A list of step runs to log metadata for.
391+
artifact_versions: A list of artifact versions to log metadata for.
392+
model_versions: A list of model versions to log metadata for.
393+
infer_models: Flag - when enabled infer model to log metadata for from step context.
394+
infer_artifacts: Flag - when enabled infer artifact to log metadata for from step context.
395+
396+
Raises:
397+
ValueError: If options are not passed correctly (infer options with explicit declarations) or
398+
invocation with `infer` options is done outside of a step context.
399+
"""
400+
client = Client()
401+
402+
resources: Set[RunMetadataResource] = set()
403+
404+
if not metadata:
405+
raise ValueError("You must provide metadata to log.")
406+
407+
if not any(
408+
bool(v)
409+
for v in [
410+
pipeline_runs,
411+
step_runs,
412+
artifact_versions,
413+
model_versions,
414+
infer_models,
415+
infer_artifacts,
416+
]
417+
):
418+
raise ValueError(
419+
"You must select at least one pipeline/step/artifact/model to log metadata to."
420+
)
421+
422+
if infer_models and model_versions:
423+
raise ValueError(
424+
"You can either specify model versions or use the infer option."
425+
)
426+
427+
if infer_artifacts and artifact_versions:
428+
raise ValueError(
429+
"You can either specify artifact versions or use the infer option."
430+
)
431+
432+
try:
433+
step_context = get_step_context()
434+
except RuntimeError:
435+
step_context = None
436+
437+
if (infer_models or infer_artifacts) and step_context is None:
438+
raise ValueError(
439+
"Infer options can be used only within a step function code."
440+
)
441+
442+
# resolve pipeline runs and add metadata resources
443+
444+
for pipeline in pipeline_runs or []:
445+
if not pipeline.id:
446+
pipeline.id = client.get_pipeline_run(
447+
name_id_or_prefix=pipeline.value
448+
).id
449+
resources.add(
450+
RunMetadataResource(
451+
id=pipeline.id, type=MetadataResourceTypes.PIPELINE_RUN
452+
)
453+
)
454+
455+
# resolve step runs and add metadata resources
456+
457+
for step in step_runs or []:
458+
if not step.id:
459+
step.id = (
460+
client.get_pipeline_run(name_id_or_prefix=step.pipeline.value)
461+
.steps[step.name]
462+
.id
463+
)
464+
465+
resources.add(
466+
RunMetadataResource(
467+
id=step.id, type=MetadataResourceTypes.STEP_RUN
468+
)
469+
)
470+
471+
# resolve artifacts and add metadata resources
472+
473+
for artifact_version in artifact_versions or []:
474+
if not artifact_version.id:
475+
artifact_version.id = client.get_artifact_version(
476+
name_id_or_prefix=artifact_version.name,
477+
version=artifact_version.version,
478+
).id
479+
resources.add(
480+
RunMetadataResource(
481+
id=artifact_version.id,
482+
type=MetadataResourceTypes.ARTIFACT_VERSION,
483+
)
484+
)
485+
486+
# resolve models and add metadata resources
487+
488+
for model_version in model_versions or []:
489+
if not model_version.id:
490+
model_version.id = client.get_model_version(
491+
model_name_or_id=model_version.name,
492+
model_version_name_or_number_or_id=model_version.version,
493+
).id
494+
resources.add(
495+
RunMetadataResource(
496+
id=model_version.id, type=MetadataResourceTypes.MODEL_VERSION
497+
)
498+
)
499+
500+
# infer models - resolve from step context
501+
502+
if infer_models and not step_context.model_version:
503+
raise ValueError(
504+
"The step context does not feature any model versions."
505+
)
506+
elif infer_models:
507+
resources.add(
508+
RunMetadataResource(
509+
id=step_context.model_version.id,
510+
type=MetadataResourceTypes.MODEL_VERSION,
511+
)
512+
)
513+
514+
# infer artifacts - resolve from step context
515+
516+
if infer_artifacts:
517+
step_output_names = list(step_context._outputs.keys())
518+
519+
for artifact_name in step_output_names:
520+
step_context.add_output_metadata(
521+
metadata=metadata, output_name=artifact_name
522+
)
523+
524+
if not resources:
525+
return
526+
527+
client.create_run_metadata(
528+
metadata=metadata,
529+
resources=list(resources),
530+
publisher_step_id=None,
531+
)

0 commit comments

Comments
 (0)