|
21 | 21 | import requests |
22 | 22 |
|
23 | 23 | from zenml.client import Client |
| 24 | +from zenml.config.step_configurations import Step |
24 | 25 | from zenml.deployers.exceptions import ( |
25 | 26 | DeploymentHTTPError, |
26 | 27 | DeploymentNotFoundError, |
27 | 28 | DeploymentProvisionError, |
28 | 29 | ) |
29 | 30 | from zenml.enums import DeploymentStatus |
30 | | -from zenml.models import DeploymentResponse |
| 31 | +from zenml.models import ( |
| 32 | + CodeReferenceRequest, |
| 33 | + DeploymentResponse, |
| 34 | + PipelineSnapshotRequest, |
| 35 | + PipelineSnapshotResponse, |
| 36 | +) |
31 | 37 | from zenml.steps.step_context import get_step_context |
| 38 | +from zenml.utils import pydantic_utils |
32 | 39 | from zenml.utils.json_utils import pydantic_encoder |
33 | 40 |
|
34 | 41 |
|
@@ -274,3 +281,120 @@ def invoke_deployment( |
274 | 281 | raise DeploymentHTTPError( |
275 | 282 | f"Request failed for deployment {deployment_name_or_id}: {e}" |
276 | 283 | ) |
| 284 | + |
| 285 | + |
| 286 | +def deployment_snapshot_request_from_source_snapshot( |
| 287 | + source_snapshot: PipelineSnapshotResponse, |
| 288 | + deployment_parameters: Dict[str, Any], |
| 289 | +) -> PipelineSnapshotRequest: |
| 290 | + """Generate a snapshot request for deployment execution. |
| 291 | +
|
| 292 | + Args: |
| 293 | + source_snapshot: The source snapshot from which to create the |
| 294 | + snapshot request. |
| 295 | + deployment_parameters: Parameters to override for deployment execution. |
| 296 | +
|
| 297 | + Raises: |
| 298 | + RuntimeError: If the source snapshot does not have an associated stack. |
| 299 | +
|
| 300 | + Returns: |
| 301 | + The generated snapshot request. |
| 302 | + """ |
| 303 | + if source_snapshot.stack is None: |
| 304 | + raise RuntimeError("Missing source snapshot stack") |
| 305 | + |
| 306 | + pipeline_configuration = pydantic_utils.update_model( |
| 307 | + source_snapshot.pipeline_configuration, {"enable_cache": False} |
| 308 | + ) |
| 309 | + |
| 310 | + steps = {} |
| 311 | + for invocation_id, step in source_snapshot.step_configurations.items(): |
| 312 | + updated_step_parameters = step.config.parameters.copy() |
| 313 | + |
| 314 | + for param_name in step.config.parameters: |
| 315 | + if param_name in deployment_parameters: |
| 316 | + updated_step_parameters[param_name] = deployment_parameters[ |
| 317 | + param_name |
| 318 | + ] |
| 319 | + |
| 320 | + # Deployment-specific step overrides |
| 321 | + step_update = { |
| 322 | + "enable_cache": False, # Disable caching for all steps |
| 323 | + "step_operator": None, # Remove step operators for deployments |
| 324 | + "retry": None, # Remove retry configuration |
| 325 | + "parameters": updated_step_parameters, |
| 326 | + } |
| 327 | + |
| 328 | + step_config = pydantic_utils.update_model( |
| 329 | + step.step_config_overrides, step_update |
| 330 | + ) |
| 331 | + merged_step_config = step_config.apply_pipeline_configuration( |
| 332 | + pipeline_configuration |
| 333 | + ) |
| 334 | + |
| 335 | + steps[invocation_id] = Step( |
| 336 | + spec=step.spec, |
| 337 | + config=merged_step_config, |
| 338 | + step_config_overrides=step_config, |
| 339 | + ) |
| 340 | + |
| 341 | + code_reference_request = None |
| 342 | + if source_snapshot.code_reference: |
| 343 | + code_reference_request = CodeReferenceRequest( |
| 344 | + commit=source_snapshot.code_reference.commit, |
| 345 | + subdirectory=source_snapshot.code_reference.subdirectory, |
| 346 | + code_repository=source_snapshot.code_reference.code_repository.id, |
| 347 | + ) |
| 348 | + |
| 349 | + zenml_version = Client().zen_store.get_store_info().version |
| 350 | + |
| 351 | + # Compute the source snapshot ID: |
| 352 | + # - If the source snapshot has a name, we use it as the source snapshot. |
| 353 | + # That way, all runs will be associated with this snapshot. |
| 354 | + # - If the source snapshot is based on another snapshot (which therefore |
| 355 | + # has a name), we use that one instead. |
| 356 | + # - If the source snapshot does not have a name and is not based on another |
| 357 | + # snapshot, we don't set a source snapshot. |
| 358 | + # |
| 359 | + # With this, we ensure that all runs are associated with the closest named |
| 360 | + # source snapshot. |
| 361 | + source_snapshot_id = None |
| 362 | + if source_snapshot.name: |
| 363 | + source_snapshot_id = source_snapshot.id |
| 364 | + elif source_snapshot.source_snapshot_id: |
| 365 | + source_snapshot_id = source_snapshot.source_snapshot_id |
| 366 | + |
| 367 | + updated_pipeline_spec = source_snapshot.pipeline_spec |
| 368 | + if ( |
| 369 | + source_snapshot.pipeline_spec |
| 370 | + and source_snapshot.pipeline_spec.parameters is not None |
| 371 | + ): |
| 372 | + original_params: Dict[str, Any] = dict( |
| 373 | + source_snapshot.pipeline_spec.parameters |
| 374 | + ) |
| 375 | + merged_params: Dict[str, Any] = original_params.copy() |
| 376 | + for k, v in deployment_parameters.items(): |
| 377 | + if k in original_params: |
| 378 | + merged_params[k] = v |
| 379 | + updated_pipeline_spec = pydantic_utils.update_model( |
| 380 | + source_snapshot.pipeline_spec, {"parameters": merged_params} |
| 381 | + ) |
| 382 | + |
| 383 | + return PipelineSnapshotRequest( |
| 384 | + project=source_snapshot.project_id, |
| 385 | + run_name_template=source_snapshot.run_name_template, |
| 386 | + pipeline_configuration=pipeline_configuration, |
| 387 | + step_configurations=steps, |
| 388 | + client_environment={}, |
| 389 | + client_version=zenml_version, |
| 390 | + server_version=zenml_version, |
| 391 | + stack=source_snapshot.stack.id, |
| 392 | + pipeline=source_snapshot.pipeline.id, |
| 393 | + schedule=None, |
| 394 | + code_reference=code_reference_request, |
| 395 | + code_path=source_snapshot.code_path, |
| 396 | + build=source_snapshot.build.id if source_snapshot.build else None, |
| 397 | + source_snapshot=source_snapshot_id, |
| 398 | + pipeline_version_hash=source_snapshot.pipeline_version_hash, |
| 399 | + pipeline_spec=updated_pipeline_spec, |
| 400 | + ) |
0 commit comments