From 45abdd6c3546dfcd045aea73dce250a1b8d668f7 Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Fri, 3 Oct 2025 14:30:56 +0200 Subject: [PATCH 01/35] v0 --- src/zenml/integrations/modal/__init__.py | 2 +- .../flavors/modal_step_operator_flavor.py | 73 +++++- .../step_operators/modal_step_operator.py | 112 ++------ src/zenml/integrations/modal/utils.py | 242 ++++++++++++++++++ 4 files changed, 330 insertions(+), 99 deletions(-) create mode 100644 src/zenml/integrations/modal/utils.py diff --git a/src/zenml/integrations/modal/__init__.py b/src/zenml/integrations/modal/__init__.py index 081628cb035..49e78c1e774 100644 --- a/src/zenml/integrations/modal/__init__.py +++ b/src/zenml/integrations/modal/__init__.py @@ -29,7 +29,7 @@ class ModalIntegration(Integration): """Definition of Modal integration for ZenML.""" NAME = MODAL - REQUIREMENTS = ["modal>=0.64.49,<1"] + REQUIREMENTS = ["modal>=1"] @classmethod def flavors(cls) -> List[Type[Flavor]]: diff --git a/src/zenml/integrations/modal/flavors/modal_step_operator_flavor.py b/src/zenml/integrations/modal/flavors/modal_step_operator_flavor.py index a66580faa54..4585460e34f 100644 --- a/src/zenml/integrations/modal/flavors/modal_step_operator_flavor.py +++ b/src/zenml/integrations/modal/flavors/modal_step_operator_flavor.py @@ -15,9 +15,12 @@ from typing import TYPE_CHECKING, Optional, Type +from pydantic import Field + from zenml.config.base_settings import BaseSettings from zenml.integrations.modal import MODAL_STEP_OPERATOR_FLAVOR from zenml.step_operators import BaseStepOperatorConfig, BaseStepOperatorFlavor +from zenml.utils.secret_utils import SecretField if TYPE_CHECKING: from zenml.integrations.modal.step_operators import ModalStepOperator @@ -36,20 +39,80 @@ class ModalStepOperatorSettings(BaseSettings): incompatible. See more in the Modal docs at https://modal.com/docs/guide/region-selection. Attributes: - gpu: The type of GPU to use for the step execution. + gpu: The type of GPU to use for the step execution (e.g., "T4", "A100"). + Use ResourceSettings.gpu_count to specify the number of GPUs. region: The region to use for the step execution. cloud: The cloud provider to use for the step execution. + modal_environment: The Modal environment to use for the step execution. + timeout: Maximum execution time in seconds (default 24h). """ - gpu: Optional[str] = None - region: Optional[str] = None - cloud: Optional[str] = None + gpu: Optional[str] = Field( + None, + description="GPU type for step execution. Must be a valid Modal GPU type. " + "Examples: 'T4' (cost-effective), 'A100' (high-performance), 'V100' (training workloads). " + "Use ResourceSettings.gpu_count to specify number of GPUs. If not specified, uses CPU-only execution", + ) + region: Optional[str] = Field( + None, + description="Cloud region for step execution. Must be a valid region for the selected cloud provider. " + "Examples: 'us-east-1', 'us-west-2', 'eu-west-1'. If not specified, Modal uses default region " + "based on cloud provider and availability", + ) + cloud: Optional[str] = Field( + None, + description="Cloud provider for step execution. Must be a valid Modal-supported cloud provider. " + "Examples: 'aws', 'gcp'. If not specified, Modal uses default cloud provider " + "based on workspace configuration", + ) + modal_environment: Optional[str] = Field( + None, + description="Modal environment name for step execution. Must be a valid environment " + "configured in your Modal workspace. Examples: 'main', 'staging', 'production'. " + "If not specified, uses the default environment for the workspace", + ) + timeout: int = Field( + 86400, + description="Maximum execution time in seconds for step completion. Must be between 1 and 86400 seconds. " + "Examples: 3600 (1 hour), 7200 (2 hours), 86400 (24 hours maximum). " + "Step execution will be terminated if it exceeds this timeout", + ) class ModalStepOperatorConfig( BaseStepOperatorConfig, ModalStepOperatorSettings ): - """Configuration for the Modal step operator.""" + """Configuration for the Modal step operator. + + Attributes: + token_id: Modal API token ID (ak-xxxxx format) for authentication. + token_secret: Modal API token secret (as-xxxxx format) for authentication. + workspace: Modal workspace name (optional). + + Note: If token_id and token_secret are not provided, falls back to + Modal's default authentication (~/.modal.toml). + All other configuration options (modal_environment, gpu, region, etc.) + are inherited from ModalStepOperatorSettings. + """ + + token_id: Optional[str] = SecretField( + default=None, + description="Modal API token ID for authentication. Must be in format 'ak-xxxxx' as provided by Modal. " + "Example: 'ak-1234567890abcdef'. If not provided, falls back to Modal's default authentication " + "from ~/.modal.toml file. Required for programmatic access to Modal API", + ) + token_secret: Optional[str] = SecretField( + default=None, + description="Modal API token secret for authentication. Must be in format 'as-xxxxx' as provided by Modal. " + "Example: 'as-abcdef1234567890'. Used together with token_id for API authentication. " + "If not provided, falls back to Modal's default authentication from ~/.modal.toml file", + ) + workspace: Optional[str] = Field( + None, + description="Modal workspace name for step execution. Must be a valid workspace name " + "you have access to. Examples: 'my-company', 'ml-team', 'personal-workspace'. " + "If not specified, uses the default workspace from Modal configuration", + ) @property def is_remote(self) -> bool: diff --git a/src/zenml/integrations/modal/step_operators/modal_step_operator.py b/src/zenml/integrations/modal/step_operators/modal_step_operator.py index 9c654ca6541..5e83fa87263 100644 --- a/src/zenml/integrations/modal/step_operators/modal_step_operator.py +++ b/src/zenml/integrations/modal/step_operators/modal_step_operator.py @@ -14,21 +14,25 @@ """Modal step operator implementation.""" import asyncio -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, cast +from typing import TYPE_CHECKING, Dict, List, Optional, Type, cast import modal -from modal_proto import api_pb2 from zenml.client import Client from zenml.config.build_configuration import BuildConfiguration -from zenml.config.resource_settings import ByteUnit, ResourceSettings -from zenml.enums import StackComponentType +from zenml.config.resource_settings import ByteUnit from zenml.integrations.modal.flavors import ( ModalStepOperatorConfig, ModalStepOperatorSettings, ) +from zenml.integrations.modal.utils import ( + build_modal_image, + get_gpu_values, + get_modal_stack_validator, + setup_modal_client, +) from zenml.logger import get_logger -from zenml.stack import Stack, StackValidator +from zenml.stack import StackValidator from zenml.step_operators import BaseStepOperator if TYPE_CHECKING: @@ -41,24 +45,6 @@ MODAL_STEP_OPERATOR_DOCKER_IMAGE_KEY = "modal_step_operator" -def get_gpu_values( - settings: ModalStepOperatorSettings, resource_settings: ResourceSettings -) -> Optional[str]: - """Get the GPU values for the Modal step operator. - - Args: - settings: The Modal step operator settings. - resource_settings: The resource settings. - - Returns: - The GPU string if a count is specified, otherwise the GPU type. - """ - if not settings.gpu: - return None - gpu_count = resource_settings.gpu_count - return f"{settings.gpu}:{gpu_count}" if gpu_count else settings.gpu - - class ModalStepOperator(BaseStepOperator): """Step operator to run a step on Modal. @@ -91,40 +77,7 @@ def validator(self) -> Optional[StackValidator]: Returns: The stack validator. """ - - def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]: - if stack.artifact_store.config.is_local: - return False, ( - "The Modal step operator runs code remotely and " - "needs to write files into the artifact store, but the " - f"artifact store `{stack.artifact_store.name}` of the " - "active stack is local. Please ensure that your stack " - "contains a remote artifact store when using the Modal " - "step operator." - ) - - container_registry = stack.container_registry - assert container_registry is not None - - if container_registry.config.is_local: - return False, ( - "The Modal step operator runs code remotely and " - "needs to push/pull Docker images, but the " - f"container registry `{container_registry.name}` of the " - "active stack is local. Please ensure that your stack " - "contains a remote container registry when using the " - "Modal step operator." - ) - - return True, "" - - return StackValidator( - required_components={ - StackComponentType.CONTAINER_REGISTRY, - StackComponentType.IMAGE_BUILDER, - }, - custom_validation_function=_validate_remote_components, - ) + return get_modal_stack_validator() def get_docker_builds( self, snapshot: "PipelineSnapshotBase" @@ -161,51 +114,24 @@ def launch( info: The step run information. entrypoint_command: The entrypoint command for the step. environment: The environment variables for the step. - - Raises: - RuntimeError: If no Docker credentials are found for the container registry. - ValueError: If no container registry is found in the stack. """ settings = cast(ModalStepOperatorSettings, self.get_settings(info)) image_name = info.get_image(key=MODAL_STEP_OPERATOR_DOCKER_IMAGE_KEY) zc = Client() stack = zc.active_stack - if not stack.container_registry: - raise ValueError( - "No Container registry found in the stack. " - "Please add a container registry and ensure " - "it is correctly configured." - ) - - if docker_creds := stack.container_registry.credentials: - docker_username, docker_password = docker_creds - else: - raise RuntimeError( - "No Docker credentials found for the container registry." - ) - - my_secret = modal.secret._Secret.from_dict( - { - "REGISTRY_USERNAME": docker_username, - "REGISTRY_PASSWORD": docker_password, - } - ) - - spec = modal.image.DockerfileSpec( - commands=[f"FROM {image_name}"], context_files={} + setup_modal_client( + token_id=self.config.token_id, + token_secret=self.config.token_secret, + workspace=self.config.workspace, + environment=self.config.modal_environment, ) - zenml_image = modal.Image._from_args( - dockerfile_function=lambda *_, **__: spec, - force_build=False, - image_registry_config=modal.image._ImageRegistryConfig( - api_pb2.REGISTRY_AUTH_TYPE_STATIC_CREDS, my_secret - ), - ).env(environment) + # Build Modal image using shared utility + zenml_image = build_modal_image(image_name, stack, environment) resource_settings = info.config.resource_settings - gpu_values = get_gpu_values(settings, resource_settings) + gpu_values = get_gpu_values(settings.gpu, resource_settings) app = modal.App( f"zenml-{info.run_name}-{info.step_run_id}-{info.pipeline_step_name}" @@ -231,7 +157,7 @@ async def run_sandbox() -> asyncio.Future[None]: cloud=settings.cloud, region=settings.region, app=app, - timeout=86400, # 24h, the max Modal allows + timeout=settings.timeout, ) await sb.wait.aio() diff --git a/src/zenml/integrations/modal/utils.py b/src/zenml/integrations/modal/utils.py new file mode 100644 index 00000000000..96548ed4c57 --- /dev/null +++ b/src/zenml/integrations/modal/utils.py @@ -0,0 +1,242 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Shared utilities for Modal integration components.""" + +import os +from typing import TYPE_CHECKING, Dict, Optional, Tuple +from uuid import UUID + +import modal + +from zenml.config import ResourceSettings +from zenml.enums import StackComponentType +from zenml.logger import get_logger +from zenml.stack import Stack, StackValidator + +if TYPE_CHECKING: + from zenml.models import BuildItem + +logger = get_logger(__name__) + +ENV_ZENML_MODAL_ORCHESTRATOR_RUN_ID = "ZENML_MODAL_ORCHESTRATOR_RUN_ID" + + +def setup_modal_client( + token_id: Optional[str] = None, + token_secret: Optional[str] = None, + workspace: Optional[str] = None, + environment: Optional[str] = None, +) -> None: + """Setup Modal client with authentication. + + Args: + token_id: Modal API token ID (ak-xxxxx format). + token_secret: Modal API token secret (as-xxxxx format). + workspace: Modal workspace name. + environment: Modal environment name. + """ + if token_id and token_secret: + # Validate token format + if not token_id.startswith("ak-"): + logger.warning( + f"Token ID format may be invalid. Expected format: ak-xxxxx, " + f"got: {token_id[:10]}... (truncated for security)" + ) + + if not token_secret.startswith("as-"): + logger.warning( + f"Token secret format may be invalid. Expected format: as-xxxxx, " + f"got: {token_secret[:10]}... (truncated for security)" + ) + + # Set both token ID and secret + os.environ["MODAL_TOKEN_ID"] = token_id + os.environ["MODAL_TOKEN_SECRET"] = token_secret + logger.debug("Using platform token ID and secret from config") + logger.debug(f"Token ID starts with: {token_id[:5]}...") + logger.debug(f"Token secret starts with: {token_secret[:5]}...") + + elif token_id: + # Validate token format + if not token_id.startswith("ak-"): + logger.warning( + f"Token ID format may be invalid. Expected format: ak-xxxxx, " + f"got: {token_id[:10]}... (truncated for security)" + ) + + # Only token ID provided + os.environ["MODAL_TOKEN_ID"] = token_id + logger.debug("Using platform token ID from config") + logger.warning( + "Only token ID provided. Make sure MODAL_TOKEN_SECRET is set " + "or platform authentication may fail." + ) + logger.debug(f"Token ID starts with: {token_id[:5]}...") + + elif token_secret: + # Validate token format + if not token_secret.startswith("as-"): + logger.warning( + f"Token secret format may be invalid. Expected format: as-xxxxx, " + f"got: {token_secret[:10]}... (truncated for security)" + ) + + # Only token secret provided (unusual) + os.environ["MODAL_TOKEN_SECRET"] = token_secret + logger.warning( + "Only token secret provided. Make sure MODAL_TOKEN_ID is set " + "or platform authentication may fail." + ) + logger.debug(f"Token secret starts with: {token_secret[:5]}...") + + else: + logger.debug("Using default platform authentication (~/.modal.toml)") + # Check if default auth exists + modal_toml_path = os.path.expanduser("~/.modal.toml") + if os.path.exists(modal_toml_path): + logger.debug(f"Found platform config at {modal_toml_path}") + else: + logger.warning( + f"No platform config found at {modal_toml_path}. " + "Run 'modal token new' to set up authentication." + ) + + # Set workspace/environment if provided + if workspace: + os.environ["MODAL_WORKSPACE"] = workspace + if environment: + os.environ["MODAL_ENVIRONMENT"] = environment + + +# TODO: refactor step operator and remove this +def get_gpu_values( + gpu_type: Optional[str], resource_settings: ResourceSettings +) -> Optional[str]: + """Get the GPU values for Modal components. + + Args: + gpu_type: The GPU type from Modal settings (e.g., "T4", "A100"). + resource_settings: The resource settings containing GPU configuration. + + Returns: + The GPU string for Modal API, or None if no GPU requested. + """ + if not gpu_type: + return None + + gpu_count = resource_settings.gpu_count + if gpu_count == 0: + return None + elif gpu_count is None: + return gpu_type + else: + return f"{gpu_type}:{gpu_count}" + + +def build_modal_image( + image_name: str, + stack: "Stack", + environment: Optional[Dict[str, str]] = None, +) -> modal.Image: + """Build a Modal image from a Docker registry with authentication. + + This helper function centralizes the shared logic for building Modal images + from Docker registries, including credential validation, secret creation, + and image building with Modal installation. + + Args: + image_name: The name of the Docker image to use as base. + stack: The ZenML stack containing container registry. + environment: Optional environment variables to apply to the image. + + Returns: + The configured Modal image. + + Raises: + RuntimeError: If no Docker credentials are found. + """ + if not stack.container_registry: + raise RuntimeError( + "No Container registry found in the stack. " + "Please add a container registry and ensure " + "it is correctly configured." + ) + + if docker_creds := stack.container_registry.credentials: + docker_username, docker_password = docker_creds + else: + raise RuntimeError( + "No Docker credentials found for the container registry." + ) + + registry_secret = modal.Secret.from_dict( + { + "REGISTRY_USERNAME": docker_username, + "REGISTRY_PASSWORD": docker_password, + } + ) + + modal_image = modal.Image.from_registry( + image_name, secret=registry_secret + ).pip_install("modal") + + if environment: + modal_image = modal_image.env(environment) + + return modal_image + + +def get_modal_stack_validator() -> StackValidator: + """Get a stack validator for Modal components. + + The validator ensures that the stack contains a remote artifact store and + container registry. + + Returns: + A stack validator for modal components. + """ + + def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]: + if stack.artifact_store.config.is_local: + return False, ( + "Serverless components run code remotely and " + "need to write files into the artifact store, but the " + f"artifact store `{stack.artifact_store.name}` of the " + "active stack is local. Please ensure that your stack " + "contains a remote artifact store when using serverless " + "components." + ) + + container_registry = stack.container_registry + assert container_registry is not None + + if container_registry.config.is_local: + return False, ( + "Serverless components run code remotely and " + "need to push/pull Docker images, but the " + f"container registry `{container_registry.name}` of the " + "active stack is local. Please ensure that your stack " + "contains a remote container registry when using serverless " + "components." + ) + + return True, "" + + return StackValidator( + required_components={ + StackComponentType.CONTAINER_REGISTRY, + StackComponentType.IMAGE_BUILDER, + }, + custom_validation_function=_validate_remote_components, + ) From a2455ec0c05f6166cebcbb3159113823ec5b98e6 Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Fri, 3 Oct 2025 14:40:51 +0200 Subject: [PATCH 02/35] Refactor Modal step operator GPU configuration Inline GPU configuration logic directly into the step operator instead of using a utility function. The logic is simple enough (10 lines of string formatting) that a shared utility adds unnecessary indirection. Changes: - Inline GPU value construction in modal_step_operator.py - Remove get_gpu_values utility function from utils.py - Remove test file that only tested trivial string formatting This addresses the TODO comment to refactor and remove the get_gpu_values helper function. --- .../step_operators/modal_step_operator.py | 13 +++++- src/zenml/integrations/modal/utils.py | 34 +------------- .../test_modal_step_operator.py | 46 ------------------- 3 files changed, 12 insertions(+), 81 deletions(-) delete mode 100644 tests/integration/integrations/modal/step_operators/test_modal_step_operator.py diff --git a/src/zenml/integrations/modal/step_operators/modal_step_operator.py b/src/zenml/integrations/modal/step_operators/modal_step_operator.py index 5e83fa87263..87f1396ae74 100644 --- a/src/zenml/integrations/modal/step_operators/modal_step_operator.py +++ b/src/zenml/integrations/modal/step_operators/modal_step_operator.py @@ -27,7 +27,6 @@ ) from zenml.integrations.modal.utils import ( build_modal_image, - get_gpu_values, get_modal_stack_validator, setup_modal_client, ) @@ -131,7 +130,17 @@ def launch( zenml_image = build_modal_image(image_name, stack, environment) resource_settings = info.config.resource_settings - gpu_values = get_gpu_values(settings.gpu, resource_settings) + + # Determine GPU configuration from settings and resource settings + gpu_values = None + if settings.gpu: + gpu_count = resource_settings.gpu_count + if gpu_count == 0: + gpu_values = None + elif gpu_count is None: + gpu_values = settings.gpu + else: + gpu_values = f"{settings.gpu}:{gpu_count}" app = modal.App( f"zenml-{info.run_name}-{info.step_run_id}-{info.pipeline_step_name}" diff --git a/src/zenml/integrations/modal/utils.py b/src/zenml/integrations/modal/utils.py index 96548ed4c57..2dc923a97da 100644 --- a/src/zenml/integrations/modal/utils.py +++ b/src/zenml/integrations/modal/utils.py @@ -14,23 +14,16 @@ """Shared utilities for Modal integration components.""" import os -from typing import TYPE_CHECKING, Dict, Optional, Tuple -from uuid import UUID +from typing import Dict, Optional, Tuple import modal -from zenml.config import ResourceSettings from zenml.enums import StackComponentType from zenml.logger import get_logger from zenml.stack import Stack, StackValidator -if TYPE_CHECKING: - from zenml.models import BuildItem - logger = get_logger(__name__) -ENV_ZENML_MODAL_ORCHESTRATOR_RUN_ID = "ZENML_MODAL_ORCHESTRATOR_RUN_ID" - def setup_modal_client( token_id: Optional[str] = None, @@ -119,31 +112,6 @@ def setup_modal_client( os.environ["MODAL_ENVIRONMENT"] = environment -# TODO: refactor step operator and remove this -def get_gpu_values( - gpu_type: Optional[str], resource_settings: ResourceSettings -) -> Optional[str]: - """Get the GPU values for Modal components. - - Args: - gpu_type: The GPU type from Modal settings (e.g., "T4", "A100"). - resource_settings: The resource settings containing GPU configuration. - - Returns: - The GPU string for Modal API, or None if no GPU requested. - """ - if not gpu_type: - return None - - gpu_count = resource_settings.gpu_count - if gpu_count == 0: - return None - elif gpu_count is None: - return gpu_type - else: - return f"{gpu_type}:{gpu_count}" - - def build_modal_image( image_name: str, stack: "Stack", diff --git a/tests/integration/integrations/modal/step_operators/test_modal_step_operator.py b/tests/integration/integrations/modal/step_operators/test_modal_step_operator.py deleted file mode 100644 index 865f6d55265..00000000000 --- a/tests/integration/integrations/modal/step_operators/test_modal_step_operator.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright (c) ZenML GmbH 2024. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at: -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -# or implied. See the License for the specific language governing -# permissions and limitations under the License. - - -import pytest - -from zenml.config.resource_settings import ResourceSettings -from zenml.integrations.modal.flavors import ModalStepOperatorSettings -from zenml.integrations.modal.step_operators.modal_step_operator import ( - get_gpu_values, -) - - -@pytest.mark.parametrize( - "gpu, gpu_count, expected_result", - [ - ("", None, None), - (None, None, None), - ("", 1, None), - (None, 1, None), - ("A100", None, "A100"), - ("A100", 0, "A100"), - ("A100", 1, "A100:1"), - ("A100", 2, "A100:2"), - ("V100", None, "V100"), - ("V100", 0, "V100"), - ("V100", 1, "V100:1"), - ("V100", 2, "V100:2"), - ], -) -def test_get_gpu_values(gpu, gpu_count, expected_result): - settings = ModalStepOperatorSettings(gpu=gpu) - resource_settings = ResourceSettings(gpu_count=gpu_count) - result = get_gpu_values(settings, resource_settings) - assert result == expected_result From 87c725e8064638dee1eaf3f9b9c5012f319c286b Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Fri, 3 Oct 2025 14:43:14 +0200 Subject: [PATCH 03/35] Small refactoring --- .../modal/step_operators/modal_step_operator.py | 2 -- src/zenml/integrations/modal/utils.py | 7 ------- 2 files changed, 9 deletions(-) diff --git a/src/zenml/integrations/modal/step_operators/modal_step_operator.py b/src/zenml/integrations/modal/step_operators/modal_step_operator.py index 87f1396ae74..2701322f33f 100644 --- a/src/zenml/integrations/modal/step_operators/modal_step_operator.py +++ b/src/zenml/integrations/modal/step_operators/modal_step_operator.py @@ -126,12 +126,10 @@ def launch( environment=self.config.modal_environment, ) - # Build Modal image using shared utility zenml_image = build_modal_image(image_name, stack, environment) resource_settings = info.config.resource_settings - # Determine GPU configuration from settings and resource settings gpu_values = None if settings.gpu: gpu_count = resource_settings.gpu_count diff --git a/src/zenml/integrations/modal/utils.py b/src/zenml/integrations/modal/utils.py index 2dc923a97da..1bbbafd5b99 100644 --- a/src/zenml/integrations/modal/utils.py +++ b/src/zenml/integrations/modal/utils.py @@ -40,7 +40,6 @@ def setup_modal_client( environment: Modal environment name. """ if token_id and token_secret: - # Validate token format if not token_id.startswith("ak-"): logger.warning( f"Token ID format may be invalid. Expected format: ak-xxxxx, " @@ -53,7 +52,6 @@ def setup_modal_client( f"got: {token_secret[:10]}... (truncated for security)" ) - # Set both token ID and secret os.environ["MODAL_TOKEN_ID"] = token_id os.environ["MODAL_TOKEN_SECRET"] = token_secret logger.debug("Using platform token ID and secret from config") @@ -61,14 +59,12 @@ def setup_modal_client( logger.debug(f"Token secret starts with: {token_secret[:5]}...") elif token_id: - # Validate token format if not token_id.startswith("ak-"): logger.warning( f"Token ID format may be invalid. Expected format: ak-xxxxx, " f"got: {token_id[:10]}... (truncated for security)" ) - # Only token ID provided os.environ["MODAL_TOKEN_ID"] = token_id logger.debug("Using platform token ID from config") logger.warning( @@ -78,7 +74,6 @@ def setup_modal_client( logger.debug(f"Token ID starts with: {token_id[:5]}...") elif token_secret: - # Validate token format if not token_secret.startswith("as-"): logger.warning( f"Token secret format may be invalid. Expected format: as-xxxxx, " @@ -95,7 +90,6 @@ def setup_modal_client( else: logger.debug("Using default platform authentication (~/.modal.toml)") - # Check if default auth exists modal_toml_path = os.path.expanduser("~/.modal.toml") if os.path.exists(modal_toml_path): logger.debug(f"Found platform config at {modal_toml_path}") @@ -105,7 +99,6 @@ def setup_modal_client( "Run 'modal token new' to set up authentication." ) - # Set workspace/environment if provided if workspace: os.environ["MODAL_WORKSPACE"] = workspace if environment: From 4481a3faee8aa22f76fc219c562d23710891345c Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Fri, 3 Oct 2025 15:25:30 +0200 Subject: [PATCH 04/35] Remove unnecessary asyncio complexity --- .../modal/step_operators/modal_step_operator.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/zenml/integrations/modal/step_operators/modal_step_operator.py b/src/zenml/integrations/modal/step_operators/modal_step_operator.py index 2701322f33f..a269b9089df 100644 --- a/src/zenml/integrations/modal/step_operators/modal_step_operator.py +++ b/src/zenml/integrations/modal/step_operators/modal_step_operator.py @@ -144,9 +144,7 @@ def launch( f"zenml-{info.run_name}-{info.step_run_id}-{info.pipeline_step_name}" ) - async def run_sandbox() -> asyncio.Future[None]: - loop = asyncio.get_event_loop() - future = loop.create_future() + async def run_sandbox() -> None: with modal.enable_output(): async with app.run(): memory_mb = resource_settings.get_memory(ByteUnit.MB) @@ -169,7 +167,4 @@ async def run_sandbox() -> asyncio.Future[None]: await sb.wait.aio() - future.set_result(None) - return future - asyncio.run(run_sandbox()) From 04508532b4004f938d1e38a3e6cae31d0e34ab18 Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Fri, 3 Oct 2025 15:36:26 +0200 Subject: [PATCH 05/35] Refactor token checks --- src/zenml/integrations/modal/utils.py | 137 ++++++++++++++++---------- 1 file changed, 84 insertions(+), 53 deletions(-) diff --git a/src/zenml/integrations/modal/utils.py b/src/zenml/integrations/modal/utils.py index 1bbbafd5b99..038b06ad251 100644 --- a/src/zenml/integrations/modal/utils.py +++ b/src/zenml/integrations/modal/utils.py @@ -24,6 +24,37 @@ logger = get_logger(__name__) +MODAL_TOKEN_ID_ENV = "MODAL_TOKEN_ID" +MODAL_TOKEN_SECRET_ENV = "MODAL_TOKEN_SECRET" +MODAL_WORKSPACE_ENV = "MODAL_WORKSPACE" +MODAL_ENVIRONMENT_ENV = "MODAL_ENVIRONMENT" +MODAL_CONFIG_PATH = os.path.expanduser("~/.modal.toml") + + +def _validate_token_prefix( + value: str, expected_prefix: str, label: str +) -> None: + """Warn if a credential doesn't match Modal's expected prefix. + + This helps catch misconfigurations early without logging secret content. + """ + if not value.startswith(expected_prefix): + logger.warning( + f"{label} format may be invalid. Expected prefix: {expected_prefix}" + ) + + +def _set_env_if_present(var_name: str, value: Optional[str]) -> bool: + """Set an environment variable only if a non-empty value is provided. + + Returns: + True if the environment variable was set, False otherwise. + """ + if value is None or value == "": + return False + os.environ[var_name] = value + return True + def setup_modal_client( token_id: Optional[str] = None, @@ -31,7 +62,17 @@ def setup_modal_client( workspace: Optional[str] = None, environment: Optional[str] = None, ) -> None: - """Setup Modal client with authentication. + """Setup Modal client authentication and context. + + Precedence for credentials: + 1) Explicit arguments (token_id, token_secret) + 2) Existing environment variables (MODAL_TOKEN_ID, MODAL_TOKEN_SECRET) + 3) Default Modal configuration (~/.modal.toml) + + Notes: + - The 'environment' parameter refers to the Modal environment name (e.g., 'main'), + not a dict of process environment variables. + - This function avoids logging secret values. Validation only checks known prefixes. Args: token_id: Modal API token ID (ak-xxxxx format). @@ -39,70 +80,60 @@ def setup_modal_client( workspace: Modal workspace name. environment: Modal environment name. """ - if token_id and token_secret: - if not token_id.startswith("ak-"): + # Remember whether values came from args vs. env to improve diagnostics without leaking secrets. + arg_token_id = token_id + arg_token_secret = token_secret + + # Coalesce from env if not explicitly provided. This reduces friction if users + # supply one value via configuration and the other via environment. + token_id = token_id or os.environ.get(MODAL_TOKEN_ID_ENV) + token_secret = token_secret or os.environ.get(MODAL_TOKEN_SECRET_ENV) + + tokens_provided = [] + token_sources: Dict[str, str] = {} + + if token_id: + _validate_token_prefix(token_id, "ak-", "Token ID") + _set_env_if_present(MODAL_TOKEN_ID_ENV, token_id) + tokens_provided.append("ID") + token_sources["ID"] = "args" if arg_token_id else "env" + + if token_secret: + _validate_token_prefix(token_secret, "as-", "Token secret") + _set_env_if_present(MODAL_TOKEN_SECRET_ENV, token_secret) + tokens_provided.append("secret") + token_sources["secret"] = "args" if arg_token_secret else "env" + + if tokens_provided: + if len(tokens_provided) == 1: logger.warning( - f"Token ID format may be invalid. Expected format: ak-xxxxx, " - f"got: {token_id[:10]}... (truncated for security)" + f"Only token {tokens_provided[0]} provided. Ensure both are set." ) - - if not token_secret.startswith("as-"): - logger.warning( - f"Token secret format may be invalid. Expected format: as-xxxxx, " - f"got: {token_secret[:10]}... (truncated for security)" - ) - - os.environ["MODAL_TOKEN_ID"] = token_id - os.environ["MODAL_TOKEN_SECRET"] = token_secret - logger.debug("Using platform token ID and secret from config") - logger.debug(f"Token ID starts with: {token_id[:5]}...") - logger.debug(f"Token secret starts with: {token_secret[:5]}...") - - elif token_id: - if not token_id.startswith("ak-"): - logger.warning( - f"Token ID format may be invalid. Expected format: ak-xxxxx, " - f"got: {token_id[:10]}... (truncated for security)" - ) - - os.environ["MODAL_TOKEN_ID"] = token_id - logger.debug("Using platform token ID from config") - logger.warning( - "Only token ID provided. Make sure MODAL_TOKEN_SECRET is set " - "or platform authentication may fail." + source_parts: List[str] = [] + if "ID" in token_sources: + source_parts.append(f"ID from {token_sources['ID']}") + if "secret" in token_sources: + source_parts.append(f"secret from {token_sources['secret']}") + source_summary = ( + " and ".join(source_parts) if source_parts else "args/env" ) - logger.debug(f"Token ID starts with: {token_id[:5]}...") - - elif token_secret: - if not token_secret.startswith("as-"): - logger.warning( - f"Token secret format may be invalid. Expected format: as-xxxxx, " - f"got: {token_secret[:10]}... (truncated for security)" - ) - - # Only token secret provided (unusual) - os.environ["MODAL_TOKEN_SECRET"] = token_secret - logger.warning( - "Only token secret provided. Make sure MODAL_TOKEN_ID is set " - "or platform authentication may fail." - ) - logger.debug(f"Token secret starts with: {token_secret[:5]}...") - + logger.debug(f"Using Modal API tokens ({source_summary}).") else: + # Fall back to default Modal CLI auth configuration. logger.debug("Using default platform authentication (~/.modal.toml)") - modal_toml_path = os.path.expanduser("~/.modal.toml") - if os.path.exists(modal_toml_path): - logger.debug(f"Found platform config at {modal_toml_path}") + if os.path.exists(MODAL_CONFIG_PATH): + logger.debug(f"Found platform config at {MODAL_CONFIG_PATH}") else: logger.warning( - f"No platform config found at {modal_toml_path}. " + f"No platform config found at {MODAL_CONFIG_PATH}. " "Run 'modal token new' to set up authentication." ) + # Configure Modal workspace/environment context if provided. if workspace: - os.environ["MODAL_WORKSPACE"] = workspace + _set_env_if_present(MODAL_WORKSPACE_ENV, workspace) if environment: - os.environ["MODAL_ENVIRONMENT"] = environment + _set_env_if_present(MODAL_ENVIRONMENT_ENV, environment) def build_modal_image( From 705b910687007a41291c9d1fe14cd49cafc39820 Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Fri, 3 Oct 2025 15:38:28 +0200 Subject: [PATCH 06/35] Make environment parameter optional in launch method Aligns type hint with build_modal_image signature for consistency. --- .../integrations/modal/step_operators/modal_step_operator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zenml/integrations/modal/step_operators/modal_step_operator.py b/src/zenml/integrations/modal/step_operators/modal_step_operator.py index a269b9089df..af9d3b3d8ea 100644 --- a/src/zenml/integrations/modal/step_operators/modal_step_operator.py +++ b/src/zenml/integrations/modal/step_operators/modal_step_operator.py @@ -105,7 +105,7 @@ def launch( self, info: "StepRunInfo", entrypoint_command: List[str], - environment: Dict[str, str], + environment: Optional[Dict[str, str]], ) -> None: """Launch a step run on Modal. From d19ffeb52e84ac6e14eb892d0e8c5222eb08adff Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Fri, 3 Oct 2025 15:40:22 +0200 Subject: [PATCH 07/35] Simplify memory conversion using walrus operator Reduces verbosity by combining assignment and conditional check. --- .../modal/step_operators/modal_step_operator.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/zenml/integrations/modal/step_operators/modal_step_operator.py b/src/zenml/integrations/modal/step_operators/modal_step_operator.py index af9d3b3d8ea..876c544d3e6 100644 --- a/src/zenml/integrations/modal/step_operators/modal_step_operator.py +++ b/src/zenml/integrations/modal/step_operators/modal_step_operator.py @@ -147,9 +147,11 @@ def launch( async def run_sandbox() -> None: with modal.enable_output(): async with app.run(): - memory_mb = resource_settings.get_memory(ByteUnit.MB) memory_int = ( - int(memory_mb) if memory_mb is not None else None + int(mb) + if (mb := resource_settings.get_memory(ByteUnit.MB)) + is not None + else None ) sb = await modal.Sandbox.create.aio( "bash", From 9e91e94f993af5968b1ca581ec7115fafefb9c2a Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Fri, 3 Oct 2025 15:43:32 +0200 Subject: [PATCH 08/35] Use consistent truthy checks for None handling Aligns with existing pattern used throughout Modal integration. --- .../integrations/modal/step_operators/modal_step_operator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/zenml/integrations/modal/step_operators/modal_step_operator.py b/src/zenml/integrations/modal/step_operators/modal_step_operator.py index 876c544d3e6..2f5d271e649 100644 --- a/src/zenml/integrations/modal/step_operators/modal_step_operator.py +++ b/src/zenml/integrations/modal/step_operators/modal_step_operator.py @@ -150,7 +150,6 @@ async def run_sandbox() -> None: memory_int = ( int(mb) if (mb := resource_settings.get_memory(ByteUnit.MB)) - is not None else None ) sb = await modal.Sandbox.create.aio( From fdbdea8900621080abc4aa43133227336777b0c2 Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Fri, 3 Oct 2025 15:45:14 +0200 Subject: [PATCH 09/35] Extract timeout default into named constant Improves maintainability by defining the value in a single location. --- .../modal/flavors/modal_step_operator_flavor.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/zenml/integrations/modal/flavors/modal_step_operator_flavor.py b/src/zenml/integrations/modal/flavors/modal_step_operator_flavor.py index 4585460e34f..fc707738caa 100644 --- a/src/zenml/integrations/modal/flavors/modal_step_operator_flavor.py +++ b/src/zenml/integrations/modal/flavors/modal_step_operator_flavor.py @@ -25,6 +25,8 @@ if TYPE_CHECKING: from zenml.integrations.modal.step_operators import ModalStepOperator +DEFAULT_TIMEOUT_SECONDS = 86400 # 24 hours + class ModalStepOperatorSettings(BaseSettings): """Settings for the Modal step operator. @@ -72,9 +74,9 @@ class ModalStepOperatorSettings(BaseSettings): "If not specified, uses the default environment for the workspace", ) timeout: int = Field( - 86400, - description="Maximum execution time in seconds for step completion. Must be between 1 and 86400 seconds. " - "Examples: 3600 (1 hour), 7200 (2 hours), 86400 (24 hours maximum). " + DEFAULT_TIMEOUT_SECONDS, + description=f"Maximum execution time in seconds for step completion. Must be between 1 and {DEFAULT_TIMEOUT_SECONDS} seconds. " + f"Examples: 3600 (1 hour), 7200 (2 hours), {DEFAULT_TIMEOUT_SECONDS} (24 hours maximum). " "Step execution will be terminated if it exceeds this timeout", ) From 8cf809db9e8332c4a9a6c0211a79c64bb21c517f Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Fri, 3 Oct 2025 16:07:35 +0200 Subject: [PATCH 10/35] Fix GPU settings validation in Modal step operator Previously, setting gpu_count > 0 without specifying a GPU type would silently run on CPU with no warning. This adds comprehensive validation that raises clear configuration errors with actionable guidance. Changes: - Add _compute_modal_gpu_arg() helper to validate and format GPU settings - Raise StackComponentInterfaceError when gpu_count > 0 but gpu is None - Warn and default to 1 GPU when gpu is set but gpu_count is 0 - Validate gpu_count is non-negative integer - Normalize whitespace in GPU type strings - Add missing List import to utils.py --- .../step_operators/modal_step_operator.py | 89 +++++++++++++++++-- src/zenml/integrations/modal/utils.py | 2 +- 2 files changed, 81 insertions(+), 10 deletions(-) diff --git a/src/zenml/integrations/modal/step_operators/modal_step_operator.py b/src/zenml/integrations/modal/step_operators/modal_step_operator.py index 2f5d271e649..ba74c9eca0b 100644 --- a/src/zenml/integrations/modal/step_operators/modal_step_operator.py +++ b/src/zenml/integrations/modal/step_operators/modal_step_operator.py @@ -21,6 +21,7 @@ from zenml.client import Client from zenml.config.build_configuration import BuildConfiguration from zenml.config.resource_settings import ByteUnit +from zenml.exceptions import StackComponentInterfaceError from zenml.integrations.modal.flavors import ( ModalStepOperatorConfig, ModalStepOperatorSettings, @@ -101,6 +102,83 @@ def get_docker_builds( return builds + def _compute_modal_gpu_arg( + self, + settings: ModalStepOperatorSettings, + resource_settings, + ) -> Optional[str]: + """Compute and validate the Modal 'gpu' argument. + + Why this exists: + - Modal expects GPU resources as a string (e.g., 'T4' or 'A100:2'). + - ZenML splits GPU intent between a 'type' (settings.gpu) and a 'count' + (resource_settings.gpu_count). This helper reconciles the two and + enforces rules that prevent silent CPU fallbacks or ambiguous configs. + + Rules enforced: + - If a positive gpu_count is requested without specifying a GPU type, + raise a StackComponentInterfaceError to make the mismatch explicit. + - If a GPU type is specified but gpu_count == 0, we interpret this as + requesting 1 GPU (Modal semantics for a bare type string) and log a + warning to explain the behavior and how to request CPU-only runs. + - If neither a type nor a positive count is requested, return None for + CPU-only execution. + - Otherwise, format the string as '' for one GPU or ':' + for multiple GPUs. + """ + # Normalize GPU type: treat empty or whitespace-only strings as None to avoid + # surprising behavior when user-provided values are malformed (e.g., " "). + gpu_type_raw = settings.gpu + gpu_type = gpu_type_raw.strip() if gpu_type_raw is not None else None + if gpu_type == "": + gpu_type = None + + # Coerce and validate gpu_count to ensure it's a non-negative integer if provided. + gpu_count = resource_settings.gpu_count + if gpu_count is not None: + try: + gpu_count = int(gpu_count) + except (TypeError, ValueError): + raise StackComponentInterfaceError( + f"Invalid GPU count '{gpu_count}'. Must be a non-negative integer." + ) + if gpu_count < 0: + raise StackComponentInterfaceError( + f"Invalid GPU count '{gpu_count}'. Must be >= 0." + ) + + # Scenario 1: Count requested but type missing -> invalid configuration. + if gpu_type is None: + if gpu_count is not None and gpu_count > 0: + raise StackComponentInterfaceError( + "GPU resources requested (gpu_count > 0) but no GPU type was specified " + "in Modal settings. Please set a GPU type (e.g., 'T4', 'A100') via " + "ModalStepOperatorSettings.gpu or @step(settings={'modal': {'gpu': ''}}), " + "or set gpu_count=0 to run on CPU." + ) + # CPU-only scenarios (no type, count is None or 0). + return None + + # Scenario 2: Type set but count == 0 -> warn and default to 1 GPU. + if gpu_count == 0: + logger.warning( + "Modal GPU type '%s' is configured but ResourceSettings.gpu_count is 0. " + "Defaulting to 1 GPU. To run on CPU only, remove the GPU type or ensure " + "gpu_count=0 with no GPU type configured.", + gpu_type, + ) + return gpu_type # Implicitly 1 GPU for Modal + + # Valid configurations + if gpu_count is None: + return gpu_type # Implicitly 1 GPU for Modal + + if gpu_count > 0: + return f"{gpu_type}:{gpu_count}" + + # Defensive fallback shouldn't be reachable due to validation above; return CPU-only None if hit. + return None + def launch( self, info: "StepRunInfo", @@ -130,15 +208,8 @@ def launch( resource_settings = info.config.resource_settings - gpu_values = None - if settings.gpu: - gpu_count = resource_settings.gpu_count - if gpu_count == 0: - gpu_values = None - elif gpu_count is None: - gpu_values = settings.gpu - else: - gpu_values = f"{settings.gpu}:{gpu_count}" + # Compute and validate the GPU argument via the helper. + gpu_values = self._compute_modal_gpu_arg(settings, resource_settings) app = modal.App( f"zenml-{info.run_name}-{info.step_run_id}-{info.pipeline_step_name}" diff --git a/src/zenml/integrations/modal/utils.py b/src/zenml/integrations/modal/utils.py index 038b06ad251..25155f03d28 100644 --- a/src/zenml/integrations/modal/utils.py +++ b/src/zenml/integrations/modal/utils.py @@ -14,7 +14,7 @@ """Shared utilities for Modal integration components.""" import os -from typing import Dict, Optional, Tuple +from typing import Dict, List, Optional, Tuple import modal From a839fefc2c0a216e502f00e12b52063a662c86ae Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Fri, 3 Oct 2025 16:15:15 +0200 Subject: [PATCH 11/35] Allow step-level modal_environment overrides Previously, setup_modal_client always used the component-level modal_environment config, preventing per-step overrides via settings. Now prioritizes settings.modal_environment when provided, falling back to component config for backward compatibility. --- .../integrations/modal/step_operators/modal_step_operator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/zenml/integrations/modal/step_operators/modal_step_operator.py b/src/zenml/integrations/modal/step_operators/modal_step_operator.py index ba74c9eca0b..fb6988c5765 100644 --- a/src/zenml/integrations/modal/step_operators/modal_step_operator.py +++ b/src/zenml/integrations/modal/step_operators/modal_step_operator.py @@ -201,7 +201,8 @@ def launch( token_id=self.config.token_id, token_secret=self.config.token_secret, workspace=self.config.workspace, - environment=self.config.modal_environment, + environment=settings.modal_environment + or self.config.modal_environment, ) zenml_image = build_modal_image(image_name, stack, environment) From 4dcf11b692000ce80e977cf5be81a1d1c6ab632a Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Fri, 3 Oct 2025 16:31:09 +0200 Subject: [PATCH 12/35] Add unit tests for complex helper functions --- .../step_operators/test_flavor_config.py | 22 ++++ .../test_modal_step_operator_gpu.py | 113 +++++++++++++++++ .../test_utils_build_and_validator.py | 56 +++++++++ .../modal/step_operators/test_utils_env.py | 119 ++++++++++++++++++ 4 files changed, 310 insertions(+) create mode 100644 tests/integration/integrations/modal/step_operators/test_flavor_config.py create mode 100644 tests/integration/integrations/modal/step_operators/test_modal_step_operator_gpu.py create mode 100644 tests/integration/integrations/modal/step_operators/test_utils_build_and_validator.py create mode 100644 tests/integration/integrations/modal/step_operators/test_utils_env.py diff --git a/tests/integration/integrations/modal/step_operators/test_flavor_config.py b/tests/integration/integrations/modal/step_operators/test_flavor_config.py new file mode 100644 index 00000000000..ac464892c3a --- /dev/null +++ b/tests/integration/integrations/modal/step_operators/test_flavor_config.py @@ -0,0 +1,22 @@ +from zenml.integrations.modal import MODAL_STEP_OPERATOR_FLAVOR +from zenml.integrations.modal.flavors.modal_step_operator_flavor import ( + DEFAULT_TIMEOUT_SECONDS, + ModalStepOperatorConfig, + ModalStepOperatorFlavor, + ModalStepOperatorSettings, +) + + +def test_modal_settings_default_timeout() -> None: + settings = ModalStepOperatorSettings() + assert settings.timeout == DEFAULT_TIMEOUT_SECONDS + + +def test_modal_config_is_remote_true() -> None: + cfg = ModalStepOperatorConfig() + assert cfg.is_remote is True + + +def test_modal_flavor_name_constant() -> None: + flavor = ModalStepOperatorFlavor() + assert flavor.name == MODAL_STEP_OPERATOR_FLAVOR diff --git a/tests/integration/integrations/modal/step_operators/test_modal_step_operator_gpu.py b/tests/integration/integrations/modal/step_operators/test_modal_step_operator_gpu.py new file mode 100644 index 00000000000..07f575b4c86 --- /dev/null +++ b/tests/integration/integrations/modal/step_operators/test_modal_step_operator_gpu.py @@ -0,0 +1,113 @@ +import pytest + +# Skip this entire module if the optional dependency isn't available, because +# the module under test imports `modal` at import time. +pytest.importorskip("modal") + +from zenml.exceptions import StackComponentInterfaceError +from zenml.integrations.modal.flavors import ModalStepOperatorSettings +from zenml.integrations.modal.step_operators.modal_step_operator import ( + ModalStepOperator, +) + + +class ResourceSettingsStub: + """Minimal stub to simulate ZenML ResourceSettings for GPU tests. + + We only model the `gpu_count` attribute because that's the only part the + helper uses. This keeps tests lightweight and avoids wider dependencies. + """ + + def __init__(self, gpu_count): + self.gpu_count = gpu_count + + +def _make_operator() -> ModalStepOperator: + # Bypass BaseStepOperator initialization since we only need the helper. + return ModalStepOperator.__new__(ModalStepOperator) + + +def test_gpu_arg_none_when_no_type_and_no_count() -> None: + op = _make_operator() + settings = ModalStepOperatorSettings(gpu=None) + rs = ResourceSettingsStub(gpu_count=None) + assert op._compute_modal_gpu_arg(settings, rs) is None + + +def test_gpu_arg_raises_when_count_without_type() -> None: + op = _make_operator() + settings = ModalStepOperatorSettings(gpu=None) + rs = ResourceSettingsStub(gpu_count=1) + with pytest.raises(StackComponentInterfaceError) as e: + op._compute_modal_gpu_arg(settings, rs) + assert ( + "GPU resources requested (gpu_count > 0) but no GPU type was specified" + in str(e.value) + ) + + +def test_gpu_arg_type_with_no_count_returns_type() -> None: + op = _make_operator() + settings = ModalStepOperatorSettings(gpu="A100") + rs = ResourceSettingsStub(gpu_count=None) + assert op._compute_modal_gpu_arg(settings, rs) == "A100" + + +def test_gpu_arg_type_with_count_returns_type_colon_count() -> None: + op = _make_operator() + settings = ModalStepOperatorSettings(gpu="A100") + + rs_two = ResourceSettingsStub(gpu_count=2) + assert op._compute_modal_gpu_arg(settings, rs_two) == "A100:2" + + rs_one = ResourceSettingsStub(gpu_count=1) + assert op._compute_modal_gpu_arg(settings, rs_one) == "A100:1" + + +def test_gpu_arg_type_with_zero_count_warns_and_defaults_to_single_gpu( + caplog, +) -> None: + op = _make_operator() + settings = ModalStepOperatorSettings(gpu="A100") + rs = ResourceSettingsStub(gpu_count=0) + + with caplog.at_level("WARNING"): + result = op._compute_modal_gpu_arg(settings, rs) + + assert result == "A100" + assert "Defaulting to 1 GPU" in caplog.text + + +def test_gpu_arg_invalid_negative_count_raises() -> None: + op = _make_operator() + settings = ModalStepOperatorSettings(gpu="T4") + rs = ResourceSettingsStub(gpu_count=-1) + with pytest.raises(StackComponentInterfaceError) as e: + op._compute_modal_gpu_arg(settings, rs) + assert "Invalid GPU count" in str(e.value) + + +def test_gpu_arg_non_integer_count_raises() -> None: + op = _make_operator() + settings = ModalStepOperatorSettings(gpu="T4") + rs = ResourceSettingsStub(gpu_count="two") + with pytest.raises(StackComponentInterfaceError) as e: + op._compute_modal_gpu_arg(settings, rs) + assert "Invalid GPU count" in str(e.value) + + +def test_gpu_arg_whitespace_type_treated_as_none_behavior() -> None: + op = _make_operator() + settings = ModalStepOperatorSettings(gpu=" ") + + # With positive GPU count this should raise since type is treated as None. + rs_positive = ResourceSettingsStub(gpu_count=2) + with pytest.raises(StackComponentInterfaceError): + op._compute_modal_gpu_arg(settings, rs_positive) + + # With zero or None count, this should be CPU-only (None). + rs_zero = ResourceSettingsStub(gpu_count=0) + assert op._compute_modal_gpu_arg(settings, rs_zero) is None + + rs_none = ResourceSettingsStub(gpu_count=None) + assert op._compute_modal_gpu_arg(settings, rs_none) is None diff --git a/tests/integration/integrations/modal/step_operators/test_utils_build_and_validator.py b/tests/integration/integrations/modal/step_operators/test_utils_build_and_validator.py new file mode 100644 index 00000000000..5735212cf8e --- /dev/null +++ b/tests/integration/integrations/modal/step_operators/test_utils_build_and_validator.py @@ -0,0 +1,56 @@ +import pytest + +# Skip this entire module if the optional dependency isn't available, because +# the module under test imports `modal` at import time. +pytest.importorskip("modal") + +from zenml.integrations.modal.utils import ( + build_modal_image, + get_modal_stack_validator, +) +from zenml.stack.stack_validator import StackValidator + + +class StackStubNoRegistry: + """Stack stub with no container registry to trigger early validation.""" + + container_registry = None + + +class ContainerRegistryStubNoCreds: + """Container registry stub with missing credentials.""" + + def __init__(self): + self.credentials = None + + +class StackStubWithRegistryNoCreds: + """Stack stub with a container registry but no credentials.""" + + def __init__(self): + self.container_registry = ContainerRegistryStubNoCreds() + + +def test_build_modal_image_raises_when_no_registry() -> None: + with pytest.raises(RuntimeError) as e: + build_modal_image( + "repo/image:tag", StackStubNoRegistry(), environment=None + ) + assert "No Container registry found in the stack" in str(e.value) + + +def test_build_modal_image_raises_when_no_credentials() -> None: + with pytest.raises(RuntimeError) as e: + build_modal_image( + "repo/image:tag", StackStubWithRegistryNoCreds(), environment=None + ) + assert "No Docker credentials found for the container registry" in str( + e.value + ) + + +def test_get_modal_stack_validator_returns_stackvalidator_instance() -> None: + validator = get_modal_stack_validator() + assert isinstance(validator, StackValidator) + # Ensure the validator exposes a validate-like callable (public contract) + assert hasattr(validator, "validate") and callable(validator.validate) diff --git a/tests/integration/integrations/modal/step_operators/test_utils_env.py b/tests/integration/integrations/modal/step_operators/test_utils_env.py new file mode 100644 index 00000000000..29acaf666f0 --- /dev/null +++ b/tests/integration/integrations/modal/step_operators/test_utils_env.py @@ -0,0 +1,119 @@ +import os + +import pytest + +# Skip this entire module if the optional dependency isn't available, because +# the module under test imports `modal` at import time. +pytest.importorskip("modal") + +from zenml.integrations.modal.utils import ( + MODAL_CONFIG_PATH, + MODAL_ENVIRONMENT_ENV, + MODAL_TOKEN_ID_ENV, + MODAL_TOKEN_SECRET_ENV, + MODAL_WORKSPACE_ENV, + _set_env_if_present, + _validate_token_prefix, + setup_modal_client, +) + + +def test_set_env_if_present_sets_and_returns_true(monkeypatch) -> None: + monkeypatch.delenv("FOO_TEST", raising=False) + assert _set_env_if_present("FOO_TEST", "bar") is True + assert os.environ["FOO_TEST"] == "bar" + + +def test_set_env_if_present_ignores_none_and_empty(monkeypatch) -> None: + monkeypatch.delenv("FOO_TEST", raising=False) + + assert _set_env_if_present("FOO_TEST", None) is False + assert "FOO_TEST" not in os.environ + + assert _set_env_if_present("FOO_TEST", "") is False + assert "FOO_TEST" not in os.environ + + +def test_validate_token_prefix_warns_on_invalid_prefix(caplog) -> None: + with caplog.at_level("WARNING"): + _validate_token_prefix("invalid-token", "ak-", "Token ID") + assert "Expected prefix: ak-" in caplog.text + + +def test_setup_modal_client_sets_env_from_args(monkeypatch) -> None: + # Ensure a clean environment + monkeypatch.delenv(MODAL_TOKEN_ID_ENV, raising=False) + monkeypatch.delenv(MODAL_TOKEN_SECRET_ENV, raising=False) + monkeypatch.delenv(MODAL_WORKSPACE_ENV, raising=False) + monkeypatch.delenv(MODAL_ENVIRONMENT_ENV, raising=False) + + setup_modal_client( + token_id="ak-abc123", + token_secret="as-def456", + workspace="my-ws", + environment="main", + ) + + assert os.environ[MODAL_TOKEN_ID_ENV] == "ak-abc123" + assert os.environ[MODAL_TOKEN_SECRET_ENV] == "as-def456" + assert os.environ[MODAL_WORKSPACE_ENV] == "my-ws" + assert os.environ[MODAL_ENVIRONMENT_ENV] == "main" + + +def test_setup_modal_client_warns_when_only_one_token_provided( + monkeypatch, caplog +) -> None: + monkeypatch.delenv(MODAL_TOKEN_ID_ENV, raising=False) + monkeypatch.delenv(MODAL_TOKEN_SECRET_ENV, raising=False) + + with caplog.at_level("WARNING"): + setup_modal_client(token_id="ak-abc123", token_secret=None) + + assert "Only token ID provided" in caplog.text + assert os.environ[MODAL_TOKEN_ID_ENV] == "ak-abc123" + assert MODAL_TOKEN_SECRET_ENV not in os.environ + + +def test_setup_modal_client_prefers_args_over_env(monkeypatch) -> None: + monkeypatch.setenv(MODAL_TOKEN_ID_ENV, "ak-old") + monkeypatch.setenv(MODAL_TOKEN_SECRET_ENV, "as-old") + + setup_modal_client(token_id="ak-new", token_secret="as-new") + + assert os.environ[MODAL_TOKEN_ID_ENV] == "ak-new" + assert os.environ[MODAL_TOKEN_SECRET_ENV] == "as-new" + + +def test_setup_modal_client_logs_missing_config_when_no_tokens_and_no_config( + monkeypatch, caplog +) -> None: + # Clear tokens from environment + monkeypatch.delenv(MODAL_TOKEN_ID_ENV, raising=False) + monkeypatch.delenv(MODAL_TOKEN_SECRET_ENV, raising=False) + + # Pretend modal config is missing + monkeypatch.setenv( + "HOME", "/nonexistent-home-for-test" + ) # defensive isolation + monkeypatch.setattr("os.path.exists", lambda path: False) + + with caplog.at_level("WARNING"): + setup_modal_client() + + # Should warn that the default config file is missing + assert "No platform config found at" in caplog.text + assert str(MODAL_CONFIG_PATH) in caplog.text + + +def test_setup_modal_client_warns_on_bad_token_prefixes( + monkeypatch, caplog +) -> None: + monkeypatch.delenv(MODAL_TOKEN_ID_ENV, raising=False) + monkeypatch.delenv(MODAL_TOKEN_SECRET_ENV, raising=False) + + with caplog.at_level("WARNING"): + setup_modal_client(token_id="bad-id", token_secret="bad-secret") + + # Both ID and secret should trigger prefix warnings + assert "Expected prefix: ak-" in caplog.text + assert "Expected prefix: as-" in caplog.text From 6ffa074a25be82002f2473e900243beaf26a3205 Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Fri, 3 Oct 2025 17:22:44 +0200 Subject: [PATCH 13/35] Update docs page --- .../component-guide/step-operators/modal.md | 40 +++++++++++++------ 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/docs/book/component-guide/step-operators/modal.md b/docs/book/component-guide/step-operators/modal.md index 83c0ec99035..2f18b6266fb 100644 --- a/docs/book/component-guide/step-operators/modal.md +++ b/docs/book/component-guide/step-operators/modal.md @@ -36,6 +36,13 @@ To use the Modal step operator, we need: cloud artifact store supported by ZenML will work with Modal. * A cloud container registry as part of your stack. Any cloud container registry supported by ZenML will work with Modal. +* An Image Builder in your stack. ZenML uses it to build the Docker image that + runs on Modal. + +The Modal step operator also respects the following environment variables if set: +- MODAL_TOKEN_ID, MODAL_TOKEN_SECRET: authentication tokens +- MODAL_WORKSPACE: workspace name +- MODAL_ENVIRONMENT: Modal environment name (e.g., "main") We can then register the step operator: @@ -66,30 +73,42 @@ You can specify the hardware requirements for each step using the `ResourceSettings` class as described in our documentation on [resource settings](https://docs.zenml.io/user-guides/tutorial/distributed-training): ```python +from zenml import step from zenml.config import ResourceSettings from zenml.integrations.modal.flavors import ModalStepOperatorSettings -modal_settings = ModalStepOperatorSettings(gpu="A100") +modal_settings = ModalStepOperatorSettings( + gpu="A100", # GPU type (e.g., "T4", "A100") + # region="us-east-1", # optional, enterprise/team only + # cloud="aws", # optional, enterprise/team only + # modal_environment="main", # optional + # timeout=86400, # optional, seconds +) + resource_settings = ResourceSettings( - cpu=2, - memory="32GB" + cpu_count=2, + memory="32GB", + # gpu_count=1, # optional; if omitted and a GPU type is set, defaults to 1 GPU ) @step( - step_operator=True, + step_operator=True, # or the specific name, e.g., step_operator="" settings={ "step_operator": modal_settings, - "resources": resource_settings - } + "resources": resource_settings, + }, ) def my_modal_step(): ... ``` +Important: +- If you request GPUs with `ResourceSettings.gpu_count > 0`, you must also specify a GPU type via `ModalStepOperatorSettings.gpu`; otherwise the run will fail with a validation error. +- If a GPU type is set but `gpu_count == 0`, ZenML defaults to 1 GPU and logs a warning. +- `cpu_count` must be an integer. `memory` can be a string like "32GB" or an integer amount of bytes. + {% hint style="info" %} Note that the `cpu` parameter in `ResourceSettings` currently only accepts a single integer value. This specifies a soft minimum limit - Modal will guarantee at least this many physical cores, but the actual usage could be higher. The CPU cores/hour will also determine the minimum price paid for the compute resources. - -For example, with the configuration above (2 CPUs and 32GB memory), the minimum cost would be approximately $1.03 per hour ((0.135 * 2) + (0.024 * 32) = $1.03). {% endhint %} This will run `my_modal_step` on a Modal instance with 1 A100 GPU, 2 CPUs, and @@ -108,8 +127,3 @@ pipeline execution failures. In the case of failures, however, Modal provides detailed error messages that can help identify what is incompatible. See more in the [Modal docs on region selection](https://modal.com/docs/guide/region-selection) for more details. - - -
ZenML Scarf
- - From 00232ea84fb548525cff3abb6b063c1152b4d9bc Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Fri, 3 Oct 2025 17:38:06 +0200 Subject: [PATCH 14/35] better error handling --- .../step_operators/modal_step_operator.py | 82 +++++++++++++------ src/zenml/integrations/modal/utils.py | 43 +++++++--- 2 files changed, 91 insertions(+), 34 deletions(-) diff --git a/src/zenml/integrations/modal/step_operators/modal_step_operator.py b/src/zenml/integrations/modal/step_operators/modal_step_operator.py index fb6988c5765..abfc83c0b8c 100644 --- a/src/zenml/integrations/modal/step_operators/modal_step_operator.py +++ b/src/zenml/integrations/modal/step_operators/modal_step_operator.py @@ -205,11 +205,18 @@ def launch( or self.config.modal_environment, ) - zenml_image = build_modal_image(image_name, stack, environment) + try: + zenml_image = build_modal_image(image_name, stack, environment) + except Exception as e: + raise RuntimeError( + "Failed to construct the Modal image from your Docker registry. " + "Action required: ensure your ZenML stack's container registry is configured with valid credentials " + "and that the base image exists and is accessible. " + f"Context: image='{image_name}'." + ) from e resource_settings = info.config.resource_settings - # Compute and validate the GPU argument via the helper. gpu_values = self._compute_modal_gpu_arg(settings, resource_settings) app = modal.App( @@ -218,26 +225,55 @@ def launch( async def run_sandbox() -> None: with modal.enable_output(): - async with app.run(): - memory_int = ( - int(mb) - if (mb := resource_settings.get_memory(ByteUnit.MB)) - else None - ) - sb = await modal.Sandbox.create.aio( - "bash", - "-c", - " ".join(entrypoint_command), - image=zenml_image, - gpu=gpu_values, - cpu=resource_settings.cpu_count, - memory=memory_int, - cloud=settings.cloud, - region=settings.region, - app=app, - timeout=settings.timeout, - ) - - await sb.wait.aio() + try: + async with app.run(): + # Compute memory lazily to preserve original semantics and avoid + # accidental integer conversion if value is missing. + memory_int = ( + int(mb) + if ( + mb := resource_settings.get_memory(ByteUnit.MB) + ) + else None + ) + + try: + sb = await modal.Sandbox.create.aio( + "bash", + "-c", + " ".join(entrypoint_command), + image=zenml_image, + gpu=gpu_values, + cpu=resource_settings.cpu_count, + memory=memory_int, + cloud=settings.cloud, + region=settings.region, + app=app, + timeout=settings.timeout, + ) + except Exception as e: + raise RuntimeError( + "Failed to create a Modal sandbox. " + "Action required: verify that the referenced Docker image exists and is accessible, " + "the requested resources are available (gpu/region/cloud), and your Modal workspace " + "permissions allow sandbox creation. " + f"Context: image='{image_name}', gpu='{gpu_values}', region='{settings.region}', cloud='{settings.cloud}'." + ) from e + + try: + await sb.wait.aio() + except Exception as e: + raise RuntimeError( + "Modal sandbox execution failed. " + "Action required: inspect the step logs in Modal, validate your entrypoint command, " + "and confirm that dependencies are available in the image/environment." + ) from e + except Exception as e: + raise RuntimeError( + "Failed to initialize Modal application context (authentication / workspace / environment). " + "Action required: make sure you're authenticated with Modal (run 'modal token new' or set " + "MODAL_TOKEN_ID and MODAL_TOKEN_SECRET), and that the configured workspace/environment exist " + "and you have access to them." + ) from e asyncio.run(run_sandbox()) diff --git a/src/zenml/integrations/modal/utils.py b/src/zenml/integrations/modal/utils.py index 25155f03d28..c1c7ee11b5e 100644 --- a/src/zenml/integrations/modal/utils.py +++ b/src/zenml/integrations/modal/utils.py @@ -172,19 +172,40 @@ def build_modal_image( "No Docker credentials found for the container registry." ) - registry_secret = modal.Secret.from_dict( - { - "REGISTRY_USERNAME": docker_username, - "REGISTRY_PASSWORD": docker_password, - } - ) - - modal_image = modal.Image.from_registry( - image_name, secret=registry_secret - ).pip_install("modal") + try: + registry_secret = modal.Secret.from_dict( + { + "REGISTRY_USERNAME": docker_username, + "REGISTRY_PASSWORD": docker_password, + } + ) + except Exception as e: + raise RuntimeError( + "Failed to create Modal secret for container registry credentials. " + "Action required: verify your container registry credentials in the active ZenML stack and " + "ensure your Modal account has permission to create secrets." + ) from e + + try: + modal_image = modal.Image.from_registry( + image_name, secret=registry_secret + ).pip_install("modal") + except Exception as e: + raise RuntimeError( + "Failed to construct a Modal image from the specified Docker base image. " + "Action required: ensure the image exists and is accessible from your container registry, " + "and that the provided credentials are correct." + ) from e if environment: - modal_image = modal_image.env(environment) + try: + modal_image = modal_image.env(environment) + except Exception as e: + # This is a defensive guard; env composition is local, but we still provide guidance. + raise RuntimeError( + "Failed to apply environment variables to the Modal image. " + "Action required: verify that environment variable keys and values are valid strings." + ) from e return modal_image From ec878bc4b613cc2508fe799ea54a9353db1505fc Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Fri, 3 Oct 2025 17:39:30 +0200 Subject: [PATCH 15/35] Remove excess comments --- src/zenml/integrations/modal/utils.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/src/zenml/integrations/modal/utils.py b/src/zenml/integrations/modal/utils.py index c1c7ee11b5e..2877973da37 100644 --- a/src/zenml/integrations/modal/utils.py +++ b/src/zenml/integrations/modal/utils.py @@ -172,19 +172,12 @@ def build_modal_image( "No Docker credentials found for the container registry." ) - try: - registry_secret = modal.Secret.from_dict( - { - "REGISTRY_USERNAME": docker_username, - "REGISTRY_PASSWORD": docker_password, - } - ) - except Exception as e: - raise RuntimeError( - "Failed to create Modal secret for container registry credentials. " - "Action required: verify your container registry credentials in the active ZenML stack and " - "ensure your Modal account has permission to create secrets." - ) from e + registry_secret = modal.Secret.from_dict( + { + "REGISTRY_USERNAME": docker_username, + "REGISTRY_PASSWORD": docker_password, + } + ) try: modal_image = modal.Image.from_registry( From 37fd631decb0019dc2e87ebf38f7495ac90dba7d Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Fri, 3 Oct 2025 17:46:10 +0200 Subject: [PATCH 16/35] Add type hint for resource_settings parameter Improves type safety by explicitly typing the resource_settings parameter as ResourceSettings in _compute_modal_gpu_arg method. --- .../integrations/modal/step_operators/modal_step_operator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/zenml/integrations/modal/step_operators/modal_step_operator.py b/src/zenml/integrations/modal/step_operators/modal_step_operator.py index abfc83c0b8c..3bdc3ed3d5a 100644 --- a/src/zenml/integrations/modal/step_operators/modal_step_operator.py +++ b/src/zenml/integrations/modal/step_operators/modal_step_operator.py @@ -20,7 +20,7 @@ from zenml.client import Client from zenml.config.build_configuration import BuildConfiguration -from zenml.config.resource_settings import ByteUnit +from zenml.config.resource_settings import ByteUnit, ResourceSettings from zenml.exceptions import StackComponentInterfaceError from zenml.integrations.modal.flavors import ( ModalStepOperatorConfig, @@ -105,7 +105,7 @@ def get_docker_builds( def _compute_modal_gpu_arg( self, settings: ModalStepOperatorSettings, - resource_settings, + resource_settings: ResourceSettings, ) -> Optional[str]: """Compute and validate the Modal 'gpu' argument. From 72d70e243516bff5f33db91e6bc0b9b9dfe1b87a Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Fri, 3 Oct 2025 17:50:47 +0200 Subject: [PATCH 17/35] Enforce timeout constraints in ModalStepOperatorSettings Adds Pydantic field constraints to enforce the timeout range (1-86400 seconds) that was previously only documented. Uses ge=1 and le=DEFAULT_TIMEOUT_SECONDS to validate input at instantiation time. --- .../flavors/modal_step_operator_flavor.py | 2 + .../step_operators/test_flavor_config.py | 38 +++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/src/zenml/integrations/modal/flavors/modal_step_operator_flavor.py b/src/zenml/integrations/modal/flavors/modal_step_operator_flavor.py index fc707738caa..13d55eea9d1 100644 --- a/src/zenml/integrations/modal/flavors/modal_step_operator_flavor.py +++ b/src/zenml/integrations/modal/flavors/modal_step_operator_flavor.py @@ -75,6 +75,8 @@ class ModalStepOperatorSettings(BaseSettings): ) timeout: int = Field( DEFAULT_TIMEOUT_SECONDS, + ge=1, + le=DEFAULT_TIMEOUT_SECONDS, description=f"Maximum execution time in seconds for step completion. Must be between 1 and {DEFAULT_TIMEOUT_SECONDS} seconds. " f"Examples: 3600 (1 hour), 7200 (2 hours), {DEFAULT_TIMEOUT_SECONDS} (24 hours maximum). " "Step execution will be terminated if it exceeds this timeout", diff --git a/tests/integration/integrations/modal/step_operators/test_flavor_config.py b/tests/integration/integrations/modal/step_operators/test_flavor_config.py index ac464892c3a..05957a18c4c 100644 --- a/tests/integration/integrations/modal/step_operators/test_flavor_config.py +++ b/tests/integration/integrations/modal/step_operators/test_flavor_config.py @@ -1,3 +1,6 @@ +import pytest +from pydantic import ValidationError + from zenml.integrations.modal import MODAL_STEP_OPERATOR_FLAVOR from zenml.integrations.modal.flavors.modal_step_operator_flavor import ( DEFAULT_TIMEOUT_SECONDS, @@ -20,3 +23,38 @@ def test_modal_config_is_remote_true() -> None: def test_modal_flavor_name_constant() -> None: flavor = ModalStepOperatorFlavor() assert flavor.name == MODAL_STEP_OPERATOR_FLAVOR + + +def test_modal_settings_timeout_accepts_valid_values() -> None: + # Test minimum valid value + settings = ModalStepOperatorSettings(timeout=1) + assert settings.timeout == 1 + + # Test maximum valid value + settings = ModalStepOperatorSettings(timeout=DEFAULT_TIMEOUT_SECONDS) + assert settings.timeout == DEFAULT_TIMEOUT_SECONDS + + # Test mid-range value + settings = ModalStepOperatorSettings(timeout=3600) + assert settings.timeout == 3600 + + +def test_modal_settings_timeout_rejects_below_minimum() -> None: + with pytest.raises(ValidationError) as exc_info: + ModalStepOperatorSettings(timeout=0) + + assert "greater than or equal to 1" in str(exc_info.value) + + +def test_modal_settings_timeout_rejects_above_maximum() -> None: + with pytest.raises(ValidationError) as exc_info: + ModalStepOperatorSettings(timeout=DEFAULT_TIMEOUT_SECONDS + 1) + + assert "less than or equal to" in str(exc_info.value) + + +def test_modal_settings_timeout_rejects_negative() -> None: + with pytest.raises(ValidationError) as exc_info: + ModalStepOperatorSettings(timeout=-1) + + assert "greater than or equal to 1" in str(exc_info.value) From b04ca61afd18c7f582a41265904b8eba1e574b12 Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Fri, 3 Oct 2025 17:59:54 +0200 Subject: [PATCH 18/35] Improve Modal sandbox command execution safety Changes command execution from shell-based to direct varargs invocation, eliminating shell metacharacter issues and removing the implicit bash dependency. Adds proper argument quoting via shlex.join() as a fallback for backward compatibility with older Modal client versions that don't support varargs command passing. --- .../step_operators/modal_step_operator.py | 55 ++++++++++++++----- 1 file changed, 42 insertions(+), 13 deletions(-) diff --git a/src/zenml/integrations/modal/step_operators/modal_step_operator.py b/src/zenml/integrations/modal/step_operators/modal_step_operator.py index 3bdc3ed3d5a..04d8eb70132 100644 --- a/src/zenml/integrations/modal/step_operators/modal_step_operator.py +++ b/src/zenml/integrations/modal/step_operators/modal_step_operator.py @@ -14,6 +14,7 @@ """Modal step operator implementation.""" import asyncio +import shlex from typing import TYPE_CHECKING, Dict, List, Optional, Type, cast import modal @@ -238,19 +239,47 @@ async def run_sandbox() -> None: ) try: - sb = await modal.Sandbox.create.aio( - "bash", - "-c", - " ".join(entrypoint_command), - image=zenml_image, - gpu=gpu_values, - cpu=resource_settings.cpu_count, - memory=memory_int, - cloud=settings.cloud, - region=settings.region, - app=app, - timeout=settings.timeout, - ) + # Prefer direct execution of the command vector to avoid shell-quoting pitfalls + # and the implicit requirement that the image contains bash. This makes argument + # handling robust when values include spaces or shell metacharacters. + if ( + not entrypoint_command + or not entrypoint_command[0] + ): + raise ValueError( + "Empty step entrypoint command is not allowed." + ) + + try: + sb = await modal.Sandbox.create.aio( + *entrypoint_command, + image=zenml_image, + gpu=gpu_values, + cpu=resource_settings.cpu_count, + memory=memory_int, + cloud=settings.cloud, + region=settings.region, + app=app, + timeout=settings.timeout, + ) + except TypeError: + # Some Modal client versions may not accept varargs for the command. + # Fall back to a shell-quoted invocation to preserve compatibility while + # still quoting arguments safely. + quoted = shlex.join(entrypoint_command) + sb = await modal.Sandbox.create.aio( + "bash", + "-c", + quoted, + image=zenml_image, + gpu=gpu_values, + cpu=resource_settings.cpu_count, + memory=memory_int, + cloud=settings.cloud, + region=settings.region, + app=app, + timeout=settings.timeout, + ) except Exception as e: raise RuntimeError( "Failed to create a Modal sandbox. " From 581eda1b3ffab4ce4a04c9ded6159c73d9916c76 Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Fri, 3 Oct 2025 18:44:55 +0200 Subject: [PATCH 19/35] Enhance Modal step operator error handling for sandbox creation Improves error handling during the creation of a Modal sandbox by adding a fallback mechanism for command execution. If the direct varargs invocation fails due to a TypeError, the code now falls back to a shell-quoted invocation using shlex.join() to ensure compatibility with older Modal client versions. This change enhances robustness and maintains backward compatibility while preserving argument safety. --- .../step_operators/modal_step_operator.py | 42 +++++-------------- 1 file changed, 11 insertions(+), 31 deletions(-) diff --git a/src/zenml/integrations/modal/step_operators/modal_step_operator.py b/src/zenml/integrations/modal/step_operators/modal_step_operator.py index 04d8eb70132..e66a088c71b 100644 --- a/src/zenml/integrations/modal/step_operators/modal_step_operator.py +++ b/src/zenml/integrations/modal/step_operators/modal_step_operator.py @@ -14,7 +14,6 @@ """Modal step operator implementation.""" import asyncio -import shlex from typing import TYPE_CHECKING, Dict, List, Optional, Type, cast import modal @@ -250,36 +249,17 @@ async def run_sandbox() -> None: "Empty step entrypoint command is not allowed." ) - try: - sb = await modal.Sandbox.create.aio( - *entrypoint_command, - image=zenml_image, - gpu=gpu_values, - cpu=resource_settings.cpu_count, - memory=memory_int, - cloud=settings.cloud, - region=settings.region, - app=app, - timeout=settings.timeout, - ) - except TypeError: - # Some Modal client versions may not accept varargs for the command. - # Fall back to a shell-quoted invocation to preserve compatibility while - # still quoting arguments safely. - quoted = shlex.join(entrypoint_command) - sb = await modal.Sandbox.create.aio( - "bash", - "-c", - quoted, - image=zenml_image, - gpu=gpu_values, - cpu=resource_settings.cpu_count, - memory=memory_int, - cloud=settings.cloud, - region=settings.region, - app=app, - timeout=settings.timeout, - ) + sb = await modal.Sandbox.create.aio( + *entrypoint_command, + image=zenml_image, + gpu=gpu_values, + cpu=resource_settings.cpu_count, + memory=memory_int, + cloud=settings.cloud, + region=settings.region, + app=app, + timeout=settings.timeout, + ) except Exception as e: raise RuntimeError( "Failed to create a Modal sandbox. " From ccef58b8c58d4e165d2c9f25c9d9e581e59a3302 Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Fri, 3 Oct 2025 18:45:22 +0200 Subject: [PATCH 20/35] Remove excessive comment --- src/zenml/integrations/modal/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/zenml/integrations/modal/utils.py b/src/zenml/integrations/modal/utils.py index 2877973da37..df39b5bb63e 100644 --- a/src/zenml/integrations/modal/utils.py +++ b/src/zenml/integrations/modal/utils.py @@ -129,7 +129,6 @@ def setup_modal_client( "Run 'modal token new' to set up authentication." ) - # Configure Modal workspace/environment context if provided. if workspace: _set_env_if_present(MODAL_WORKSPACE_ENV, workspace) if environment: From b2154d68701605d48bb55e766b9a1a20adc8dd6e Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Sun, 5 Oct 2025 11:18:12 +0200 Subject: [PATCH 21/35] Refactor the get_gpu_values out to utils --- .../step_operators/modal_step_operator.py | 83 +------------------ src/zenml/integrations/modal/utils.py | 77 ++++++++++++++++- .../test_modal_step_operator_gpu.py | 39 +++------ 3 files changed, 91 insertions(+), 108 deletions(-) diff --git a/src/zenml/integrations/modal/step_operators/modal_step_operator.py b/src/zenml/integrations/modal/step_operators/modal_step_operator.py index e66a088c71b..ca0d0a6ebc2 100644 --- a/src/zenml/integrations/modal/step_operators/modal_step_operator.py +++ b/src/zenml/integrations/modal/step_operators/modal_step_operator.py @@ -20,14 +20,14 @@ from zenml.client import Client from zenml.config.build_configuration import BuildConfiguration -from zenml.config.resource_settings import ByteUnit, ResourceSettings -from zenml.exceptions import StackComponentInterfaceError +from zenml.config.resource_settings import ByteUnit from zenml.integrations.modal.flavors import ( ModalStepOperatorConfig, ModalStepOperatorSettings, ) from zenml.integrations.modal.utils import ( build_modal_image, + get_gpu_values, get_modal_stack_validator, setup_modal_client, ) @@ -102,83 +102,6 @@ def get_docker_builds( return builds - def _compute_modal_gpu_arg( - self, - settings: ModalStepOperatorSettings, - resource_settings: ResourceSettings, - ) -> Optional[str]: - """Compute and validate the Modal 'gpu' argument. - - Why this exists: - - Modal expects GPU resources as a string (e.g., 'T4' or 'A100:2'). - - ZenML splits GPU intent between a 'type' (settings.gpu) and a 'count' - (resource_settings.gpu_count). This helper reconciles the two and - enforces rules that prevent silent CPU fallbacks or ambiguous configs. - - Rules enforced: - - If a positive gpu_count is requested without specifying a GPU type, - raise a StackComponentInterfaceError to make the mismatch explicit. - - If a GPU type is specified but gpu_count == 0, we interpret this as - requesting 1 GPU (Modal semantics for a bare type string) and log a - warning to explain the behavior and how to request CPU-only runs. - - If neither a type nor a positive count is requested, return None for - CPU-only execution. - - Otherwise, format the string as '' for one GPU or ':' - for multiple GPUs. - """ - # Normalize GPU type: treat empty or whitespace-only strings as None to avoid - # surprising behavior when user-provided values are malformed (e.g., " "). - gpu_type_raw = settings.gpu - gpu_type = gpu_type_raw.strip() if gpu_type_raw is not None else None - if gpu_type == "": - gpu_type = None - - # Coerce and validate gpu_count to ensure it's a non-negative integer if provided. - gpu_count = resource_settings.gpu_count - if gpu_count is not None: - try: - gpu_count = int(gpu_count) - except (TypeError, ValueError): - raise StackComponentInterfaceError( - f"Invalid GPU count '{gpu_count}'. Must be a non-negative integer." - ) - if gpu_count < 0: - raise StackComponentInterfaceError( - f"Invalid GPU count '{gpu_count}'. Must be >= 0." - ) - - # Scenario 1: Count requested but type missing -> invalid configuration. - if gpu_type is None: - if gpu_count is not None and gpu_count > 0: - raise StackComponentInterfaceError( - "GPU resources requested (gpu_count > 0) but no GPU type was specified " - "in Modal settings. Please set a GPU type (e.g., 'T4', 'A100') via " - "ModalStepOperatorSettings.gpu or @step(settings={'modal': {'gpu': ''}}), " - "or set gpu_count=0 to run on CPU." - ) - # CPU-only scenarios (no type, count is None or 0). - return None - - # Scenario 2: Type set but count == 0 -> warn and default to 1 GPU. - if gpu_count == 0: - logger.warning( - "Modal GPU type '%s' is configured but ResourceSettings.gpu_count is 0. " - "Defaulting to 1 GPU. To run on CPU only, remove the GPU type or ensure " - "gpu_count=0 with no GPU type configured.", - gpu_type, - ) - return gpu_type # Implicitly 1 GPU for Modal - - # Valid configurations - if gpu_count is None: - return gpu_type # Implicitly 1 GPU for Modal - - if gpu_count > 0: - return f"{gpu_type}:{gpu_count}" - - # Defensive fallback shouldn't be reachable due to validation above; return CPU-only None if hit. - return None - def launch( self, info: "StepRunInfo", @@ -217,7 +140,7 @@ def launch( resource_settings = info.config.resource_settings - gpu_values = self._compute_modal_gpu_arg(settings, resource_settings) + gpu_values = get_gpu_values(settings, resource_settings) app = modal.App( f"zenml-{info.run_name}-{info.step_run_id}-{info.pipeline_step_name}" diff --git a/src/zenml/integrations/modal/utils.py b/src/zenml/integrations/modal/utils.py index df39b5bb63e..3d0b5bea85a 100644 --- a/src/zenml/integrations/modal/utils.py +++ b/src/zenml/integrations/modal/utils.py @@ -14,16 +14,21 @@ """Shared utilities for Modal integration components.""" import os -from typing import Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple import modal +from zenml.config.resource_settings import ResourceSettings from zenml.enums import StackComponentType +from zenml.exceptions import StackComponentInterfaceError from zenml.logger import get_logger from zenml.stack import Stack, StackValidator logger = get_logger(__name__) +if TYPE_CHECKING: + from zenml.integrations.modal.flavors import ModalStepOperatorSettings + MODAL_TOKEN_ID_ENV = "MODAL_TOKEN_ID" MODAL_TOKEN_SECRET_ENV = "MODAL_TOKEN_SECRET" MODAL_WORKSPACE_ENV = "MODAL_WORKSPACE" @@ -202,6 +207,76 @@ def build_modal_image( return modal_image +def get_gpu_values( + settings: "ModalStepOperatorSettings", + resource_settings: "ResourceSettings", +) -> Optional[str]: + """Compute and validate the Modal ``gpu`` argument string. + + Modal expects GPU resources as either ``None`` (CPU only), a GPU type string + like ``"A100"`` (implicitly a single GPU), or ``"A100:2"`` when multiple + GPUs of the same type are requested. Within ZenML, the GPU type is captured + in :class:`ModalStepOperatorSettings` while the count lives in + :class:`~zenml.config.resource_settings.ResourceSettings`. This helper + reconciles both sources so other Modal components can reuse the same + validation rules. + + Args: + settings: The Modal step operator settings describing the GPU type. + resource_settings: Resource constraints for the step, providing the GPU count. + + Returns: + A Modal-compatible GPU specification string or ``None`` when running on CPU. + + Raises: + StackComponentInterfaceError: If the configuration is inconsistent or invalid. + """ + gpu_type_raw = settings.gpu + gpu_type = gpu_type_raw.strip() if gpu_type_raw is not None else None + if gpu_type == "": + gpu_type = None + + gpu_count = resource_settings.gpu_count + if gpu_count is not None: + try: + gpu_count = int(gpu_count) + except (TypeError, ValueError): + raise StackComponentInterfaceError( + f"Invalid GPU count '{gpu_count}'. Must be a non-negative integer." + ) + if gpu_count < 0: + raise StackComponentInterfaceError( + f"Invalid GPU count '{gpu_count}'. Must be >= 0." + ) + + if gpu_type is None: + if gpu_count is not None and gpu_count > 0: + raise StackComponentInterfaceError( + "GPU resources requested (gpu_count > 0) but no GPU type was specified " + "in Modal settings. Please set a GPU type (e.g., 'T4', 'A100') via " + "ModalStepOperatorSettings.gpu or @step(settings={'modal': {'gpu': ''}}), " + "or set gpu_count=0 to run on CPU." + ) + return None + + if gpu_count == 0: + logger.warning( + "Modal GPU type '%s' is configured but ResourceSettings.gpu_count is 0. " + "Defaulting to 1 GPU. To run on CPU only, remove the GPU type or ensure " + "gpu_count=0 with no GPU type configured.", + gpu_type, + ) + return gpu_type + + if gpu_count is None: + return gpu_type + + if gpu_count > 0: + return f"{gpu_type}:{gpu_count}" + + return None + + def get_modal_stack_validator() -> StackValidator: """Get a stack validator for Modal components. diff --git a/tests/integration/integrations/modal/step_operators/test_modal_step_operator_gpu.py b/tests/integration/integrations/modal/step_operators/test_modal_step_operator_gpu.py index 07f575b4c86..c2e0b4491f7 100644 --- a/tests/integration/integrations/modal/step_operators/test_modal_step_operator_gpu.py +++ b/tests/integration/integrations/modal/step_operators/test_modal_step_operator_gpu.py @@ -6,9 +6,7 @@ from zenml.exceptions import StackComponentInterfaceError from zenml.integrations.modal.flavors import ModalStepOperatorSettings -from zenml.integrations.modal.step_operators.modal_step_operator import ( - ModalStepOperator, -) +from zenml.integrations.modal.utils import get_gpu_values class ResourceSettingsStub: @@ -22,24 +20,17 @@ def __init__(self, gpu_count): self.gpu_count = gpu_count -def _make_operator() -> ModalStepOperator: - # Bypass BaseStepOperator initialization since we only need the helper. - return ModalStepOperator.__new__(ModalStepOperator) - - def test_gpu_arg_none_when_no_type_and_no_count() -> None: - op = _make_operator() settings = ModalStepOperatorSettings(gpu=None) rs = ResourceSettingsStub(gpu_count=None) - assert op._compute_modal_gpu_arg(settings, rs) is None + assert get_gpu_values(settings, rs) is None def test_gpu_arg_raises_when_count_without_type() -> None: - op = _make_operator() settings = ModalStepOperatorSettings(gpu=None) rs = ResourceSettingsStub(gpu_count=1) with pytest.raises(StackComponentInterfaceError) as e: - op._compute_modal_gpu_arg(settings, rs) + get_gpu_values(settings, rs) assert ( "GPU resources requested (gpu_count > 0) but no GPU type was specified" in str(e.value) @@ -47,67 +38,61 @@ def test_gpu_arg_raises_when_count_without_type() -> None: def test_gpu_arg_type_with_no_count_returns_type() -> None: - op = _make_operator() settings = ModalStepOperatorSettings(gpu="A100") rs = ResourceSettingsStub(gpu_count=None) - assert op._compute_modal_gpu_arg(settings, rs) == "A100" + assert get_gpu_values(settings, rs) == "A100" def test_gpu_arg_type_with_count_returns_type_colon_count() -> None: - op = _make_operator() settings = ModalStepOperatorSettings(gpu="A100") rs_two = ResourceSettingsStub(gpu_count=2) - assert op._compute_modal_gpu_arg(settings, rs_two) == "A100:2" + assert get_gpu_values(settings, rs_two) == "A100:2" rs_one = ResourceSettingsStub(gpu_count=1) - assert op._compute_modal_gpu_arg(settings, rs_one) == "A100:1" + assert get_gpu_values(settings, rs_one) == "A100:1" def test_gpu_arg_type_with_zero_count_warns_and_defaults_to_single_gpu( caplog, ) -> None: - op = _make_operator() settings = ModalStepOperatorSettings(gpu="A100") rs = ResourceSettingsStub(gpu_count=0) with caplog.at_level("WARNING"): - result = op._compute_modal_gpu_arg(settings, rs) + result = get_gpu_values(settings, rs) assert result == "A100" assert "Defaulting to 1 GPU" in caplog.text def test_gpu_arg_invalid_negative_count_raises() -> None: - op = _make_operator() settings = ModalStepOperatorSettings(gpu="T4") rs = ResourceSettingsStub(gpu_count=-1) with pytest.raises(StackComponentInterfaceError) as e: - op._compute_modal_gpu_arg(settings, rs) + get_gpu_values(settings, rs) assert "Invalid GPU count" in str(e.value) def test_gpu_arg_non_integer_count_raises() -> None: - op = _make_operator() settings = ModalStepOperatorSettings(gpu="T4") rs = ResourceSettingsStub(gpu_count="two") with pytest.raises(StackComponentInterfaceError) as e: - op._compute_modal_gpu_arg(settings, rs) + get_gpu_values(settings, rs) assert "Invalid GPU count" in str(e.value) def test_gpu_arg_whitespace_type_treated_as_none_behavior() -> None: - op = _make_operator() settings = ModalStepOperatorSettings(gpu=" ") # With positive GPU count this should raise since type is treated as None. rs_positive = ResourceSettingsStub(gpu_count=2) with pytest.raises(StackComponentInterfaceError): - op._compute_modal_gpu_arg(settings, rs_positive) + get_gpu_values(settings, rs_positive) # With zero or None count, this should be CPU-only (None). rs_zero = ResourceSettingsStub(gpu_count=0) - assert op._compute_modal_gpu_arg(settings, rs_zero) is None + assert get_gpu_values(settings, rs_zero) is None rs_none = ResourceSettingsStub(gpu_count=None) - assert op._compute_modal_gpu_arg(settings, rs_none) is None + assert get_gpu_values(settings, rs_none) is None From e4d4af0bb750deec2210202878f53090dfdf3317 Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Mon, 6 Oct 2025 10:07:27 +0200 Subject: [PATCH 22/35] Fix darglint docstring errors in Modal integration Add missing parameter documentation to helper functions and exception documentation to the launch method to satisfy darglint requirements. --- .../modal/step_operators/modal_step_operator.py | 4 ++++ src/zenml/integrations/modal/utils.py | 9 +++++++++ 2 files changed, 13 insertions(+) diff --git a/src/zenml/integrations/modal/step_operators/modal_step_operator.py b/src/zenml/integrations/modal/step_operators/modal_step_operator.py index ca0d0a6ebc2..fe6a120ff8c 100644 --- a/src/zenml/integrations/modal/step_operators/modal_step_operator.py +++ b/src/zenml/integrations/modal/step_operators/modal_step_operator.py @@ -114,6 +114,10 @@ def launch( info: The step run information. entrypoint_command: The entrypoint command for the step. environment: The environment variables for the step. + + Raises: + RuntimeError: If Modal image construction fails, sandbox creation fails, + sandbox execution fails, or Modal application initialization fails. """ settings = cast(ModalStepOperatorSettings, self.get_settings(info)) image_name = info.get_image(key=MODAL_STEP_OPERATOR_DOCKER_IMAGE_KEY) diff --git a/src/zenml/integrations/modal/utils.py b/src/zenml/integrations/modal/utils.py index 3d0b5bea85a..1f8789f2d03 100644 --- a/src/zenml/integrations/modal/utils.py +++ b/src/zenml/integrations/modal/utils.py @@ -42,6 +42,11 @@ def _validate_token_prefix( """Warn if a credential doesn't match Modal's expected prefix. This helps catch misconfigurations early without logging secret content. + + Args: + value: The credential value to validate. + expected_prefix: The prefix that the credential should start with. + label: Human-readable label for the credential (used in warning messages). """ if not value.startswith(expected_prefix): logger.warning( @@ -52,6 +57,10 @@ def _validate_token_prefix( def _set_env_if_present(var_name: str, value: Optional[str]) -> bool: """Set an environment variable only if a non-empty value is provided. + Args: + var_name: The name of the environment variable to set. + value: The value to set for the environment variable, or None. + Returns: True if the environment variable was set, False otherwise. """ From 7bcf68cc2e28a65ddb1ba7e3708207b7b15cee6c Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Mon, 6 Oct 2025 13:35:22 +0200 Subject: [PATCH 23/35] Small changes --- .../modal/step_operators/modal_step_operator.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/zenml/integrations/modal/step_operators/modal_step_operator.py b/src/zenml/integrations/modal/step_operators/modal_step_operator.py index fe6a120ff8c..22f4b24c973 100644 --- a/src/zenml/integrations/modal/step_operators/modal_step_operator.py +++ b/src/zenml/integrations/modal/step_operators/modal_step_operator.py @@ -136,7 +136,7 @@ def launch( zenml_image = build_modal_image(image_name, stack, environment) except Exception as e: raise RuntimeError( - "Failed to construct the Modal image from your Docker registry. " + "Failed to build the Modal image from your Docker registry. " "Action required: ensure your ZenML stack's container registry is configured with valid credentials " "and that the base image exists and is accessible. " f"Context: image='{image_name}'." @@ -154,8 +154,6 @@ async def run_sandbox() -> None: with modal.enable_output(): try: async with app.run(): - # Compute memory lazily to preserve original semantics and avoid - # accidental integer conversion if value is missing. memory_int = ( int(mb) if ( @@ -165,9 +163,6 @@ async def run_sandbox() -> None: ) try: - # Prefer direct execution of the command vector to avoid shell-quoting pitfalls - # and the implicit requirement that the image contains bash. This makes argument - # handling robust when values include spaces or shell metacharacters. if ( not entrypoint_command or not entrypoint_command[0] From f4c753aaacc0599fb624aedd7c92b63075deb8de Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Mon, 6 Oct 2025 13:38:00 +0200 Subject: [PATCH 24/35] mypy fix --- src/zenml/cli/cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zenml/cli/cli.py b/src/zenml/cli/cli.py index baa29bafbe5..01346581782 100644 --- a/src/zenml/cli/cli.py +++ b/src/zenml/cli/cli.py @@ -43,7 +43,7 @@ def __init__( commands: Optional[ Union[Dict[str, click.Command], Sequence[click.Command]] ] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: """Initialize the Tag group. From ae065d1daadb6f608e82edfd2ab3e1b037e1f157 Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Mon, 6 Oct 2025 13:46:29 +0200 Subject: [PATCH 25/35] Tests go in the right folders --- .../integrations/modal/flavors/__init__.py | 0 .../test_flavor_config.py | 0 ...tor_gpu.py => test_modal_step_operator.py} | 0 .../test_utils_build_and_validator.py | 56 ------------------- .../test_utils_env.py => test_modal_utils.py} | 48 ++++++++++++++++ 5 files changed, 48 insertions(+), 56 deletions(-) create mode 100644 tests/integration/integrations/modal/flavors/__init__.py rename tests/integration/integrations/modal/{step_operators => flavors}/test_flavor_config.py (100%) rename tests/integration/integrations/modal/step_operators/{test_modal_step_operator_gpu.py => test_modal_step_operator.py} (100%) delete mode 100644 tests/integration/integrations/modal/step_operators/test_utils_build_and_validator.py rename tests/integration/integrations/modal/{step_operators/test_utils_env.py => test_modal_utils.py} (72%) diff --git a/tests/integration/integrations/modal/flavors/__init__.py b/tests/integration/integrations/modal/flavors/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/integration/integrations/modal/step_operators/test_flavor_config.py b/tests/integration/integrations/modal/flavors/test_flavor_config.py similarity index 100% rename from tests/integration/integrations/modal/step_operators/test_flavor_config.py rename to tests/integration/integrations/modal/flavors/test_flavor_config.py diff --git a/tests/integration/integrations/modal/step_operators/test_modal_step_operator_gpu.py b/tests/integration/integrations/modal/step_operators/test_modal_step_operator.py similarity index 100% rename from tests/integration/integrations/modal/step_operators/test_modal_step_operator_gpu.py rename to tests/integration/integrations/modal/step_operators/test_modal_step_operator.py diff --git a/tests/integration/integrations/modal/step_operators/test_utils_build_and_validator.py b/tests/integration/integrations/modal/step_operators/test_utils_build_and_validator.py deleted file mode 100644 index 5735212cf8e..00000000000 --- a/tests/integration/integrations/modal/step_operators/test_utils_build_and_validator.py +++ /dev/null @@ -1,56 +0,0 @@ -import pytest - -# Skip this entire module if the optional dependency isn't available, because -# the module under test imports `modal` at import time. -pytest.importorskip("modal") - -from zenml.integrations.modal.utils import ( - build_modal_image, - get_modal_stack_validator, -) -from zenml.stack.stack_validator import StackValidator - - -class StackStubNoRegistry: - """Stack stub with no container registry to trigger early validation.""" - - container_registry = None - - -class ContainerRegistryStubNoCreds: - """Container registry stub with missing credentials.""" - - def __init__(self): - self.credentials = None - - -class StackStubWithRegistryNoCreds: - """Stack stub with a container registry but no credentials.""" - - def __init__(self): - self.container_registry = ContainerRegistryStubNoCreds() - - -def test_build_modal_image_raises_when_no_registry() -> None: - with pytest.raises(RuntimeError) as e: - build_modal_image( - "repo/image:tag", StackStubNoRegistry(), environment=None - ) - assert "No Container registry found in the stack" in str(e.value) - - -def test_build_modal_image_raises_when_no_credentials() -> None: - with pytest.raises(RuntimeError) as e: - build_modal_image( - "repo/image:tag", StackStubWithRegistryNoCreds(), environment=None - ) - assert "No Docker credentials found for the container registry" in str( - e.value - ) - - -def test_get_modal_stack_validator_returns_stackvalidator_instance() -> None: - validator = get_modal_stack_validator() - assert isinstance(validator, StackValidator) - # Ensure the validator exposes a validate-like callable (public contract) - assert hasattr(validator, "validate") and callable(validator.validate) diff --git a/tests/integration/integrations/modal/step_operators/test_utils_env.py b/tests/integration/integrations/modal/test_modal_utils.py similarity index 72% rename from tests/integration/integrations/modal/step_operators/test_utils_env.py rename to tests/integration/integrations/modal/test_modal_utils.py index 29acaf666f0..9b7d3b95563 100644 --- a/tests/integration/integrations/modal/step_operators/test_utils_env.py +++ b/tests/integration/integrations/modal/test_modal_utils.py @@ -14,8 +14,56 @@ MODAL_WORKSPACE_ENV, _set_env_if_present, _validate_token_prefix, + build_modal_image, + get_modal_stack_validator, setup_modal_client, ) +from zenml.stack.stack_validator import StackValidator + + +class StackStubNoRegistry: + """Stack stub with no container registry to trigger early validation.""" + + container_registry = None + + +class ContainerRegistryStubNoCreds: + """Container registry stub with missing credentials.""" + + def __init__(self): + self.credentials = None + + +class StackStubWithRegistryNoCreds: + """Stack stub with a container registry but no credentials.""" + + def __init__(self): + self.container_registry = ContainerRegistryStubNoCreds() + + +def test_build_modal_image_raises_when_no_registry() -> None: + with pytest.raises(RuntimeError) as e: + build_modal_image( + "repo/image:tag", StackStubNoRegistry(), environment=None + ) + assert "No Container registry found in the stack" in str(e.value) + + +def test_build_modal_image_raises_when_no_credentials() -> None: + with pytest.raises(RuntimeError) as e: + build_modal_image( + "repo/image:tag", StackStubWithRegistryNoCreds(), environment=None + ) + assert "No Docker credentials found for the container registry" in str( + e.value + ) + + +def test_get_modal_stack_validator_returns_stackvalidator_instance() -> None: + validator = get_modal_stack_validator() + assert isinstance(validator, StackValidator) + # Ensure the validator exposes a validate-like callable (public contract) + assert hasattr(validator, "validate") and callable(validator.validate) def test_set_env_if_present_sets_and_returns_true(monkeypatch) -> None: From 684cd2c5194d110289f58fd3ca6a26b67d867faf Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Mon, 6 Oct 2025 13:47:31 +0200 Subject: [PATCH 26/35] Add licenses --- .../integrations/modal/flavors/__init__.py | 13 +++++++++++++ .../modal/flavors/test_flavor_config.py | 14 ++++++++++++++ .../integrations/modal/step_operators/__init__.py | 2 +- .../step_operators/test_modal_step_operator.py | 14 ++++++++++++++ .../integrations/modal/test_modal_utils.py | 14 ++++++++++++++ 5 files changed, 56 insertions(+), 1 deletion(-) diff --git a/tests/integration/integrations/modal/flavors/__init__.py b/tests/integration/integrations/modal/flavors/__init__.py index e69de29bb2d..32ba63d3703 100644 --- a/tests/integration/integrations/modal/flavors/__init__.py +++ b/tests/integration/integrations/modal/flavors/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. diff --git a/tests/integration/integrations/modal/flavors/test_flavor_config.py b/tests/integration/integrations/modal/flavors/test_flavor_config.py index 05957a18c4c..dd1f2566ff1 100644 --- a/tests/integration/integrations/modal/flavors/test_flavor_config.py +++ b/tests/integration/integrations/modal/flavors/test_flavor_config.py @@ -1,3 +1,17 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. + import pytest from pydantic import ValidationError diff --git a/tests/integration/integrations/modal/step_operators/__init__.py b/tests/integration/integrations/modal/step_operators/__init__.py index cd90a82cfc2..32ba63d3703 100644 --- a/tests/integration/integrations/modal/step_operators/__init__.py +++ b/tests/integration/integrations/modal/step_operators/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/integration/integrations/modal/step_operators/test_modal_step_operator.py b/tests/integration/integrations/modal/step_operators/test_modal_step_operator.py index c2e0b4491f7..91142b8b3ab 100644 --- a/tests/integration/integrations/modal/step_operators/test_modal_step_operator.py +++ b/tests/integration/integrations/modal/step_operators/test_modal_step_operator.py @@ -1,3 +1,17 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. + import pytest # Skip this entire module if the optional dependency isn't available, because diff --git a/tests/integration/integrations/modal/test_modal_utils.py b/tests/integration/integrations/modal/test_modal_utils.py index 9b7d3b95563..e94fc75b735 100644 --- a/tests/integration/integrations/modal/test_modal_utils.py +++ b/tests/integration/integrations/modal/test_modal_utils.py @@ -1,3 +1,17 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. + import os import pytest From fb4dcb770906031494f8d2b69bddf59e6e0d19fb Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Mon, 6 Oct 2025 13:49:56 +0200 Subject: [PATCH 27/35] Remove dumb tests --- .../modal/flavors/test_flavor_config.py | 18 ------------------ .../integrations/modal/test_modal_utils.py | 9 --------- 2 files changed, 27 deletions(-) diff --git a/tests/integration/integrations/modal/flavors/test_flavor_config.py b/tests/integration/integrations/modal/flavors/test_flavor_config.py index dd1f2566ff1..ab3dd2d57ef 100644 --- a/tests/integration/integrations/modal/flavors/test_flavor_config.py +++ b/tests/integration/integrations/modal/flavors/test_flavor_config.py @@ -15,30 +15,12 @@ import pytest from pydantic import ValidationError -from zenml.integrations.modal import MODAL_STEP_OPERATOR_FLAVOR from zenml.integrations.modal.flavors.modal_step_operator_flavor import ( DEFAULT_TIMEOUT_SECONDS, - ModalStepOperatorConfig, - ModalStepOperatorFlavor, ModalStepOperatorSettings, ) -def test_modal_settings_default_timeout() -> None: - settings = ModalStepOperatorSettings() - assert settings.timeout == DEFAULT_TIMEOUT_SECONDS - - -def test_modal_config_is_remote_true() -> None: - cfg = ModalStepOperatorConfig() - assert cfg.is_remote is True - - -def test_modal_flavor_name_constant() -> None: - flavor = ModalStepOperatorFlavor() - assert flavor.name == MODAL_STEP_OPERATOR_FLAVOR - - def test_modal_settings_timeout_accepts_valid_values() -> None: # Test minimum valid value settings = ModalStepOperatorSettings(timeout=1) diff --git a/tests/integration/integrations/modal/test_modal_utils.py b/tests/integration/integrations/modal/test_modal_utils.py index e94fc75b735..5cf30b77f75 100644 --- a/tests/integration/integrations/modal/test_modal_utils.py +++ b/tests/integration/integrations/modal/test_modal_utils.py @@ -29,10 +29,8 @@ _set_env_if_present, _validate_token_prefix, build_modal_image, - get_modal_stack_validator, setup_modal_client, ) -from zenml.stack.stack_validator import StackValidator class StackStubNoRegistry: @@ -73,13 +71,6 @@ def test_build_modal_image_raises_when_no_credentials() -> None: ) -def test_get_modal_stack_validator_returns_stackvalidator_instance() -> None: - validator = get_modal_stack_validator() - assert isinstance(validator, StackValidator) - # Ensure the validator exposes a validate-like callable (public contract) - assert hasattr(validator, "validate") and callable(validator.validate) - - def test_set_env_if_present_sets_and_returns_true(monkeypatch) -> None: monkeypatch.delenv("FOO_TEST", raising=False) assert _set_env_if_present("FOO_TEST", "bar") is True From 86ec76c252a5783d3fd38cbc1c36d311f116b1ed Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Tue, 7 Oct 2025 10:09:48 +0200 Subject: [PATCH 28/35] Revert Optional setting for environment --- .../integrations/modal/step_operators/modal_step_operator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zenml/integrations/modal/step_operators/modal_step_operator.py b/src/zenml/integrations/modal/step_operators/modal_step_operator.py index 22f4b24c973..4a630330394 100644 --- a/src/zenml/integrations/modal/step_operators/modal_step_operator.py +++ b/src/zenml/integrations/modal/step_operators/modal_step_operator.py @@ -106,7 +106,7 @@ def launch( self, info: "StepRunInfo", entrypoint_command: List[str], - environment: Optional[Dict[str, str]], + environment: Dict[str, str], ) -> None: """Launch a step run on Modal. From c999e47759d563bd20c54028762c14638d6d92bf Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Tue, 7 Oct 2025 10:13:58 +0200 Subject: [PATCH 29/35] Fix variable naming and error message in Modal step operator Rename zenml_image to modal_image to accurately reflect that the function returns a Modal-specific image object, not a ZenML construct. Update error message to describe actual failure scenarios: Modal's inability to access the container registry or extend the Docker image, rather than suggesting the image might not exist (which is incorrect since ZenML just built and pushed it successfully). --- .../modal/step_operators/modal_step_operator.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/zenml/integrations/modal/step_operators/modal_step_operator.py b/src/zenml/integrations/modal/step_operators/modal_step_operator.py index 4a630330394..08561d9d78c 100644 --- a/src/zenml/integrations/modal/step_operators/modal_step_operator.py +++ b/src/zenml/integrations/modal/step_operators/modal_step_operator.py @@ -133,12 +133,12 @@ def launch( ) try: - zenml_image = build_modal_image(image_name, stack, environment) + modal_image = build_modal_image(image_name, stack, environment) except Exception as e: raise RuntimeError( - "Failed to build the Modal image from your Docker registry. " - "Action required: ensure your ZenML stack's container registry is configured with valid credentials " - "and that the base image exists and is accessible. " + "Failed to construct Modal execution environment from your Docker image. " + "Action required: verify that Modal can access your container registry (check network connectivity " + "and registry permissions), and that the Docker image can be pulled and extended with additional dependencies. " f"Context: image='{image_name}'." ) from e @@ -173,7 +173,7 @@ async def run_sandbox() -> None: sb = await modal.Sandbox.create.aio( *entrypoint_command, - image=zenml_image, + image=modal_image, gpu=gpu_values, cpu=resource_settings.cpu_count, memory=memory_int, From 31d1a63b024cbc75796088a6d3a450b59b31e82f Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Tue, 7 Oct 2025 10:17:43 +0200 Subject: [PATCH 30/35] Move memory calculation outside Modal runtime context Move the memory_int calculation from inside the async with app.run() context to outside, grouping it with other resource preparations like gpu_values. This is a pure data transformation that doesn't require Modal's runtime context, improving code organization and reducing unnecessary nesting. --- .../modal/step_operators/modal_step_operator.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/zenml/integrations/modal/step_operators/modal_step_operator.py b/src/zenml/integrations/modal/step_operators/modal_step_operator.py index 08561d9d78c..0ff89083660 100644 --- a/src/zenml/integrations/modal/step_operators/modal_step_operator.py +++ b/src/zenml/integrations/modal/step_operators/modal_step_operator.py @@ -146,6 +146,12 @@ def launch( gpu_values = get_gpu_values(settings, resource_settings) + memory_int = ( + int(mb) + if (mb := resource_settings.get_memory(ByteUnit.MB)) + else None + ) + app = modal.App( f"zenml-{info.run_name}-{info.step_run_id}-{info.pipeline_step_name}" ) @@ -154,14 +160,6 @@ async def run_sandbox() -> None: with modal.enable_output(): try: async with app.run(): - memory_int = ( - int(mb) - if ( - mb := resource_settings.get_memory(ByteUnit.MB) - ) - else None - ) - try: if ( not entrypoint_command From 0599ddfa0f532512b7f565ee7218ad9b9e361da6 Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Tue, 7 Oct 2025 10:18:44 +0200 Subject: [PATCH 31/35] Remove unneeded guardrail --- .../modal/step_operators/modal_step_operator.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/zenml/integrations/modal/step_operators/modal_step_operator.py b/src/zenml/integrations/modal/step_operators/modal_step_operator.py index 0ff89083660..f40ad410e5b 100644 --- a/src/zenml/integrations/modal/step_operators/modal_step_operator.py +++ b/src/zenml/integrations/modal/step_operators/modal_step_operator.py @@ -161,14 +161,6 @@ async def run_sandbox() -> None: try: async with app.run(): try: - if ( - not entrypoint_command - or not entrypoint_command[0] - ): - raise ValueError( - "Empty step entrypoint command is not allowed." - ) - sb = await modal.Sandbox.create.aio( *entrypoint_command, image=modal_image, From c54b2b40aa84e2f219784ba0308ca3f14cbe6864 Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Tue, 7 Oct 2025 10:20:25 +0200 Subject: [PATCH 32/35] Update comments --- .../modal/step_operators/modal_step_operator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/zenml/integrations/modal/step_operators/modal_step_operator.py b/src/zenml/integrations/modal/step_operators/modal_step_operator.py index f40ad410e5b..1454a3c3d84 100644 --- a/src/zenml/integrations/modal/step_operators/modal_step_operator.py +++ b/src/zenml/integrations/modal/step_operators/modal_step_operator.py @@ -175,7 +175,7 @@ async def run_sandbox() -> None: except Exception as e: raise RuntimeError( "Failed to create a Modal sandbox. " - "Action required: verify that the referenced Docker image exists and is accessible, " + "Action required: verify that the referenced Docker image is accessible, " "the requested resources are available (gpu/region/cloud), and your Modal workspace " "permissions allow sandbox creation. " f"Context: image='{image_name}', gpu='{gpu_values}', region='{settings.region}', cloud='{settings.cloud}'." @@ -186,8 +186,8 @@ async def run_sandbox() -> None: except Exception as e: raise RuntimeError( "Modal sandbox execution failed. " - "Action required: inspect the step logs in Modal, validate your entrypoint command, " - "and confirm that dependencies are available in the image/environment." + "Action required: inspect the step logs in Modal, and confirm that " + "dependencies are available in the image/environment." ) from e except Exception as e: raise RuntimeError( From 24bfd63a459a36cbe8ea1fb095b85a72de833e26 Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Tue, 7 Oct 2025 10:23:55 +0200 Subject: [PATCH 33/35] Remove extra modal pip install --- src/zenml/integrations/modal/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zenml/integrations/modal/utils.py b/src/zenml/integrations/modal/utils.py index 1f8789f2d03..fbbcf44240a 100644 --- a/src/zenml/integrations/modal/utils.py +++ b/src/zenml/integrations/modal/utils.py @@ -195,7 +195,7 @@ def build_modal_image( try: modal_image = modal.Image.from_registry( image_name, secret=registry_secret - ).pip_install("modal") + ) except Exception as e: raise RuntimeError( "Failed to construct a Modal image from the specified Docker base image. " From 1147e99dfbe4108d7ce17b61b90758974c428dca Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Tue, 7 Oct 2025 11:36:15 +0200 Subject: [PATCH 34/35] Adapt CLI tests for Click 8.2 compatibility --- .../integration/functional/cli/test_model.py | 28 +++++++++---------- tests/integration/functional/cli/test_tag.py | 18 ++++++------ tests/integration/functional/cli/utils.py | 13 +++++++++ 3 files changed, 36 insertions(+), 23 deletions(-) diff --git a/tests/integration/functional/cli/test_model.py b/tests/integration/functional/cli/test_model.py index 48c67b298ad..27bd55cfb92 100644 --- a/tests/integration/functional/cli/test_model.py +++ b/tests/integration/functional/cli/test_model.py @@ -16,16 +16,16 @@ from uuid import uuid4 import pytest -from click.testing import CliRunner from tests.integration.functional.cli.conftest import NAME, PREFIX +from tests.integration.functional.cli.utils import cli_runner from zenml.cli.cli import cli from zenml.client import Client def test_model_list(clean_client_with_models: "Client"): """Test that zenml model list does not fail.""" - runner = CliRunner(mix_stderr=False) + runner = cli_runner() list_command = cli.commands["model"].commands["list"] result = runner.invoke(list_command) assert result.exit_code == 0, result.stderr @@ -33,7 +33,7 @@ def test_model_list(clean_client_with_models: "Client"): def test_model_create_short_names(clean_client_with_models: "Client"): """Test that zenml model create does not fail with short names.""" - runner = CliRunner(mix_stderr=False) + runner = cli_runner() create_command = cli.commands["model"].commands["register"] model_name = PREFIX + str(uuid4()) result = runner.invoke( @@ -82,7 +82,7 @@ def test_model_create_short_names(clean_client_with_models: "Client"): def test_model_create_full_names(clean_client_with_models: "Client"): """Test that zenml model create does not fail with full names.""" - runner = CliRunner(mix_stderr=False) + runner = cli_runner() create_command = cli.commands["model"].commands["register"] model_name = PREFIX + str(uuid4()) result = runner.invoke( @@ -131,7 +131,7 @@ def test_model_create_full_names(clean_client_with_models: "Client"): def test_model_create_only_required(clean_client_with_models: "Client"): """Test that zenml model create does not fail.""" - runner = CliRunner(mix_stderr=False) + runner = cli_runner() create_command = cli.commands["model"].commands["register"] model_name = PREFIX + str(uuid4()) result = runner.invoke( @@ -155,7 +155,7 @@ def test_model_create_only_required(clean_client_with_models: "Client"): def test_model_update(clean_client_with_models: "Client"): """Test that zenml model update does not fail.""" - runner = CliRunner(mix_stderr=False) + runner = cli_runner() update_command = cli.commands["model"].commands["update"] result = runner.invoke( update_command, @@ -185,7 +185,7 @@ def test_model_create_without_required_fails( clean_client_with_models: "Client", ): """Test that zenml model create fails.""" - runner = CliRunner(mix_stderr=False) + runner = cli_runner() create_command = cli.commands["model"].commands["register"] result = runner.invoke( create_command, @@ -195,7 +195,7 @@ def test_model_create_without_required_fails( def test_model_delete_found(clean_client_with_models: "Client"): """Test that zenml model delete does not fail.""" - runner = CliRunner(mix_stderr=False) + runner = cli_runner() name = PREFIX + str(uuid4()) create_command = cli.commands["model"].commands["register"] runner.invoke( @@ -212,7 +212,7 @@ def test_model_delete_found(clean_client_with_models: "Client"): def test_model_delete_not_found(clean_client_with_models: "Client"): """Test that zenml model delete fail.""" - runner = CliRunner(mix_stderr=False) + runner = cli_runner() name = PREFIX + str(uuid4()) delete_command = cli.commands["model"].commands["delete"] result = runner.invoke( @@ -224,7 +224,7 @@ def test_model_delete_not_found(clean_client_with_models: "Client"): def test_model_version_list(clean_client_with_models: "Client"): """Test that zenml model version list does not fail.""" - runner = CliRunner(mix_stderr=False) + runner = cli_runner() list_command = cli.commands["model"].commands["version"].commands["list"] result = runner.invoke(list_command, args=[f"--model={NAME}"]) assert result.exit_code == 0, result.stderr @@ -232,7 +232,7 @@ def test_model_version_list(clean_client_with_models: "Client"): def test_model_version_delete_found(clean_client_with_models: "Client"): """Test that zenml model version delete does not fail.""" - runner = CliRunner(mix_stderr=False) + runner = cli_runner() model_name = PREFIX + str(uuid4()) model_version_name = PREFIX + str(uuid4()) model = clean_client_with_models.create_model( @@ -254,7 +254,7 @@ def test_model_version_delete_found(clean_client_with_models: "Client"): def test_model_version_delete_not_found(clean_client_with_models: "Client"): """Test that zenml model version delete fail.""" - runner = CliRunner(mix_stderr=False) + runner = cli_runner() model_name = PREFIX + str(uuid4()) model_version_name = PREFIX + str(uuid4()) clean_client_with_models.create_model( @@ -278,7 +278,7 @@ def test_model_version_links_list( command: str, clean_client_with_models: "Client" ): """Test that zenml model version artifacts list fails.""" - runner = CliRunner(mix_stderr=False) + runner = cli_runner() list_command = cli.commands["model"].commands[command] result = runner.invoke( list_command, @@ -289,7 +289,7 @@ def test_model_version_links_list( def test_model_version_update(clean_client_with_models: "Client"): """Test that zenml model version stage update pass.""" - runner = CliRunner(mix_stderr=False) + runner = cli_runner() update_command = ( cli.commands["model"].commands["version"].commands["update"] ) diff --git a/tests/integration/functional/cli/test_tag.py b/tests/integration/functional/cli/test_tag.py index 775e6f1b013..0d20767f962 100644 --- a/tests/integration/functional/cli/test_tag.py +++ b/tests/integration/functional/cli/test_tag.py @@ -14,9 +14,9 @@ """Test zenml tag CLI commands.""" import pytest -from click.testing import CliRunner from tests.integration.functional.cli.utils import ( + cli_runner, random_resource_name, tags_killer, ) @@ -29,7 +29,7 @@ def test_tag_list(): """Test that zenml tag list does not fail.""" with tags_killer(): - runner = CliRunner(mix_stderr=False) + runner = cli_runner() list_command = cli.commands["tag"].commands["list"] result = runner.invoke(list_command) assert result.exit_code == 0, result.stderr @@ -38,7 +38,7 @@ def test_tag_list(): def test_tag_create_short_names(): """Test that zenml tag create does not fail with short names.""" with tags_killer(0): - runner = CliRunner(mix_stderr=False) + runner = cli_runner() create_command = cli.commands["tag"].commands["register"] tag_name = random_resource_name() result = runner.invoke( @@ -55,7 +55,7 @@ def test_tag_create_short_names(): def test_tag_create_full_names(): """Test that zenml tag create does not fail with full names.""" with tags_killer(0): - runner = CliRunner(mix_stderr=False) + runner = cli_runner() create_command = cli.commands["tag"].commands["register"] tag_name = random_resource_name() result = runner.invoke( @@ -72,7 +72,7 @@ def test_tag_create_full_names(): def test_tag_create_only_required(): """Test that zenml tag create does not fail.""" with tags_killer(0): - runner = CliRunner(mix_stderr=False) + runner = cli_runner() create_command = cli.commands["tag"].commands["register"] tag_name = random_resource_name() result = runner.invoke( @@ -93,7 +93,7 @@ def test_tag_update(): """Test that zenml tag update does not fail.""" with tags_killer(1) as tags: tag: TagResponse = tags[0] - runner = CliRunner(mix_stderr=False) + runner = cli_runner() update_command = cli.commands["tag"].commands["update"] color_to_set = "yellow" if tag.color.value != "yellow" else "grey" result = runner.invoke( @@ -138,7 +138,7 @@ def test_tag_update(): def test_tag_create_without_required_fails(): """Test that zenml tag create fails.""" with tags_killer(0): - runner = CliRunner(mix_stderr=False) + runner = cli_runner() create_command = cli.commands["tag"].commands["register"] result = runner.invoke( create_command, @@ -150,7 +150,7 @@ def test_tag_delete_found(): """Test that zenml tag delete does not fail.""" with tags_killer(1) as tags: tag: TagResponse = tags[0] - runner = CliRunner(mix_stderr=False) + runner = cli_runner() delete_command = cli.commands["tag"].commands["delete"] result = runner.invoke( delete_command, @@ -165,7 +165,7 @@ def test_tag_delete_found(): def test_tag_delete_not_found(): """Test that zenml tag delete fail.""" with tags_killer(0): - runner = CliRunner(mix_stderr=False) + runner = cli_runner() delete_command = cli.commands["tag"].commands["delete"] result = runner.invoke( delete_command, diff --git a/tests/integration/functional/cli/utils.py b/tests/integration/functional/cli/utils.py index 7a8331a919c..d41e0db6f21 100644 --- a/tests/integration/functional/cli/utils.py +++ b/tests/integration/functional/cli/utils.py @@ -11,9 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing # permissions and limitations under the License. +import inspect from contextlib import contextmanager from typing import Generator, Optional, Tuple +from click.testing import CliRunner + from tests.harness.harness import TestHarness from zenml.cli import cli from zenml.cli.utils import ( @@ -35,6 +38,16 @@ ] +def cli_runner(**kwargs) -> CliRunner: + """Return a Click runner that stays compatible across Click releases.""" + runner_kwargs = dict(kwargs) + if "mix_stderr" not in runner_kwargs: + params = inspect.signature(CliRunner.__init__).parameters + if "mix_stderr" in params: + runner_kwargs["mix_stderr"] = False + return CliRunner(**runner_kwargs) + + # ----- # # USERS # # ----- # From ad071ed075b45d171a96ca50879c90ec1e4b88ef Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Tue, 7 Oct 2025 11:36:15 +0200 Subject: [PATCH 35/35] Adapt CLI tests for Click 8.2 compatibility --- .../integration/functional/cli/test_model.py | 28 +++++++++---------- tests/integration/functional/cli/test_tag.py | 18 ++++++------ tests/integration/functional/cli/utils.py | 12 ++++++++ 3 files changed, 35 insertions(+), 23 deletions(-) diff --git a/tests/integration/functional/cli/test_model.py b/tests/integration/functional/cli/test_model.py index 48c67b298ad..27bd55cfb92 100644 --- a/tests/integration/functional/cli/test_model.py +++ b/tests/integration/functional/cli/test_model.py @@ -16,16 +16,16 @@ from uuid import uuid4 import pytest -from click.testing import CliRunner from tests.integration.functional.cli.conftest import NAME, PREFIX +from tests.integration.functional.cli.utils import cli_runner from zenml.cli.cli import cli from zenml.client import Client def test_model_list(clean_client_with_models: "Client"): """Test that zenml model list does not fail.""" - runner = CliRunner(mix_stderr=False) + runner = cli_runner() list_command = cli.commands["model"].commands["list"] result = runner.invoke(list_command) assert result.exit_code == 0, result.stderr @@ -33,7 +33,7 @@ def test_model_list(clean_client_with_models: "Client"): def test_model_create_short_names(clean_client_with_models: "Client"): """Test that zenml model create does not fail with short names.""" - runner = CliRunner(mix_stderr=False) + runner = cli_runner() create_command = cli.commands["model"].commands["register"] model_name = PREFIX + str(uuid4()) result = runner.invoke( @@ -82,7 +82,7 @@ def test_model_create_short_names(clean_client_with_models: "Client"): def test_model_create_full_names(clean_client_with_models: "Client"): """Test that zenml model create does not fail with full names.""" - runner = CliRunner(mix_stderr=False) + runner = cli_runner() create_command = cli.commands["model"].commands["register"] model_name = PREFIX + str(uuid4()) result = runner.invoke( @@ -131,7 +131,7 @@ def test_model_create_full_names(clean_client_with_models: "Client"): def test_model_create_only_required(clean_client_with_models: "Client"): """Test that zenml model create does not fail.""" - runner = CliRunner(mix_stderr=False) + runner = cli_runner() create_command = cli.commands["model"].commands["register"] model_name = PREFIX + str(uuid4()) result = runner.invoke( @@ -155,7 +155,7 @@ def test_model_create_only_required(clean_client_with_models: "Client"): def test_model_update(clean_client_with_models: "Client"): """Test that zenml model update does not fail.""" - runner = CliRunner(mix_stderr=False) + runner = cli_runner() update_command = cli.commands["model"].commands["update"] result = runner.invoke( update_command, @@ -185,7 +185,7 @@ def test_model_create_without_required_fails( clean_client_with_models: "Client", ): """Test that zenml model create fails.""" - runner = CliRunner(mix_stderr=False) + runner = cli_runner() create_command = cli.commands["model"].commands["register"] result = runner.invoke( create_command, @@ -195,7 +195,7 @@ def test_model_create_without_required_fails( def test_model_delete_found(clean_client_with_models: "Client"): """Test that zenml model delete does not fail.""" - runner = CliRunner(mix_stderr=False) + runner = cli_runner() name = PREFIX + str(uuid4()) create_command = cli.commands["model"].commands["register"] runner.invoke( @@ -212,7 +212,7 @@ def test_model_delete_found(clean_client_with_models: "Client"): def test_model_delete_not_found(clean_client_with_models: "Client"): """Test that zenml model delete fail.""" - runner = CliRunner(mix_stderr=False) + runner = cli_runner() name = PREFIX + str(uuid4()) delete_command = cli.commands["model"].commands["delete"] result = runner.invoke( @@ -224,7 +224,7 @@ def test_model_delete_not_found(clean_client_with_models: "Client"): def test_model_version_list(clean_client_with_models: "Client"): """Test that zenml model version list does not fail.""" - runner = CliRunner(mix_stderr=False) + runner = cli_runner() list_command = cli.commands["model"].commands["version"].commands["list"] result = runner.invoke(list_command, args=[f"--model={NAME}"]) assert result.exit_code == 0, result.stderr @@ -232,7 +232,7 @@ def test_model_version_list(clean_client_with_models: "Client"): def test_model_version_delete_found(clean_client_with_models: "Client"): """Test that zenml model version delete does not fail.""" - runner = CliRunner(mix_stderr=False) + runner = cli_runner() model_name = PREFIX + str(uuid4()) model_version_name = PREFIX + str(uuid4()) model = clean_client_with_models.create_model( @@ -254,7 +254,7 @@ def test_model_version_delete_found(clean_client_with_models: "Client"): def test_model_version_delete_not_found(clean_client_with_models: "Client"): """Test that zenml model version delete fail.""" - runner = CliRunner(mix_stderr=False) + runner = cli_runner() model_name = PREFIX + str(uuid4()) model_version_name = PREFIX + str(uuid4()) clean_client_with_models.create_model( @@ -278,7 +278,7 @@ def test_model_version_links_list( command: str, clean_client_with_models: "Client" ): """Test that zenml model version artifacts list fails.""" - runner = CliRunner(mix_stderr=False) + runner = cli_runner() list_command = cli.commands["model"].commands[command] result = runner.invoke( list_command, @@ -289,7 +289,7 @@ def test_model_version_links_list( def test_model_version_update(clean_client_with_models: "Client"): """Test that zenml model version stage update pass.""" - runner = CliRunner(mix_stderr=False) + runner = cli_runner() update_command = ( cli.commands["model"].commands["version"].commands["update"] ) diff --git a/tests/integration/functional/cli/test_tag.py b/tests/integration/functional/cli/test_tag.py index 775e6f1b013..0d20767f962 100644 --- a/tests/integration/functional/cli/test_tag.py +++ b/tests/integration/functional/cli/test_tag.py @@ -14,9 +14,9 @@ """Test zenml tag CLI commands.""" import pytest -from click.testing import CliRunner from tests.integration.functional.cli.utils import ( + cli_runner, random_resource_name, tags_killer, ) @@ -29,7 +29,7 @@ def test_tag_list(): """Test that zenml tag list does not fail.""" with tags_killer(): - runner = CliRunner(mix_stderr=False) + runner = cli_runner() list_command = cli.commands["tag"].commands["list"] result = runner.invoke(list_command) assert result.exit_code == 0, result.stderr @@ -38,7 +38,7 @@ def test_tag_list(): def test_tag_create_short_names(): """Test that zenml tag create does not fail with short names.""" with tags_killer(0): - runner = CliRunner(mix_stderr=False) + runner = cli_runner() create_command = cli.commands["tag"].commands["register"] tag_name = random_resource_name() result = runner.invoke( @@ -55,7 +55,7 @@ def test_tag_create_short_names(): def test_tag_create_full_names(): """Test that zenml tag create does not fail with full names.""" with tags_killer(0): - runner = CliRunner(mix_stderr=False) + runner = cli_runner() create_command = cli.commands["tag"].commands["register"] tag_name = random_resource_name() result = runner.invoke( @@ -72,7 +72,7 @@ def test_tag_create_full_names(): def test_tag_create_only_required(): """Test that zenml tag create does not fail.""" with tags_killer(0): - runner = CliRunner(mix_stderr=False) + runner = cli_runner() create_command = cli.commands["tag"].commands["register"] tag_name = random_resource_name() result = runner.invoke( @@ -93,7 +93,7 @@ def test_tag_update(): """Test that zenml tag update does not fail.""" with tags_killer(1) as tags: tag: TagResponse = tags[0] - runner = CliRunner(mix_stderr=False) + runner = cli_runner() update_command = cli.commands["tag"].commands["update"] color_to_set = "yellow" if tag.color.value != "yellow" else "grey" result = runner.invoke( @@ -138,7 +138,7 @@ def test_tag_update(): def test_tag_create_without_required_fails(): """Test that zenml tag create fails.""" with tags_killer(0): - runner = CliRunner(mix_stderr=False) + runner = cli_runner() create_command = cli.commands["tag"].commands["register"] result = runner.invoke( create_command, @@ -150,7 +150,7 @@ def test_tag_delete_found(): """Test that zenml tag delete does not fail.""" with tags_killer(1) as tags: tag: TagResponse = tags[0] - runner = CliRunner(mix_stderr=False) + runner = cli_runner() delete_command = cli.commands["tag"].commands["delete"] result = runner.invoke( delete_command, @@ -165,7 +165,7 @@ def test_tag_delete_found(): def test_tag_delete_not_found(): """Test that zenml tag delete fail.""" with tags_killer(0): - runner = CliRunner(mix_stderr=False) + runner = cli_runner() delete_command = cli.commands["tag"].commands["delete"] result = runner.invoke( delete_command, diff --git a/tests/integration/functional/cli/utils.py b/tests/integration/functional/cli/utils.py index 7a8331a919c..b54664c67ad 100644 --- a/tests/integration/functional/cli/utils.py +++ b/tests/integration/functional/cli/utils.py @@ -11,9 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing # permissions and limitations under the License. +import inspect from contextlib import contextmanager from typing import Generator, Optional, Tuple +from click.testing import CliRunner + from tests.harness.harness import TestHarness from zenml.cli import cli from zenml.cli.utils import ( @@ -35,6 +38,15 @@ ] +def cli_runner(**kwargs) -> CliRunner: + """Return a Click runner that stays compatible across Click releases.""" + if "mix_stderr" not in kwargs: + params = inspect.signature(CliRunner.__init__).parameters + if "mix_stderr" in params: + kwargs["mix_stderr"] = False + return CliRunner(**kwargs) + + # ----- # # USERS # # ----- #