Skip to content

Commit 98d115e

Browse files
committed
Dynamic step config
1 parent c6153aa commit 98d115e

File tree

8 files changed

+91
-78
lines changed

8 files changed

+91
-78
lines changed

src/zenml/models/v2/core/pipeline_snapshot.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,6 @@ class PipelineSnapshotUpdate(BaseUpdate):
208208
remove_tags: Optional[List[str]] = Field(
209209
default=None, title="Tags to remove from the snapshot."
210210
)
211-
add_steps: Optional[Dict[str, Step]] = None
212211

213212
@field_validator("name")
214213
@classmethod

src/zenml/models/v2/core/step_run.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
from pydantic import ConfigDict, Field
3030

31-
from zenml.config.step_configurations import StepConfiguration, StepSpec
31+
from zenml.config.step_configurations import Step, StepConfiguration, StepSpec
3232
from zenml.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH
3333
from zenml.enums import (
3434
ArtifactSaveType,
@@ -148,6 +148,10 @@ class StepRunRequest(ProjectScopedRequest):
148148
default=None,
149149
title="The exception information of the step run.",
150150
)
151+
dynamic_config: Optional["Step"] = Field(
152+
title="The dynamic configuration of the step run.",
153+
default=None,
154+
)
151155

152156
model_config = ConfigDict(protected_namespaces=())
153157

src/zenml/orchestrators/step_launcher.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def __init__(
108108
snapshot: PipelineSnapshotResponse,
109109
step: Step,
110110
orchestrator_run_id: str,
111+
dynamic: bool = False,
111112
):
112113
"""Initializes the launcher.
113114
@@ -122,6 +123,7 @@ def __init__(
122123
self._snapshot = snapshot
123124
self._step = step
124125
self._orchestrator_run_id = orchestrator_run_id
126+
self._dynamic = dynamic
125127

126128
if not snapshot.stack:
127129
raise RuntimeError(
@@ -306,7 +308,8 @@ def launch(self) -> StepRunResponse:
306308
stack=self._stack,
307309
)
308310
step_run_request = request_factory.create_request(
309-
invocation_id=self._invocation_id
311+
invocation_id=self._invocation_id,
312+
dynamic_config=self._step if self._dynamic else None,
310313
)
311314
step_run_request.logs = logs_model
312315

src/zenml/orchestrators/step_run_utils.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,9 @@ def has_caching_enabled(self, invocation_id: str) -> bool:
7575
is_enabled_on_pipeline=self.snapshot.pipeline_configuration.enable_cache,
7676
)
7777

78-
def create_request(self, invocation_id: str) -> StepRunRequest:
78+
def create_request(
79+
self, invocation_id: str, dynamic_config: Optional[Step] = None
80+
) -> StepRunRequest:
7981
"""Create a step run request.
8082
8183
This will only create a request with basic information and will not yet
@@ -95,6 +97,7 @@ def create_request(self, invocation_id: str) -> StepRunRequest:
9597
status=ExecutionStatus.RUNNING,
9698
start_time=utc_now(),
9799
project=Client().active_project.id,
100+
dynamic_config=dynamic_config,
98101
)
99102

