Skip to content

Commit 87e2a54

Browse files
committed
Artifact cache
1 parent f1ed44a commit 87e2a54

File tree

4 files changed

+121
-44
lines changed

4 files changed

+121
-44
lines changed
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright (c) ZenML GmbH 2025. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at:
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
12+
# or implied. See the License for the specific language governing
13+
# permissions and limitations under the License.
14+
"""In-memory artifact cache."""
15+
16+
from typing import Any
17+
from uuid import UUID
18+
19+
from zenml.utils import context_utils
20+
21+
22+
class InMemoryArtifactCache(context_utils.BaseContext):
23+
"""In-memory artifact cache."""
24+
25+
__context_var__ = context_utils.ContextVar("in_memory_artifact_cache")
26+
27+
def __init__(self) -> None:
28+
"""Initialize the artifact cache."""
29+
super().__init__()
30+
self._cache = {}
31+
32+
def clear(self) -> None:
33+
"""Clear the artifact cache."""
34+
self._cache = {}
35+
36+
def get_artifact_data(self, id_: UUID) -> Any:
37+
"""Get the artifact data.
38+
39+
Args:
40+
id_: The ID of the artifact to get the data for.
41+
42+
Returns:
43+
The artifact data.
44+
"""
45+
return self._cache.get(id_)
46+
47+
def set_artifact_data(self, id_: UUID, data: Any) -> None:
48+
"""Set the artifact data.
49+
50+
Args:
51+
id_: The ID of the artifact to set the data for.
52+
data: The artifact data to set.
53+
"""
54+
self._cache[id_] = data

src/zenml/execution/pipeline/dynamic/outputs.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,13 +103,16 @@ def result(self) -> OutputArtifact:
103103
f"{result}."
104104
)
105105

106-
def load(self) -> Any:
106+
def load(self, disable_cache: bool = False) -> Any:
107107
"""Load the step run output artifact data.
108108
109+
Args:
110+
disable_cache: Whether to disable the artifact cache.
111+
109112
Returns:
110113
The step run output artifact data.
111114
"""
112-
return self.result().load()
115+
return self.result().load(disable_cache=disable_cache)
113116

114117

115118
class StepRunOutputsFuture(_BaseStepRunFuture):
@@ -157,9 +160,12 @@ def artifacts(self) -> StepRunOutputs:
157160
"""
158161
return self._wrapped.result()
159162

160-
def load(self) -> Any:
163+
def load(self, disable_cache: bool = False) -> Any:
161164
"""Get the step run output artifact data.
162165
166+
Args:
167+
disable_cache: Whether to disable the artifact cache.
168+
163169
Raises:
164170
ValueError: If the step run output is invalid.
165171
@@ -171,9 +177,11 @@ def load(self) -> Any:
171177
if result is None:
172178
return None
173179
elif isinstance(result, ArtifactVersionResponse):
174-
return result.load()
180+
return result.load(disable_cache=disable_cache)
175181
elif isinstance(result, tuple):
176-
return tuple(item.load() for item in result)
182+
return tuple(
183+
item.load(disable_cache=disable_cache) for item in result
184+
)
177185
else:
178186
raise ValueError(f"Invalid step run output: {result}")
179187

src/zenml/execution/pipeline/dynamic/runner.py

Lines changed: 39 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from uuid import UUID
3232

3333
from zenml import ExternalArtifact
34+
from zenml.artifacts.in_memory_cache import InMemoryArtifactCache
3435
from zenml.client import Client
3536
from zenml.config.compiler import Compiler
3637
from zenml.config.step_configurations import Step
@@ -150,44 +151,45 @@ def run_pipeline(self) -> None:
150151
snapshot=self._snapshot,
151152
run_id=self._run.id if self._run else None,
152153
) as logs_request:
153-
run = self._run or create_placeholder_run(
154-
snapshot=self._snapshot,
155-
orchestrator_run_id=self._orchestrator_run_id,
156-
logs=logs_request,
157-
)
158-
159-
assert (
160-
self._snapshot.pipeline_spec
161-
) # Always exists for new snapshots
162-
pipeline_parameters = self._snapshot.pipeline_spec.parameters
154+
with InMemoryArtifactCache():
155+
run = self._run or create_placeholder_run(
156+
snapshot=self._snapshot,
157+
orchestrator_run_id=self._orchestrator_run_id,
158+
logs=logs_request,
159+
)
163160

