Skip to content

Commit aba28cf

Browse files
committed
Enhance artifact storage handling for in-memory materializers
This commit introduces logic to handle in-memory materializers more effectively by avoiding unnecessary interactions with the artifact store. When using an in-memory materializer, the URI is now prefixed with "memory://" to clearly indicate its storage type. Additionally, the artifact store's `makedirs` method is only called when not using in-memory storage, improving performance and clarity. Furthermore, the parameter model construction in the deployment service has been refined for better readability and maintainability. No functional changes were made to the application code outside of these improvements.
1 parent a661699 commit aba28cf

File tree

7 files changed

+100
-41
lines changed

7 files changed

+100
-41
lines changed

src/zenml/artifacts/utils.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,21 @@ def _store_artifact_data_and_prepare_request(
152152
Artifact version request for the artifact data that was stored.
153153
"""
154154
artifact_store = Client().active_stack.artifact_store
155-
artifact_store.makedirs(uri)
155+
156+
# Detect in-memory materializer to avoid touching the artifact store.
157+
# Local import to minimize import-time dependencies.
158+
from zenml.materializers.in_memory_materializer import (
159+
InMemoryMaterializer,
160+
)
161+
162+
is_in_memory = issubclass(materializer_class, InMemoryMaterializer)
163+
164+
if not is_in_memory:
165+
artifact_store.makedirs(uri)
166+
else:
167+
# Ensure URI clearly indicates in-memory storage and not the artifact store
168+
if not uri.startswith("memory://"):
169+
uri = f"memory://custom_artifacts/{name}/{uuid4()}"
156170

157171
materializer = materializer_class(uri=uri, artifact_store=artifact_store)
158172
materializer.uri = materializer.uri.replace("\\", "/")
@@ -190,7 +204,7 @@ def _store_artifact_data_and_prepare_request(
190204
data_type=source_utils.resolve(data_type),
191205
content_hash=content_hash,
192206
project=Client().active_project.id,
193-
artifact_store_id=artifact_store.id,
207+
artifact_store_id=None if is_in_memory else artifact_store.id,
194208
visualizations=visualizations,
195209
has_custom_name=has_custom_name,
196210
save_type=save_type,

src/zenml/deployers/server/parameters.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,14 @@ def build_params_model_from_snapshot(
3737
3838
Args:
3939
snapshot: The snapshot to derive the model from.
40-
strict: Whether to raise an error if the model cannot be constructed.
4140
4241
Returns:
4342
A Pydantic `BaseModel` subclass that validates the pipeline parameters,
44-
or None if the model could not be constructed.
43+
or None if the snapshot lacks a valid `pipeline_spec.source`.
4544
4645
Raises:
47-
RuntimeError: If the model cannot be constructed and `strict` is True.
48-
Exception: If loading the pipeline class fails when `strict` is True.
46+
RuntimeError: If the pipeline class cannot be loaded or if no
47+
parameters model can be constructed for the pipeline.
4948
"""
5049
if not snapshot.pipeline_spec or not snapshot.pipeline_spec.source:
5150
msg = (
@@ -66,13 +65,7 @@ def build_params_model_from_snapshot(
6665

6766
model = pipeline_class.get_parameters_model()
6867
if not model:
69-
message = (
70-
f"Failed to construct parameters model from pipeline "
71-
f"`{snapshot.pipeline_configuration.name}`."
68+
raise RuntimeError(
69+
f"Failed to construct parameters model from pipeline `{snapshot.pipeline_configuration.name}`."
7270
)
73-
logger.error(message)
74-
raise RuntimeError(message)
75-
else:
76-
logger.debug(message)
77-
7871
return model

src/zenml/deployers/server/runtime.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,4 +167,4 @@ def get_in_memory_data(uri: str) -> Any:
167167
if is_active():
168168
state = _get_context()
169169
return state.in_memory_data.get(uri)
170-
return None
170+
return None

src/zenml/deployers/server/service.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import time
1818
import traceback
1919
from datetime import datetime, timezone
20-
from typing import Any, Dict, Optional, Type, Union
20+
from typing import Any, Dict, Optional, Tuple, Type, Union
2121
from uuid import UUID, uuid4
2222

2323
from pydantic import BaseModel
@@ -136,7 +136,9 @@ def initialize(self) -> None:
136136
integration_registry.activate_integrations()
137137

138138
# Build parameter model
139-
self._params_model = build_params_model_from_snapshot(self.snapshot, strict=True)
139+
self._params_model = build_params_model_from_snapshot(
140+
snapshot=self.snapshot,
141+
)
140142

141143
# Initialize orchestrator
142144
self._orchestrator = SharedLocalOrchestrator(
@@ -208,11 +210,12 @@ def execute_pipeline(
208210

209211
placeholder_run: Optional[PipelineRunResponse] = None
210212
try:
211-
placeholder_run = self._prepare_execute_with_orchestrator()
212-
213213
# Execute pipeline and get runtime outputs captured internally
214-
captured_outputs = self._execute_with_orchestrator(
215-
placeholder_run, parameters, request.use_in_memory
214+
placeholder_run, captured_outputs = (
215+
self._execute_with_orchestrator(
216+
resolved_params=parameters,
217+
use_in_memory=request.use_in_memory,
218+
)
216219
)
217220

218221
# Map outputs using fast (in-memory) or slow (artifact) path
@@ -327,19 +330,17 @@ def _map_outputs(
327330

328331
def _execute_with_orchestrator(
329332
self,
330-
placeholder_run: PipelineRunResponse,
331333
resolved_params: Dict[str, Any],
332334
use_in_memory: bool,
333-
) -> Optional[Dict[str, Dict[str, Any]]]:
335+
) -> Tuple[PipelineRunResponse, Optional[Dict[str, Dict[str, Any]]]]:
334336
"""Run the snapshot via the orchestrator and return the concrete run.
335337
336338
Args:
337-
placeholder_run: The placeholder run to execute the pipeline on.
338339
resolved_params: Normalized pipeline parameters.
339340
use_in_memory: Whether runtime should capture in-memory outputs.
340341
341342
Returns:
342-
The in-memory outputs of the pipeline execution.
343+
A tuple of (placeholder_run, in-memory outputs of the execution).
343344
344345
Raises:
345346
RuntimeError: If the orchestrator has not been initialized.
@@ -400,9 +401,7 @@ def _execute_with_orchestrator(
400401
finally:
401402
# Always stop deployment runtime context
402403
runtime.stop()
403-
404-
# Store captured outputs for the caller to use
405-
return captured_outputs
404+
return placeholder_run, captured_outputs
406405

407406
def _execute_init_hook(self) -> None:
408407
"""Execute init hook if present.

src/zenml/orchestrators/output_utils.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,19 @@ def generate_artifact_uri(
5858

5959

6060
def prepare_output_artifact_uris(
61-
step_run: "StepRunResponse", stack: "Stack", step: "Step"
61+
step_run: "StepRunResponse",
62+
stack: "Stack",
63+
step: "Step",
64+
*,
65+
create_dirs: bool = True,
6266
) -> Dict[str, str]:
6367
"""Prepares the output artifact URIs to run the current step.
6468
6569
Args:
6670
step_run: The step run for which to prepare the artifact URIs.
6771
stack: The stack on which the pipeline is running.
6872
step: The step configuration.
73+
create_dirs: Whether to pre-create directories in the artifact store.
6974
7075
Raises:
7176
RuntimeError: If an artifact URI already exists.
@@ -75,18 +80,43 @@ def prepare_output_artifact_uris(
7580
"""
7681
artifact_store = stack.artifact_store
7782
output_artifact_uris: Dict[str, str] = {}
83+
7884
for output_name in step.config.outputs.keys():
7985
substituted_output_name = string_utils.format_name_template(
8086
output_name, substitutions=step_run.config.substitutions
8187
)
82-
artifact_uri = generate_artifact_uri(
83-
artifact_store=stack.artifact_store,
84-
step_run=step_run,
85-
output_name=substituted_output_name,
86-
)
87-
if artifact_store.exists(artifact_uri):
88-
raise RuntimeError("Artifact already exists")
89-
artifact_store.makedirs(artifact_uri)
88+
if create_dirs:
89+
artifact_uri = generate_artifact_uri(
90+
artifact_store=artifact_store,
91+
step_run=step_run,
92+
output_name=substituted_output_name,
93+
)
94+
else:
95+
# Produce a clear in-memory URI that doesn't point to the store.
96+
sanitized_output = substituted_output_name
97+
for banned_character in [
98+
"<",
99+
">",
100+
":",
101+
'"',
102+
"/",
103+
"\\",
104+
"|",
105+
"?",
106+
"*",
107+
]:
108+
sanitized_output = sanitized_output.replace(
109+
banned_character, "_"
110+
)
111+
artifact_uri = (
112+
f"memory://{step_run.name}/{sanitized_output}/"
113+
f"{step_run.id}/{str(uuid4())[:8]}"
114+
)
115+
116+
if create_dirs:
117+
if artifact_store.exists(artifact_uri):
118+
raise RuntimeError("Artifact already exists")
119+
artifact_store.makedirs(artifact_uri)
90120
output_artifact_uris[output_name] = artifact_uri
91121
return output_artifact_uris
92122

src/zenml/orchestrators/step_launcher.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
ENV_ZENML_STEP_OPERATOR,
2727
handle_bool_env_var,
2828
)
29+
from zenml.deployers.server import runtime
2930
from zenml.enums import ExecutionMode, ExecutionStatus
3031
from zenml.environment import get_run_environment_dict
3132
from zenml.exceptions import RunInterruptedException, RunStoppedException
@@ -438,7 +439,10 @@ def _run_step(
438439
)
439440

440441
output_artifact_uris = output_utils.prepare_output_artifact_uris(
441-
step_run=step_run, stack=self._stack, step=self._step
442+
step_run=step_run,
443+
stack=self._stack,
444+
step=self._step,
445+
create_dirs=not runtime.should_use_in_memory_mode(),
442446
)
443447

444448
start_time = time.time()

src/zenml/orchestrators/utils.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -434,8 +434,27 @@ def deployment_snapshot_request_from_source_snapshot(
434434

435435
if source_snapshot.stack is None:
436436
raise ValueError("Source snapshot stack is None")
437-
if source_snapshot.pipeline is None:
438-
raise ValueError("Source snapshot pipeline is None")
437+
438+
# Update the pipeline spec parameters by overriding only known keys
439+
updated_pipeline_spec = source_snapshot.pipeline_spec
440+
try:
441+
if (
442+
source_snapshot.pipeline_spec
443+
and source_snapshot.pipeline_spec.parameters is not None
444+
):
445+
original_params: Dict[str, Any] = dict(
446+
source_snapshot.pipeline_spec.parameters
447+
)
448+
merged_params: Dict[str, Any] = original_params.copy()
449+
for k, v in deployment_parameters.items():
450+
if k in original_params:
451+
merged_params[k] = v
452+
updated_pipeline_spec = pydantic_utils.update_model(
453+
source_snapshot.pipeline_spec, {"parameters": merged_params}
454+
)
455+
except Exception:
456+
# In case of any unforeseen errors, fall back to the original spec
457+
updated_pipeline_spec = source_snapshot.pipeline_spec
439458

440459
return PipelineSnapshotRequest(
441460
project=source_snapshot.project_id,
@@ -454,5 +473,5 @@ def deployment_snapshot_request_from_source_snapshot(
454473
template=template_id,
455474
source_snapshot=source_snapshot_id,
456475
pipeline_version_hash=source_snapshot.pipeline_version_hash,
457-
pipeline_spec=source_snapshot.pipeline_spec,
476+
pipeline_spec=updated_pipeline_spec,
458477
)

0 commit comments

Comments
 (0)