100103
def populate_request(
@@ -110,7 +113,10 @@ def populate_request(
110113
input resolution. This will be updated in-place with newly
111114
fetched step runs.
112115
"""
113-
step = self.snapshot.step_configurations[request.name]
116+
step = (
117+
request.dynamic_config
118+
or self.snapshot.step_configurations[request.name]
119+
)
114120

115121
input_artifacts = input_utils.resolve_step_inputs(
116122
step=step,
@@ -133,7 +139,9 @@ def populate_request(
133139
(
134140
docstring,
135141
source_code,
136-
) = self._get_docstring_and_source_code(invocation_id=request.name)
142+
) = self._get_docstring_and_source_code(
143+
invocation_id=request.name, step=step
144+
)
137145

138146
request.docstring = docstring
139147
request.source_code = source_code
@@ -174,7 +182,7 @@ def populate_request(
174182
request.docstring = cached_step_run.docstring
175183

176184
def _get_docstring_and_source_code(
177-
self, invocation_id: str
185+
self, invocation_id: str, step: "Step"
178186
) -> Tuple[Optional[str], Optional[str]]:
179187
"""Get the docstring and source code for the step.
180188
@@ -185,8 +193,6 @@ def _get_docstring_and_source_code(
185193
Returns:
186194
The docstring and source code of the step.
187195
"""
188-
step = self.snapshot.step_configurations[invocation_id]
189-
190196
try:
191197
return self._get_docstring_and_source_code_from_step_instance(
192198
step=step

src/zenml/pipelines/dynamic/runner.py

Lines changed: 30 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,11 @@
1313
Tuple,
1414
Union,
1515
)
16-
from zenml.client import Client
17-
from zenml.config.compiler import Compiler
18-
1916
from uuid import UUID
17+
2018
from zenml import ExternalArtifact
2119
from zenml.client import Client
20+
from zenml.config.compiler import Compiler
2221
from zenml.config.step_configurations import Step
2322
from zenml.exceptions import RunStoppedException
2423
from zenml.logger import get_logger
@@ -27,7 +26,6 @@
2726
ArtifactVersionResponse,
2827
PipelineRunResponse,
2928
PipelineSnapshotResponse,
30-
PipelineSnapshotUpdate,
3129
)
3230
from zenml.models.v2.core.step_run import StepRunResponse
3331
from zenml.orchestrators.publish_utils import (
@@ -169,18 +167,19 @@ def run_step_sync(
169167
compiled_step, invocation_id = _compile_step(
170168
self.pipeline, step, id, upstream_steps, inputs
171169
)
172-
updated_snapshot = Client().zen_store.update_snapshot(
173-
self._snapshot.id,
174-
snapshot_update=PipelineSnapshotUpdate(
175-
add_steps={invocation_id: compiled_step}
176-
),
177-
)
178-
step_config = updated_snapshot.step_configurations[invocation_id]
170+
# updated_snapshot = Client().zen_store.update_snapshot(
171+
# self._snapshot.id,
172+
# snapshot_update=PipelineSnapshotUpdate(
173+
# add_steps={invocation_id: compiled_step}
174+
# ),
175+
# )
176+
# step_config = updated_snapshot.step_configurations[invocation_id]
179177
step_run = _run_step_sync(
180-
snapshot=updated_snapshot,
181-
step=step_config,
178+
snapshot=self._snapshot,
179+
step=compiled_step,
182180
orchestrator_run_id=self._run.orchestrator_run_id,
183-
retry=_should_retry_locally(step_config),
181+
retry=_should_retry_locally(compiled_step),
182+
dynamic=True,
184183
)
185184
return _load_step_result(step_run.id)
186185

@@ -198,31 +197,20 @@ def run_step_in_thread(
198197
compiled_step, invocation_id = _compile_step(
199198
self.pipeline, step, id, upstream_steps, inputs
200199
)
201-
updated_snapshot = Client().zen_store.update_snapshot(
202-
self._snapshot.id,
203-
snapshot_update=PipelineSnapshotUpdate(
204-
add_steps={invocation_id: compiled_step}
205-
),
206-
)
207-
step_config = updated_snapshot.step_configurations[invocation_id]
208200

209201
def _run() -> StepRunResult:
210202
step_run = _run_step_sync(
211-
snapshot=updated_snapshot,
212-
step=step_config,
203+
snapshot=self._snapshot,
204+
step=compiled_step,
213205
orchestrator_run_id=self._run.orchestrator_run_id,
214-
retry=_should_retry_locally(step_config),
206+
retry=_should_retry_locally(compiled_step),
207+
dynamic=True,
215208
)
216209
return _load_step_result(step_run.id)
217210

218211
ctx = contextvars.copy_context()
219-
future = self._executor.submit(
220-
ctx.run, _run
221-
)
222-
return StepRunResultFuture(
223-
wrapped=future, invocation_id=invocation_id
224-
)
225-
212+
future = self._executor.submit(ctx.run, _run)
213+
return StepRunResultFuture(wrapped=future, invocation_id=invocation_id)
226214

227215

228216
def _prepare_step_run(
@@ -318,12 +306,14 @@ def _run_step_sync(
318306
step: "Step",
319307
orchestrator_run_id: str,
320308
retry: bool = False,
309+
dynamic: bool = False,
321310
) -> StepRunResponse:
322311
def _launch_step() -> StepRunResponse:
323312
launcher = StepLauncher(
324313
snapshot=snapshot,
325314
step=step,
326315
orchestrator_run_id=orchestrator_run_id,
316+
dynamic=dynamic,
327317
)
328318
return launcher.launch()
329319

@@ -362,7 +352,7 @@ def _launch_step() -> StepRunResponse:
362352
raise
363353
else:
364354
break
365-
355+
366356
return step_run
367357

368358

@@ -377,7 +367,7 @@ def _convert_output_artifact(
377367
step_name=step_run.name,
378368
**artifact.model_dump(),
379369
)
380-
370+
381371
output_artifacts = step_run.regular_outputs
382372
if len(output_artifacts) == 0:
383373
return None
@@ -399,16 +389,19 @@ def _should_retry_locally(step: "Step") -> bool:
399389
return True
400390
else:
401391
# Running out of process with the orchestrator
402-
return not Client().active_stack.orchestrator.config.handles_step_retries
392+
return (
393+
not Client().active_stack.orchestrator.config.handles_step_retries
394+
)
395+
403396

404397
def _runs_in_process(step: "Step") -> bool:
405398
if step.config.step_operator:
406399
return False
407-
400+
408401
if not Client().active_stack.orchestrator.supports_dynamic_out_of_process_steps:
409402
return False
410403

411404
if step.config.in_process is False:
412405
return False
413-
414-
return True
406+
407+
return True

src/zenml/zen_stores/schemas/pipeline_snapshot_schemas.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,7 @@ class StepConfigurationSchema(BaseSchema, table=True):
586586
__table_args__ = (
587587
UniqueConstraint(
588588
"snapshot_id",
589+
"step_run_id",
589590
"name",
590591
name="unique_step_name_for_snapshot",
591592
),
@@ -608,5 +609,13 @@ class StepConfigurationSchema(BaseSchema, table=True):
608609
source_column="snapshot_id",
609610
target_column="id",
610611
ondelete="CASCADE",
611-
nullable=False,
612+
nullable=True,
613+
)
614+
step_run_id: UUID = build_foreign_key_field(
615+
source=__tablename__,
616+
target="step_run",
617+
source_column="step_run_id",
618+
target_column="id",
619+
ondelete="CASCADE",
620+
nullable=True,
612621
)

src/zenml/zen_stores/schemas/step_run_schemas.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,9 @@ class StepRunSchema(NamedSchema, RunMetadataInterface, table=True):
216216
),
217217
)
218218
)
219+
dynamic_step_configuration_schema: Optional["StepConfigurationSchema"] = (
220+
Relationship()
221+
)
219222

220223
model_config = ConfigDict(protected_namespaces=()) # type: ignore[assignment]
221224

@@ -344,7 +347,10 @@ def get_step_configuration(self) -> Step:
344347
step = None
345348

346349
if self.snapshot is not None:
347-
if self.step_configuration_schema:
350+
if (
351+
self.dynamic_step_configuration_schema
352+
or self.step_configuration_schema
353+
):
348354
pipeline_configuration = (
349355
PipelineConfiguration.model_validate_json(
350356
self.snapshot.pipeline_configuration
@@ -354,8 +360,12 @@ def get_step_configuration(self) -> Step:
354360
start_time=self.pipeline_run.start_time,
355361
inplace=True,
356362
)
363+
config_schema = (
364+
self.dynamic_step_configuration_schema
365+
or self.step_configuration_schema
366+
)
357367
step = Step.from_dict(
358-
json.loads(self.step_configuration_schema.config),
368+
json.loads(config_schema.config),
359369
pipeline_configuration=pipeline_configuration,
360370
)
361371
if not step and self.step_configuration:

src/zenml/zen_stores/sql_zen_store.py

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5167,34 +5167,6 @@ def update_snapshot(
51675167
session=session,
51685168
)
51695169

5170-
if snapshot_update.add_steps:
5171-
# TODO: this doesn't work for scheduled runs that reuse a
5172-
# snapshot. For that, we'd have to separate config from the
5173-
# snapshot, which is something we want to do anyway.
5174-
if not snapshot.is_dynamic:
5175-
raise ValueError(
5176-
"Cannot dynamically update steps of a static snapshot."
5177-
)
5178-
5179-
# TODO: race conditions
5180-
current_index = len(snapshot.step_configurations)
5181-
for index, (step_name, step_configuration) in enumerate(
5182-
snapshot_update.add_steps.items()
5183-
):
5184-
step_configuration_schema = StepConfigurationSchema(
5185-
index=current_index + index,
5186-
name=step_name,
5187-
# Don't include the merged config in the step
5188-
# configurations, we reconstruct it in the `to_model` method
5189-
# using the pipeline configuration.
5190-
config=step_configuration.model_dump_json(
5191-
exclude={"config"}
5192-
),
5193-
snapshot_id=snapshot.id,
5194-
)
5195-
session.add(step_configuration_schema)
5196-
session.commit()
5197-
51985170
session.refresh(snapshot)
51995171
return snapshot.to_model(
52005172
include_metadata=True, include_resources=True
@@ -9677,7 +9649,10 @@ def create_run_step(self, step_run: StepRunRequest) -> StepRunResponse:
96779649
session=session,
96789650
reference_type="original step run",
96799651
)
9680-
step_config = run.get_step_configuration(step_name=step_run.name)
9652+
step_config = (
9653+
step_run.dynamic_config
9654+
or run.get_step_configuration(step_name=step_run.name)
9655+
)
96819656

96829657
# Release the read locks of the previous two queries before we
96839658
# try to acquire more exclusive locks
@@ -9929,6 +9904,20 @@ def create_run_step(self, step_run: StepRunRequest) -> StepRunResponse:
99299904
pipeline_run_id=step_run.pipeline_run_id, session=session
99309905
)
99319906

9907+
if step_run.dynamic_config:
9908+
step_configuration_schema = StepConfigurationSchema(
9909+
index=0,
9910+
name=step_run.name,
9911+
# Don't include the merged config in the step
9912+
# configurations, we reconstruct it in the `to_model` method
9913+
# using the pipeline configuration.
9914+
config=step_run.dynamic_config.model_dump_json(
9915+
exclude={"config"}
9916+
),
9917+
step_run_id=step_schema.id,
9918+
)
9919+
session.add(step_configuration_schema)
9920+
99329921
session.commit()
99339922
session.refresh(
99349923
step_schema, ["input_artifacts", "output_artifacts"]

0 commit comments

Comments
 (0)