164-
with DynamicPipelineRunContext(
165-
pipeline=self.pipeline,
166-
run=run,
167-
snapshot=self._snapshot,
168-
runner=self,
169-
):
170-
self._orchestrator.run_init_hook(snapshot=self._snapshot)
171-
try:
172-
# TODO: step logging isn't threadsafe
173-
# TODO: what should be allowed as pipeline returns?
174-
# (artifacts, json serializable, anything?)
175-
# how do we show it in the UI?
176-
self.pipeline._call_entrypoint(**pipeline_parameters)
177-
except:
178-
publish_failed_pipeline_run(run.id)
179-
logger.error(
180-
"Pipeline run failed. All in-progress step runs will "
181-
"still finish executing."
182-
)
183-
raise
184-
finally:
185-
self._orchestrator.run_cleanup_hook(
186-
snapshot=self._snapshot
187-
)
188-
self._executor.shutdown(wait=True, cancel_futures=True)
189-
# self.await_all_step_run_futures()
190-
publish_successful_pipeline_run(run.id)
161+
assert (
162+
self._snapshot.pipeline_spec
163+
) # Always exists for new snapshots
164+
pipeline_parameters = self._snapshot.pipeline_spec.parameters
165+
166+
with DynamicPipelineRunContext(
167+
pipeline=self.pipeline,
168+
run=run,
169+
snapshot=self._snapshot,
170+
runner=self,
171+
):
172+
self._orchestrator.run_init_hook(snapshot=self._snapshot)
173+
try:
174+
# TODO: step logging isn't threadsafe
175+
# TODO: what should be allowed as pipeline returns?
176+
# (artifacts, json serializable, anything?)
177+
# how do we show it in the UI?
178+
self.pipeline._call_entrypoint(**pipeline_parameters)
179+
except:
180+
publish_failed_pipeline_run(run.id)
181+
logger.error(
182+
"Pipeline run failed. All in-progress step runs "
183+
"will still finish executing."
184+
)
185+
raise
186+
finally:
187+
self._orchestrator.run_cleanup_hook(
188+
snapshot=self._snapshot
189+
)
190+
self._executor.shutdown(wait=True, cancel_futures=True)
191+
# self.await_all_step_run_futures()
192+
publish_successful_pipeline_run(run.id)
191193

192194
@overload
193195
def launch_step(

src/zenml/models/v2/core/artifact_version.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -440,15 +440,28 @@ def run(self) -> "PipelineRunResponse":
440440

441441
return Client().get_pipeline_run(self.step.pipeline_run_id)
442442

443-
def load(self) -> Any:
443+
def load(self, disable_cache: bool = False) -> Any:
444444
"""Materializes (loads) the data stored in this artifact.
445445
446+
Args:
447+
disable_cache: Whether to disable the artifact cache.
448+
446449
Returns:
447450
The materialized data.
448451
"""
452+
from zenml.artifacts.in_memory_cache import InMemoryArtifactCache
449453
from zenml.artifacts.utils import load_artifact_from_response
450454

451-
return load_artifact_from_response(self)
455+
cache = InMemoryArtifactCache.get()
456+
457+
if cache and (data := cache.get_artifact_data(self.id)):
458+
logger.debug(f"Returning artifact data (%s) from cache", self.id)
459+
return data
460+
461+
data = load_artifact_from_response(self)
462+
if not disable_cache:
463+
cache.set_artifact_data(self.id, data)
464+
return data
452465

453466
def download_files(self, path: str, overwrite: bool = False) -> None:
454467
"""Downloads data for an artifact with no materializing.

0 commit comments

Comments
 (0)