diff --git a/docs/book/component-guide/step-operators/modal.md b/docs/book/component-guide/step-operators/modal.md index 83c0ec99035..2f18b6266fb 100644 --- a/docs/book/component-guide/step-operators/modal.md +++ b/docs/book/component-guide/step-operators/modal.md @@ -36,6 +36,13 @@ To use the Modal step operator, we need: cloud artifact store supported by ZenML will work with Modal. * A cloud container registry as part of your stack. Any cloud container registry supported by ZenML will work with Modal. +* An Image Builder in your stack. ZenML uses it to build the Docker image that + runs on Modal. + +The Modal step operator also respects the following environment variables if set: +- MODAL_TOKEN_ID, MODAL_TOKEN_SECRET: authentication tokens +- MODAL_WORKSPACE: workspace name +- MODAL_ENVIRONMENT: Modal environment name (e.g., "main") We can then register the step operator: @@ -66,30 +73,42 @@ You can specify the hardware requirements for each step using the `ResourceSettings` class as described in our documentation on [resource settings](https://docs.zenml.io/user-guides/tutorial/distributed-training): ```python +from zenml import step from zenml.config import ResourceSettings from zenml.integrations.modal.flavors import ModalStepOperatorSettings -modal_settings = ModalStepOperatorSettings(gpu="A100") +modal_settings = ModalStepOperatorSettings( + gpu="A100", # GPU type (e.g., "T4", "A100") + # region="us-east-1", # optional, enterprise/team only + # cloud="aws", # optional, enterprise/team only + # modal_environment="main", # optional + # timeout=86400, # optional, seconds +) + resource_settings = ResourceSettings( - cpu=2, - memory="32GB" + cpu_count=2, + memory="32GB", + # gpu_count=1, # optional; if omitted and a GPU type is set, defaults to 1 GPU ) @step( - step_operator=True, + step_operator=True, # or the specific name, e.g., step_operator="" settings={ "step_operator": modal_settings, - "resources": resource_settings - } + "resources": resource_settings, + }, ) def my_modal_step(): ... ``` +Important: +- If you request GPUs with `ResourceSettings.gpu_count > 0`, you must also specify a GPU type via `ModalStepOperatorSettings.gpu`; otherwise the run will fail with a validation error. +- If a GPU type is set but `gpu_count == 0`, ZenML defaults to 1 GPU and logs a warning. +- `cpu_count` must be an integer. `memory` can be a string like "32GB" or an integer amount of bytes. + {% hint style="info" %} Note that the `cpu` parameter in `ResourceSettings` currently only accepts a single integer value. This specifies a soft minimum limit - Modal will guarantee at least this many physical cores, but the actual usage could be higher. The CPU cores/hour will also determine the minimum price paid for the compute resources. - -For example, with the configuration above (2 CPUs and 32GB memory), the minimum cost would be approximately $1.03 per hour ((0.135 * 2) + (0.024 * 32) = $1.03). {% endhint %} This will run `my_modal_step` on a Modal instance with 1 A100 GPU, 2 CPUs, and @@ -108,8 +127,3 @@ pipeline execution failures. In the case of failures, however, Modal provides detailed error messages that can help identify what is incompatible. See more in the [Modal docs on region selection](https://modal.com/docs/guide/region-selection) for more details. - - -
ZenML Scarf
- - diff --git a/src/zenml/cli/cli.py b/src/zenml/cli/cli.py index baa29bafbe5..01346581782 100644 --- a/src/zenml/cli/cli.py +++ b/src/zenml/cli/cli.py @@ -43,7 +43,7 @@ def __init__( commands: Optional[ Union[Dict[str, click.Command], Sequence[click.Command]] ] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: """Initialize the Tag group. diff --git a/src/zenml/integrations/modal/__init__.py b/src/zenml/integrations/modal/__init__.py index 081628cb035..49e78c1e774 100644 --- a/src/zenml/integrations/modal/__init__.py +++ b/src/zenml/integrations/modal/__init__.py @@ -29,7 +29,7 @@ class ModalIntegration(Integration): """Definition of Modal integration for ZenML.""" NAME = MODAL - REQUIREMENTS = ["modal>=0.64.49,<1"] + REQUIREMENTS = ["modal>=1"] @classmethod def flavors(cls) -> List[Type[Flavor]]: diff --git a/src/zenml/integrations/modal/flavors/modal_step_operator_flavor.py b/src/zenml/integrations/modal/flavors/modal_step_operator_flavor.py index a66580faa54..13d55eea9d1 100644 --- a/src/zenml/integrations/modal/flavors/modal_step_operator_flavor.py +++ b/src/zenml/integrations/modal/flavors/modal_step_operator_flavor.py @@ -15,13 +15,18 @@ from typing import TYPE_CHECKING, Optional, Type +from pydantic import Field + from zenml.config.base_settings import BaseSettings from zenml.integrations.modal import MODAL_STEP_OPERATOR_FLAVOR from zenml.step_operators import BaseStepOperatorConfig, BaseStepOperatorFlavor +from zenml.utils.secret_utils import SecretField if TYPE_CHECKING: from zenml.integrations.modal.step_operators import ModalStepOperator +DEFAULT_TIMEOUT_SECONDS = 86400 # 24 hours + class ModalStepOperatorSettings(BaseSettings): """Settings for the Modal step operator. @@ -36,20 +41,82 @@ class ModalStepOperatorSettings(BaseSettings): incompatible. See more in the Modal docs at https://modal.com/docs/guide/region-selection. Attributes: - gpu: The type of GPU to use for the step execution. + gpu: The type of GPU to use for the step execution (e.g., "T4", "A100"). + Use ResourceSettings.gpu_count to specify the number of GPUs. region: The region to use for the step execution. cloud: The cloud provider to use for the step execution. + modal_environment: The Modal environment to use for the step execution. + timeout: Maximum execution time in seconds (default 24h). """ - gpu: Optional[str] = None - region: Optional[str] = None - cloud: Optional[str] = None + gpu: Optional[str] = Field( + None, + description="GPU type for step execution. Must be a valid Modal GPU type. " + "Examples: 'T4' (cost-effective), 'A100' (high-performance), 'V100' (training workloads). " + "Use ResourceSettings.gpu_count to specify number of GPUs. If not specified, uses CPU-only execution", + ) + region: Optional[str] = Field( + None, + description="Cloud region for step execution. Must be a valid region for the selected cloud provider. " + "Examples: 'us-east-1', 'us-west-2', 'eu-west-1'. If not specified, Modal uses default region " + "based on cloud provider and availability", + ) + cloud: Optional[str] = Field( + None, + description="Cloud provider for step execution. Must be a valid Modal-supported cloud provider. " + "Examples: 'aws', 'gcp'. If not specified, Modal uses default cloud provider " + "based on workspace configuration", + ) + modal_environment: Optional[str] = Field( + None, + description="Modal environment name for step execution. Must be a valid environment " + "configured in your Modal workspace. Examples: 'main', 'staging', 'production'. " + "If not specified, uses the default environment for the workspace", + ) + timeout: int = Field( + DEFAULT_TIMEOUT_SECONDS, + ge=1, + le=DEFAULT_TIMEOUT_SECONDS, + description=f"Maximum execution time in seconds for step completion. Must be between 1 and {DEFAULT_TIMEOUT_SECONDS} seconds. " + f"Examples: 3600 (1 hour), 7200 (2 hours), {DEFAULT_TIMEOUT_SECONDS} (24 hours maximum). " + "Step execution will be terminated if it exceeds this timeout", + ) class ModalStepOperatorConfig( BaseStepOperatorConfig, ModalStepOperatorSettings ): - """Configuration for the Modal step operator.""" + """Configuration for the Modal step operator. + + Attributes: + token_id: Modal API token ID (ak-xxxxx format) for authentication. + token_secret: Modal API token secret (as-xxxxx format) for authentication. + workspace: Modal workspace name (optional). + + Note: If token_id and token_secret are not provided, falls back to + Modal's default authentication (~/.modal.toml). + All other configuration options (modal_environment, gpu, region, etc.) + are inherited from ModalStepOperatorSettings. + """ + + token_id: Optional[str] = SecretField( + default=None, + description="Modal API token ID for authentication. Must be in format 'ak-xxxxx' as provided by Modal. " + "Example: 'ak-1234567890abcdef'. If not provided, falls back to Modal's default authentication " + "from ~/.modal.toml file. Required for programmatic access to Modal API", + ) + token_secret: Optional[str] = SecretField( + default=None, + description="Modal API token secret for authentication. Must be in format 'as-xxxxx' as provided by Modal. " + "Example: 'as-abcdef1234567890'. Used together with token_id for API authentication. " + "If not provided, falls back to Modal's default authentication from ~/.modal.toml file", + ) + workspace: Optional[str] = Field( + None, + description="Modal workspace name for step execution. Must be a valid workspace name " + "you have access to. Examples: 'my-company', 'ml-team', 'personal-workspace'. " + "If not specified, uses the default workspace from Modal configuration", + ) @property def is_remote(self) -> bool: diff --git a/src/zenml/integrations/modal/step_operators/modal_step_operator.py b/src/zenml/integrations/modal/step_operators/modal_step_operator.py index 9c654ca6541..1454a3c3d84 100644 --- a/src/zenml/integrations/modal/step_operators/modal_step_operator.py +++ b/src/zenml/integrations/modal/step_operators/modal_step_operator.py @@ -14,21 +14,25 @@ """Modal step operator implementation.""" import asyncio -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, cast +from typing import TYPE_CHECKING, Dict, List, Optional, Type, cast import modal -from modal_proto import api_pb2 from zenml.client import Client from zenml.config.build_configuration import BuildConfiguration -from zenml.config.resource_settings import ByteUnit, ResourceSettings -from zenml.enums import StackComponentType +from zenml.config.resource_settings import ByteUnit from zenml.integrations.modal.flavors import ( ModalStepOperatorConfig, ModalStepOperatorSettings, ) +from zenml.integrations.modal.utils import ( + build_modal_image, + get_gpu_values, + get_modal_stack_validator, + setup_modal_client, +) from zenml.logger import get_logger -from zenml.stack import Stack, StackValidator +from zenml.stack import StackValidator from zenml.step_operators import BaseStepOperator if TYPE_CHECKING: @@ -41,24 +45,6 @@ MODAL_STEP_OPERATOR_DOCKER_IMAGE_KEY = "modal_step_operator" -def get_gpu_values( - settings: ModalStepOperatorSettings, resource_settings: ResourceSettings -) -> Optional[str]: - """Get the GPU values for the Modal step operator. - - Args: - settings: The Modal step operator settings. - resource_settings: The resource settings. - - Returns: - The GPU string if a count is specified, otherwise the GPU type. - """ - if not settings.gpu: - return None - gpu_count = resource_settings.gpu_count - return f"{settings.gpu}:{gpu_count}" if gpu_count else settings.gpu - - class ModalStepOperator(BaseStepOperator): """Step operator to run a step on Modal. @@ -91,40 +77,7 @@ def validator(self) -> Optional[StackValidator]: Returns: The stack validator. """ - - def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]: - if stack.artifact_store.config.is_local: - return False, ( - "The Modal step operator runs code remotely and " - "needs to write files into the artifact store, but the " - f"artifact store `{stack.artifact_store.name}` of the " - "active stack is local. Please ensure that your stack " - "contains a remote artifact store when using the Modal " - "step operator." - ) - - container_registry = stack.container_registry - assert container_registry is not None - - if container_registry.config.is_local: - return False, ( - "The Modal step operator runs code remotely and " - "needs to push/pull Docker images, but the " - f"container registry `{container_registry.name}` of the " - "active stack is local. Please ensure that your stack " - "contains a remote container registry when using the " - "Modal step operator." - ) - - return True, "" - - return StackValidator( - required_components={ - StackComponentType.CONTAINER_REGISTRY, - StackComponentType.IMAGE_BUILDER, - }, - custom_validation_function=_validate_remote_components, - ) + return get_modal_stack_validator() def get_docker_builds( self, snapshot: "PipelineSnapshotBase" @@ -163,80 +116,85 @@ def launch( environment: The environment variables for the step. Raises: - RuntimeError: If no Docker credentials are found for the container registry. - ValueError: If no container registry is found in the stack. + RuntimeError: If Modal image construction fails, sandbox creation fails, + sandbox execution fails, or Modal application initialization fails. """ settings = cast(ModalStepOperatorSettings, self.get_settings(info)) image_name = info.get_image(key=MODAL_STEP_OPERATOR_DOCKER_IMAGE_KEY) zc = Client() stack = zc.active_stack - if not stack.container_registry: - raise ValueError( - "No Container registry found in the stack. " - "Please add a container registry and ensure " - "it is correctly configured." - ) - - if docker_creds := stack.container_registry.credentials: - docker_username, docker_password = docker_creds - else: - raise RuntimeError( - "No Docker credentials found for the container registry." - ) - - my_secret = modal.secret._Secret.from_dict( - { - "REGISTRY_USERNAME": docker_username, - "REGISTRY_PASSWORD": docker_password, - } - ) - - spec = modal.image.DockerfileSpec( - commands=[f"FROM {image_name}"], context_files={} + setup_modal_client( + token_id=self.config.token_id, + token_secret=self.config.token_secret, + workspace=self.config.workspace, + environment=settings.modal_environment + or self.config.modal_environment, ) - zenml_image = modal.Image._from_args( - dockerfile_function=lambda *_, **__: spec, - force_build=False, - image_registry_config=modal.image._ImageRegistryConfig( - api_pb2.REGISTRY_AUTH_TYPE_STATIC_CREDS, my_secret - ), - ).env(environment) + try: + modal_image = build_modal_image(image_name, stack, environment) + except Exception as e: + raise RuntimeError( + "Failed to construct Modal execution environment from your Docker image. " + "Action required: verify that Modal can access your container registry (check network connectivity " + "and registry permissions), and that the Docker image can be pulled and extended with additional dependencies. " + f"Context: image='{image_name}'." + ) from e resource_settings = info.config.resource_settings + gpu_values = get_gpu_values(settings, resource_settings) + memory_int = ( + int(mb) + if (mb := resource_settings.get_memory(ByteUnit.MB)) + else None + ) + app = modal.App( f"zenml-{info.run_name}-{info.step_run_id}-{info.pipeline_step_name}" ) - async def run_sandbox() -> asyncio.Future[None]: - loop = asyncio.get_event_loop() - future = loop.create_future() + async def run_sandbox() -> None: with modal.enable_output(): - async with app.run(): - memory_mb = resource_settings.get_memory(ByteUnit.MB) - memory_int = ( - int(memory_mb) if memory_mb is not None else None - ) - sb = await modal.Sandbox.create.aio( - "bash", - "-c", - " ".join(entrypoint_command), - image=zenml_image, - gpu=gpu_values, - cpu=resource_settings.cpu_count, - memory=memory_int, - cloud=settings.cloud, - region=settings.region, - app=app, - timeout=86400, # 24h, the max Modal allows - ) - - await sb.wait.aio() - - future.set_result(None) - return future + try: + async with app.run(): + try: + sb = await modal.Sandbox.create.aio( + *entrypoint_command, + image=modal_image, + gpu=gpu_values, + cpu=resource_settings.cpu_count, + memory=memory_int, + cloud=settings.cloud, + region=settings.region, + app=app, + timeout=settings.timeout, + ) + except Exception as e: + raise RuntimeError( + "Failed to create a Modal sandbox. " + "Action required: verify that the referenced Docker image is accessible, " + "the requested resources are available (gpu/region/cloud), and your Modal workspace " + "permissions allow sandbox creation. " + f"Context: image='{image_name}', gpu='{gpu_values}', region='{settings.region}', cloud='{settings.cloud}'." + ) from e + + try: + await sb.wait.aio() + except Exception as e: + raise RuntimeError( + "Modal sandbox execution failed. " + "Action required: inspect the step logs in Modal, and confirm that " + "dependencies are available in the image/environment." + ) from e + except Exception as e: + raise RuntimeError( + "Failed to initialize Modal application context (authentication / workspace / environment). " + "Action required: make sure you're authenticated with Modal (run 'modal token new' or set " + "MODAL_TOKEN_ID and MODAL_TOKEN_SECRET), and that the configured workspace/environment exist " + "and you have access to them." + ) from e asyncio.run(run_sandbox()) diff --git a/src/zenml/integrations/modal/utils.py b/src/zenml/integrations/modal/utils.py new file mode 100644 index 00000000000..fbbcf44240a --- /dev/null +++ b/src/zenml/integrations/modal/utils.py @@ -0,0 +1,331 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Shared utilities for Modal integration components.""" + +import os +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple + +import modal + +from zenml.config.resource_settings import ResourceSettings +from zenml.enums import StackComponentType +from zenml.exceptions import StackComponentInterfaceError +from zenml.logger import get_logger +from zenml.stack import Stack, StackValidator + +logger = get_logger(__name__) + +if TYPE_CHECKING: + from zenml.integrations.modal.flavors import ModalStepOperatorSettings + +MODAL_TOKEN_ID_ENV = "MODAL_TOKEN_ID" +MODAL_TOKEN_SECRET_ENV = "MODAL_TOKEN_SECRET" +MODAL_WORKSPACE_ENV = "MODAL_WORKSPACE" +MODAL_ENVIRONMENT_ENV = "MODAL_ENVIRONMENT" +MODAL_CONFIG_PATH = os.path.expanduser("~/.modal.toml") + + +def _validate_token_prefix( + value: str, expected_prefix: str, label: str +) -> None: + """Warn if a credential doesn't match Modal's expected prefix. + + This helps catch misconfigurations early without logging secret content. + + Args: + value: The credential value to validate. + expected_prefix: The prefix that the credential should start with. + label: Human-readable label for the credential (used in warning messages). + """ + if not value.startswith(expected_prefix): + logger.warning( + f"{label} format may be invalid. Expected prefix: {expected_prefix}" + ) + + +def _set_env_if_present(var_name: str, value: Optional[str]) -> bool: + """Set an environment variable only if a non-empty value is provided. + + Args: + var_name: The name of the environment variable to set. + value: The value to set for the environment variable, or None. + + Returns: + True if the environment variable was set, False otherwise. + """ + if value is None or value == "": + return False + os.environ[var_name] = value + return True + + +def setup_modal_client( + token_id: Optional[str] = None, + token_secret: Optional[str] = None, + workspace: Optional[str] = None, + environment: Optional[str] = None, +) -> None: + """Setup Modal client authentication and context. + + Precedence for credentials: + 1) Explicit arguments (token_id, token_secret) + 2) Existing environment variables (MODAL_TOKEN_ID, MODAL_TOKEN_SECRET) + 3) Default Modal configuration (~/.modal.toml) + + Notes: + - The 'environment' parameter refers to the Modal environment name (e.g., 'main'), + not a dict of process environment variables. + - This function avoids logging secret values. Validation only checks known prefixes. + + Args: + token_id: Modal API token ID (ak-xxxxx format). + token_secret: Modal API token secret (as-xxxxx format). + workspace: Modal workspace name. + environment: Modal environment name. + """ + # Remember whether values came from args vs. env to improve diagnostics without leaking secrets. + arg_token_id = token_id + arg_token_secret = token_secret + + # Coalesce from env if not explicitly provided. This reduces friction if users + # supply one value via configuration and the other via environment. + token_id = token_id or os.environ.get(MODAL_TOKEN_ID_ENV) + token_secret = token_secret or os.environ.get(MODAL_TOKEN_SECRET_ENV) + + tokens_provided = [] + token_sources: Dict[str, str] = {} + + if token_id: + _validate_token_prefix(token_id, "ak-", "Token ID") + _set_env_if_present(MODAL_TOKEN_ID_ENV, token_id) + tokens_provided.append("ID") + token_sources["ID"] = "args" if arg_token_id else "env" + + if token_secret: + _validate_token_prefix(token_secret, "as-", "Token secret") + _set_env_if_present(MODAL_TOKEN_SECRET_ENV, token_secret) + tokens_provided.append("secret") + token_sources["secret"] = "args" if arg_token_secret else "env" + + if tokens_provided: + if len(tokens_provided) == 1: + logger.warning( + f"Only token {tokens_provided[0]} provided. Ensure both are set." + ) + source_parts: List[str] = [] + if "ID" in token_sources: + source_parts.append(f"ID from {token_sources['ID']}") + if "secret" in token_sources: + source_parts.append(f"secret from {token_sources['secret']}") + source_summary = ( + " and ".join(source_parts) if source_parts else "args/env" + ) + logger.debug(f"Using Modal API tokens ({source_summary}).") + else: + # Fall back to default Modal CLI auth configuration. + logger.debug("Using default platform authentication (~/.modal.toml)") + if os.path.exists(MODAL_CONFIG_PATH): + logger.debug(f"Found platform config at {MODAL_CONFIG_PATH}") + else: + logger.warning( + f"No platform config found at {MODAL_CONFIG_PATH}. " + "Run 'modal token new' to set up authentication." + ) + + if workspace: + _set_env_if_present(MODAL_WORKSPACE_ENV, workspace) + if environment: + _set_env_if_present(MODAL_ENVIRONMENT_ENV, environment) + + +def build_modal_image( + image_name: str, + stack: "Stack", + environment: Optional[Dict[str, str]] = None, +) -> modal.Image: + """Build a Modal image from a Docker registry with authentication. + + This helper function centralizes the shared logic for building Modal images + from Docker registries, including credential validation, secret creation, + and image building with Modal installation. + + Args: + image_name: The name of the Docker image to use as base. + stack: The ZenML stack containing container registry. + environment: Optional environment variables to apply to the image. + + Returns: + The configured Modal image. + + Raises: + RuntimeError: If no Docker credentials are found. + """ + if not stack.container_registry: + raise RuntimeError( + "No Container registry found in the stack. " + "Please add a container registry and ensure " + "it is correctly configured." + ) + + if docker_creds := stack.container_registry.credentials: + docker_username, docker_password = docker_creds + else: + raise RuntimeError( + "No Docker credentials found for the container registry." + ) + + registry_secret = modal.Secret.from_dict( + { + "REGISTRY_USERNAME": docker_username, + "REGISTRY_PASSWORD": docker_password, + } + ) + + try: + modal_image = modal.Image.from_registry( + image_name, secret=registry_secret + ) + except Exception as e: + raise RuntimeError( + "Failed to construct a Modal image from the specified Docker base image. " + "Action required: ensure the image exists and is accessible from your container registry, " + "and that the provided credentials are correct." + ) from e + + if environment: + try: + modal_image = modal_image.env(environment) + except Exception as e: + # This is a defensive guard; env composition is local, but we still provide guidance. + raise RuntimeError( + "Failed to apply environment variables to the Modal image. " + "Action required: verify that environment variable keys and values are valid strings." + ) from e + + return modal_image + + +def get_gpu_values( + settings: "ModalStepOperatorSettings", + resource_settings: "ResourceSettings", +) -> Optional[str]: + """Compute and validate the Modal ``gpu`` argument string. + + Modal expects GPU resources as either ``None`` (CPU only), a GPU type string + like ``"A100"`` (implicitly a single GPU), or ``"A100:2"`` when multiple + GPUs of the same type are requested. Within ZenML, the GPU type is captured + in :class:`ModalStepOperatorSettings` while the count lives in + :class:`~zenml.config.resource_settings.ResourceSettings`. This helper + reconciles both sources so other Modal components can reuse the same + validation rules. + + Args: + settings: The Modal step operator settings describing the GPU type. + resource_settings: Resource constraints for the step, providing the GPU count. + + Returns: + A Modal-compatible GPU specification string or ``None`` when running on CPU. + + Raises: + StackComponentInterfaceError: If the configuration is inconsistent or invalid. + """ + gpu_type_raw = settings.gpu + gpu_type = gpu_type_raw.strip() if gpu_type_raw is not None else None + if gpu_type == "": + gpu_type = None + + gpu_count = resource_settings.gpu_count + if gpu_count is not None: + try: + gpu_count = int(gpu_count) + except (TypeError, ValueError): + raise StackComponentInterfaceError( + f"Invalid GPU count '{gpu_count}'. Must be a non-negative integer." + ) + if gpu_count < 0: + raise StackComponentInterfaceError( + f"Invalid GPU count '{gpu_count}'. Must be >= 0." + ) + + if gpu_type is None: + if gpu_count is not None and gpu_count > 0: + raise StackComponentInterfaceError( + "GPU resources requested (gpu_count > 0) but no GPU type was specified " + "in Modal settings. Please set a GPU type (e.g., 'T4', 'A100') via " + "ModalStepOperatorSettings.gpu or @step(settings={'modal': {'gpu': ''}}), " + "or set gpu_count=0 to run on CPU." + ) + return None + + if gpu_count == 0: + logger.warning( + "Modal GPU type '%s' is configured but ResourceSettings.gpu_count is 0. " + "Defaulting to 1 GPU. To run on CPU only, remove the GPU type or ensure " + "gpu_count=0 with no GPU type configured.", + gpu_type, + ) + return gpu_type + + if gpu_count is None: + return gpu_type + + if gpu_count > 0: + return f"{gpu_type}:{gpu_count}" + + return None + + +def get_modal_stack_validator() -> StackValidator: + """Get a stack validator for Modal components. + + The validator ensures that the stack contains a remote artifact store and + container registry. + + Returns: + A stack validator for modal components. + """ + + def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]: + if stack.artifact_store.config.is_local: + return False, ( + "Serverless components run code remotely and " + "need to write files into the artifact store, but the " + f"artifact store `{stack.artifact_store.name}` of the " + "active stack is local. Please ensure that your stack " + "contains a remote artifact store when using serverless " + "components." + ) + + container_registry = stack.container_registry + assert container_registry is not None + + if container_registry.config.is_local: + return False, ( + "Serverless components run code remotely and " + "need to push/pull Docker images, but the " + f"container registry `{container_registry.name}` of the " + "active stack is local. Please ensure that your stack " + "contains a remote container registry when using serverless " + "components." + ) + + return True, "" + + return StackValidator( + required_components={ + StackComponentType.CONTAINER_REGISTRY, + StackComponentType.IMAGE_BUILDER, + }, + custom_validation_function=_validate_remote_components, + ) diff --git a/tests/integration/functional/cli/test_model.py b/tests/integration/functional/cli/test_model.py index 48c67b298ad..27bd55cfb92 100644 --- a/tests/integration/functional/cli/test_model.py +++ b/tests/integration/functional/cli/test_model.py @@ -16,16 +16,16 @@ from uuid import uuid4 import pytest -from click.testing import CliRunner from tests.integration.functional.cli.conftest import NAME, PREFIX +from tests.integration.functional.cli.utils import cli_runner from zenml.cli.cli import cli from zenml.client import Client def test_model_list(clean_client_with_models: "Client"): """Test that zenml model list does not fail.""" - runner = CliRunner(mix_stderr=False) + runner = cli_runner() list_command = cli.commands["model"].commands["list"] result = runner.invoke(list_command) assert result.exit_code == 0, result.stderr @@ -33,7 +33,7 @@ def test_model_list(clean_client_with_models: "Client"): def test_model_create_short_names(clean_client_with_models: "Client"): """Test that zenml model create does not fail with short names.""" - runner = CliRunner(mix_stderr=False) + runner = cli_runner() create_command = cli.commands["model"].commands["register"] model_name = PREFIX + str(uuid4()) result = runner.invoke( @@ -82,7 +82,7 @@ def test_model_create_short_names(clean_client_with_models: "Client"): def test_model_create_full_names(clean_client_with_models: "Client"): """Test that zenml model create does not fail with full names.""" - runner = CliRunner(mix_stderr=False) + runner = cli_runner() create_command = cli.commands["model"].commands["register"] model_name = PREFIX + str(uuid4()) result = runner.invoke( @@ -131,7 +131,7 @@ def test_model_create_full_names(clean_client_with_models: "Client"): def test_model_create_only_required(clean_client_with_models: "Client"): """Test that zenml model create does not fail.""" - runner = CliRunner(mix_stderr=False) + runner = cli_runner() create_command = cli.commands["model"].commands["register"] model_name = PREFIX + str(uuid4()) result = runner.invoke( @@ -155,7 +155,7 @@ def test_model_create_only_required(clean_client_with_models: "Client"): def test_model_update(clean_client_with_models: "Client"): """Test that zenml model update does not fail.""" - runner = CliRunner(mix_stderr=False) + runner = cli_runner() update_command = cli.commands["model"].commands["update"] result = runner.invoke( update_command, @@ -185,7 +185,7 @@ def test_model_create_without_required_fails( clean_client_with_models: "Client", ): """Test that zenml model create fails.""" - runner = CliRunner(mix_stderr=False) + runner = cli_runner() create_command = cli.commands["model"].commands["register"] result = runner.invoke( create_command, @@ -195,7 +195,7 @@ def test_model_create_without_required_fails( def test_model_delete_found(clean_client_with_models: "Client"): """Test that zenml model delete does not fail.""" - runner = CliRunner(mix_stderr=False) + runner = cli_runner() name = PREFIX + str(uuid4()) create_command = cli.commands["model"].commands["register"] runner.invoke( @@ -212,7 +212,7 @@ def test_model_delete_found(clean_client_with_models: "Client"): def test_model_delete_not_found(clean_client_with_models: "Client"): """Test that zenml model delete fail.""" - runner = CliRunner(mix_stderr=False) + runner = cli_runner() name = PREFIX + str(uuid4()) delete_command = cli.commands["model"].commands["delete"] result = runner.invoke( @@ -224,7 +224,7 @@ def test_model_delete_not_found(clean_client_with_models: "Client"): def test_model_version_list(clean_client_with_models: "Client"): """Test that zenml model version list does not fail.""" - runner = CliRunner(mix_stderr=False) + runner = cli_runner() list_command = cli.commands["model"].commands["version"].commands["list"] result = runner.invoke(list_command, args=[f"--model={NAME}"]) assert result.exit_code == 0, result.stderr @@ -232,7 +232,7 @@ def test_model_version_list(clean_client_with_models: "Client"): def test_model_version_delete_found(clean_client_with_models: "Client"): """Test that zenml model version delete does not fail.""" - runner = CliRunner(mix_stderr=False) + runner = cli_runner() model_name = PREFIX + str(uuid4()) model_version_name = PREFIX + str(uuid4()) model = clean_client_with_models.create_model( @@ -254,7 +254,7 @@ def test_model_version_delete_found(clean_client_with_models: "Client"): def test_model_version_delete_not_found(clean_client_with_models: "Client"): """Test that zenml model version delete fail.""" - runner = CliRunner(mix_stderr=False) + runner = cli_runner() model_name = PREFIX + str(uuid4()) model_version_name = PREFIX + str(uuid4()) clean_client_with_models.create_model( @@ -278,7 +278,7 @@ def test_model_version_links_list( command: str, clean_client_with_models: "Client" ): """Test that zenml model version artifacts list fails.""" - runner = CliRunner(mix_stderr=False) + runner = cli_runner() list_command = cli.commands["model"].commands[command] result = runner.invoke( list_command, @@ -289,7 +289,7 @@ def test_model_version_links_list( def test_model_version_update(clean_client_with_models: "Client"): """Test that zenml model version stage update pass.""" - runner = CliRunner(mix_stderr=False) + runner = cli_runner() update_command = ( cli.commands["model"].commands["version"].commands["update"] ) diff --git a/tests/integration/functional/cli/test_tag.py b/tests/integration/functional/cli/test_tag.py index 775e6f1b013..0d20767f962 100644 --- a/tests/integration/functional/cli/test_tag.py +++ b/tests/integration/functional/cli/test_tag.py @@ -14,9 +14,9 @@ """Test zenml tag CLI commands.""" import pytest -from click.testing import CliRunner from tests.integration.functional.cli.utils import ( + cli_runner, random_resource_name, tags_killer, ) @@ -29,7 +29,7 @@ def test_tag_list(): """Test that zenml tag list does not fail.""" with tags_killer(): - runner = CliRunner(mix_stderr=False) + runner = cli_runner() list_command = cli.commands["tag"].commands["list"] result = runner.invoke(list_command) assert result.exit_code == 0, result.stderr @@ -38,7 +38,7 @@ def test_tag_list(): def test_tag_create_short_names(): """Test that zenml tag create does not fail with short names.""" with tags_killer(0): - runner = CliRunner(mix_stderr=False) + runner = cli_runner() create_command = cli.commands["tag"].commands["register"] tag_name = random_resource_name() result = runner.invoke( @@ -55,7 +55,7 @@ def test_tag_create_short_names(): def test_tag_create_full_names(): """Test that zenml tag create does not fail with full names.""" with tags_killer(0): - runner = CliRunner(mix_stderr=False) + runner = cli_runner() create_command = cli.commands["tag"].commands["register"] tag_name = random_resource_name() result = runner.invoke( @@ -72,7 +72,7 @@ def test_tag_create_full_names(): def test_tag_create_only_required(): """Test that zenml tag create does not fail.""" with tags_killer(0): - runner = CliRunner(mix_stderr=False) + runner = cli_runner() create_command = cli.commands["tag"].commands["register"] tag_name = random_resource_name() result = runner.invoke( @@ -93,7 +93,7 @@ def test_tag_update(): """Test that zenml tag update does not fail.""" with tags_killer(1) as tags: tag: TagResponse = tags[0] - runner = CliRunner(mix_stderr=False) + runner = cli_runner() update_command = cli.commands["tag"].commands["update"] color_to_set = "yellow" if tag.color.value != "yellow" else "grey" result = runner.invoke( @@ -138,7 +138,7 @@ def test_tag_update(): def test_tag_create_without_required_fails(): """Test that zenml tag create fails.""" with tags_killer(0): - runner = CliRunner(mix_stderr=False) + runner = cli_runner() create_command = cli.commands["tag"].commands["register"] result = runner.invoke( create_command, @@ -150,7 +150,7 @@ def test_tag_delete_found(): """Test that zenml tag delete does not fail.""" with tags_killer(1) as tags: tag: TagResponse = tags[0] - runner = CliRunner(mix_stderr=False) + runner = cli_runner() delete_command = cli.commands["tag"].commands["delete"] result = runner.invoke( delete_command, @@ -165,7 +165,7 @@ def test_tag_delete_found(): def test_tag_delete_not_found(): """Test that zenml tag delete fail.""" with tags_killer(0): - runner = CliRunner(mix_stderr=False) + runner = cli_runner() delete_command = cli.commands["tag"].commands["delete"] result = runner.invoke( delete_command, diff --git a/tests/integration/functional/cli/utils.py b/tests/integration/functional/cli/utils.py index 7a8331a919c..b54664c67ad 100644 --- a/tests/integration/functional/cli/utils.py +++ b/tests/integration/functional/cli/utils.py @@ -11,9 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing # permissions and limitations under the License. +import inspect from contextlib import contextmanager from typing import Generator, Optional, Tuple +from click.testing import CliRunner + from tests.harness.harness import TestHarness from zenml.cli import cli from zenml.cli.utils import ( @@ -35,6 +38,15 @@ ] +def cli_runner(**kwargs) -> CliRunner: + """Return a Click runner that stays compatible across Click releases.""" + if "mix_stderr" not in kwargs: + params = inspect.signature(CliRunner.__init__).parameters + if "mix_stderr" in params: + kwargs["mix_stderr"] = False + return CliRunner(**kwargs) + + # ----- # # USERS # # ----- # diff --git a/tests/integration/integrations/modal/flavors/__init__.py b/tests/integration/integrations/modal/flavors/__init__.py new file mode 100644 index 00000000000..32ba63d3703 --- /dev/null +++ b/tests/integration/integrations/modal/flavors/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. diff --git a/tests/integration/integrations/modal/flavors/test_flavor_config.py b/tests/integration/integrations/modal/flavors/test_flavor_config.py new file mode 100644 index 00000000000..ab3dd2d57ef --- /dev/null +++ b/tests/integration/integrations/modal/flavors/test_flavor_config.py @@ -0,0 +1,56 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. + +import pytest +from pydantic import ValidationError + +from zenml.integrations.modal.flavors.modal_step_operator_flavor import ( + DEFAULT_TIMEOUT_SECONDS, + ModalStepOperatorSettings, +) + + +def test_modal_settings_timeout_accepts_valid_values() -> None: + # Test minimum valid value + settings = ModalStepOperatorSettings(timeout=1) + assert settings.timeout == 1 + + # Test maximum valid value + settings = ModalStepOperatorSettings(timeout=DEFAULT_TIMEOUT_SECONDS) + assert settings.timeout == DEFAULT_TIMEOUT_SECONDS + + # Test mid-range value + settings = ModalStepOperatorSettings(timeout=3600) + assert settings.timeout == 3600 + + +def test_modal_settings_timeout_rejects_below_minimum() -> None: + with pytest.raises(ValidationError) as exc_info: + ModalStepOperatorSettings(timeout=0) + + assert "greater than or equal to 1" in str(exc_info.value) + + +def test_modal_settings_timeout_rejects_above_maximum() -> None: + with pytest.raises(ValidationError) as exc_info: + ModalStepOperatorSettings(timeout=DEFAULT_TIMEOUT_SECONDS + 1) + + assert "less than or equal to" in str(exc_info.value) + + +def test_modal_settings_timeout_rejects_negative() -> None: + with pytest.raises(ValidationError) as exc_info: + ModalStepOperatorSettings(timeout=-1) + + assert "greater than or equal to 1" in str(exc_info.value) diff --git a/tests/integration/integrations/modal/step_operators/__init__.py b/tests/integration/integrations/modal/step_operators/__init__.py index cd90a82cfc2..32ba63d3703 100644 --- a/tests/integration/integrations/modal/step_operators/__init__.py +++ b/tests/integration/integrations/modal/step_operators/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/integration/integrations/modal/step_operators/test_modal_step_operator.py b/tests/integration/integrations/modal/step_operators/test_modal_step_operator.py index 865f6d55265..91142b8b3ab 100644 --- a/tests/integration/integrations/modal/step_operators/test_modal_step_operator.py +++ b/tests/integration/integrations/modal/step_operators/test_modal_step_operator.py @@ -1,4 +1,4 @@ -# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,35 +12,101 @@ # or implied. See the License for the specific language governing # permissions and limitations under the License. - import pytest -from zenml.config.resource_settings import ResourceSettings +# Skip this entire module if the optional dependency isn't available, because +# the module under test imports `modal` at import time. +pytest.importorskip("modal") + +from zenml.exceptions import StackComponentInterfaceError from zenml.integrations.modal.flavors import ModalStepOperatorSettings -from zenml.integrations.modal.step_operators.modal_step_operator import ( - get_gpu_values, -) - - -@pytest.mark.parametrize( - "gpu, gpu_count, expected_result", - [ - ("", None, None), - (None, None, None), - ("", 1, None), - (None, 1, None), - ("A100", None, "A100"), - ("A100", 0, "A100"), - ("A100", 1, "A100:1"), - ("A100", 2, "A100:2"), - ("V100", None, "V100"), - ("V100", 0, "V100"), - ("V100", 1, "V100:1"), - ("V100", 2, "V100:2"), - ], -) -def test_get_gpu_values(gpu, gpu_count, expected_result): - settings = ModalStepOperatorSettings(gpu=gpu) - resource_settings = ResourceSettings(gpu_count=gpu_count) - result = get_gpu_values(settings, resource_settings) - assert result == expected_result +from zenml.integrations.modal.utils import get_gpu_values + + +class ResourceSettingsStub: + """Minimal stub to simulate ZenML ResourceSettings for GPU tests. + + We only model the `gpu_count` attribute because that's the only part the + helper uses. This keeps tests lightweight and avoids wider dependencies. + """ + + def __init__(self, gpu_count): + self.gpu_count = gpu_count + + +def test_gpu_arg_none_when_no_type_and_no_count() -> None: + settings = ModalStepOperatorSettings(gpu=None) + rs = ResourceSettingsStub(gpu_count=None) + assert get_gpu_values(settings, rs) is None + + +def test_gpu_arg_raises_when_count_without_type() -> None: + settings = ModalStepOperatorSettings(gpu=None) + rs = ResourceSettingsStub(gpu_count=1) + with pytest.raises(StackComponentInterfaceError) as e: + get_gpu_values(settings, rs) + assert ( + "GPU resources requested (gpu_count > 0) but no GPU type was specified" + in str(e.value) + ) + + +def test_gpu_arg_type_with_no_count_returns_type() -> None: + settings = ModalStepOperatorSettings(gpu="A100") + rs = ResourceSettingsStub(gpu_count=None) + assert get_gpu_values(settings, rs) == "A100" + + +def test_gpu_arg_type_with_count_returns_type_colon_count() -> None: + settings = ModalStepOperatorSettings(gpu="A100") + + rs_two = ResourceSettingsStub(gpu_count=2) + assert get_gpu_values(settings, rs_two) == "A100:2" + + rs_one = ResourceSettingsStub(gpu_count=1) + assert get_gpu_values(settings, rs_one) == "A100:1" + + +def test_gpu_arg_type_with_zero_count_warns_and_defaults_to_single_gpu( + caplog, +) -> None: + settings = ModalStepOperatorSettings(gpu="A100") + rs = ResourceSettingsStub(gpu_count=0) + + with caplog.at_level("WARNING"): + result = get_gpu_values(settings, rs) + + assert result == "A100" + assert "Defaulting to 1 GPU" in caplog.text + + +def test_gpu_arg_invalid_negative_count_raises() -> None: + settings = ModalStepOperatorSettings(gpu="T4") + rs = ResourceSettingsStub(gpu_count=-1) + with pytest.raises(StackComponentInterfaceError) as e: + get_gpu_values(settings, rs) + assert "Invalid GPU count" in str(e.value) + + +def test_gpu_arg_non_integer_count_raises() -> None: + settings = ModalStepOperatorSettings(gpu="T4") + rs = ResourceSettingsStub(gpu_count="two") + with pytest.raises(StackComponentInterfaceError) as e: + get_gpu_values(settings, rs) + assert "Invalid GPU count" in str(e.value) + + +def test_gpu_arg_whitespace_type_treated_as_none_behavior() -> None: + settings = ModalStepOperatorSettings(gpu=" ") + + # With positive GPU count this should raise since type is treated as None. + rs_positive = ResourceSettingsStub(gpu_count=2) + with pytest.raises(StackComponentInterfaceError): + get_gpu_values(settings, rs_positive) + + # With zero or None count, this should be CPU-only (None). + rs_zero = ResourceSettingsStub(gpu_count=0) + assert get_gpu_values(settings, rs_zero) is None + + rs_none = ResourceSettingsStub(gpu_count=None) + assert get_gpu_values(settings, rs_none) is None diff --git a/tests/integration/integrations/modal/test_modal_utils.py b/tests/integration/integrations/modal/test_modal_utils.py new file mode 100644 index 00000000000..5cf30b77f75 --- /dev/null +++ b/tests/integration/integrations/modal/test_modal_utils.py @@ -0,0 +1,172 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. + +import os + +import pytest + +# Skip this entire module if the optional dependency isn't available, because +# the module under test imports `modal` at import time. +pytest.importorskip("modal") + +from zenml.integrations.modal.utils import ( + MODAL_CONFIG_PATH, + MODAL_ENVIRONMENT_ENV, + MODAL_TOKEN_ID_ENV, + MODAL_TOKEN_SECRET_ENV, + MODAL_WORKSPACE_ENV, + _set_env_if_present, + _validate_token_prefix, + build_modal_image, + setup_modal_client, +) + + +class StackStubNoRegistry: + """Stack stub with no container registry to trigger early validation.""" + + container_registry = None + + +class ContainerRegistryStubNoCreds: + """Container registry stub with missing credentials.""" + + def __init__(self): + self.credentials = None + + +class StackStubWithRegistryNoCreds: + """Stack stub with a container registry but no credentials.""" + + def __init__(self): + self.container_registry = ContainerRegistryStubNoCreds() + + +def test_build_modal_image_raises_when_no_registry() -> None: + with pytest.raises(RuntimeError) as e: + build_modal_image( + "repo/image:tag", StackStubNoRegistry(), environment=None + ) + assert "No Container registry found in the stack" in str(e.value) + + +def test_build_modal_image_raises_when_no_credentials() -> None: + with pytest.raises(RuntimeError) as e: + build_modal_image( + "repo/image:tag", StackStubWithRegistryNoCreds(), environment=None + ) + assert "No Docker credentials found for the container registry" in str( + e.value + ) + + +def test_set_env_if_present_sets_and_returns_true(monkeypatch) -> None: + monkeypatch.delenv("FOO_TEST", raising=False) + assert _set_env_if_present("FOO_TEST", "bar") is True + assert os.environ["FOO_TEST"] == "bar" + + +def test_set_env_if_present_ignores_none_and_empty(monkeypatch) -> None: + monkeypatch.delenv("FOO_TEST", raising=False) + + assert _set_env_if_present("FOO_TEST", None) is False + assert "FOO_TEST" not in os.environ + + assert _set_env_if_present("FOO_TEST", "") is False + assert "FOO_TEST" not in os.environ + + +def test_validate_token_prefix_warns_on_invalid_prefix(caplog) -> None: + with caplog.at_level("WARNING"): + _validate_token_prefix("invalid-token", "ak-", "Token ID") + assert "Expected prefix: ak-" in caplog.text + + +def test_setup_modal_client_sets_env_from_args(monkeypatch) -> None: + # Ensure a clean environment + monkeypatch.delenv(MODAL_TOKEN_ID_ENV, raising=False) + monkeypatch.delenv(MODAL_TOKEN_SECRET_ENV, raising=False) + monkeypatch.delenv(MODAL_WORKSPACE_ENV, raising=False) + monkeypatch.delenv(MODAL_ENVIRONMENT_ENV, raising=False) + + setup_modal_client( + token_id="ak-abc123", + token_secret="as-def456", + workspace="my-ws", + environment="main", + ) + + assert os.environ[MODAL_TOKEN_ID_ENV] == "ak-abc123" + assert os.environ[MODAL_TOKEN_SECRET_ENV] == "as-def456" + assert os.environ[MODAL_WORKSPACE_ENV] == "my-ws" + assert os.environ[MODAL_ENVIRONMENT_ENV] == "main" + + +def test_setup_modal_client_warns_when_only_one_token_provided( + monkeypatch, caplog +) -> None: + monkeypatch.delenv(MODAL_TOKEN_ID_ENV, raising=False) + monkeypatch.delenv(MODAL_TOKEN_SECRET_ENV, raising=False) + + with caplog.at_level("WARNING"): + setup_modal_client(token_id="ak-abc123", token_secret=None) + + assert "Only token ID provided" in caplog.text + assert os.environ[MODAL_TOKEN_ID_ENV] == "ak-abc123" + assert MODAL_TOKEN_SECRET_ENV not in os.environ + + +def test_setup_modal_client_prefers_args_over_env(monkeypatch) -> None: + monkeypatch.setenv(MODAL_TOKEN_ID_ENV, "ak-old") + monkeypatch.setenv(MODAL_TOKEN_SECRET_ENV, "as-old") + + setup_modal_client(token_id="ak-new", token_secret="as-new") + + assert os.environ[MODAL_TOKEN_ID_ENV] == "ak-new" + assert os.environ[MODAL_TOKEN_SECRET_ENV] == "as-new" + + +def test_setup_modal_client_logs_missing_config_when_no_tokens_and_no_config( + monkeypatch, caplog +) -> None: + # Clear tokens from environment + monkeypatch.delenv(MODAL_TOKEN_ID_ENV, raising=False) + monkeypatch.delenv(MODAL_TOKEN_SECRET_ENV, raising=False) + + # Pretend modal config is missing + monkeypatch.setenv( + "HOME", "/nonexistent-home-for-test" + ) # defensive isolation + monkeypatch.setattr("os.path.exists", lambda path: False) + + with caplog.at_level("WARNING"): + setup_modal_client() + + # Should warn that the default config file is missing + assert "No platform config found at" in caplog.text + assert str(MODAL_CONFIG_PATH) in caplog.text + + +def test_setup_modal_client_warns_on_bad_token_prefixes( + monkeypatch, caplog +) -> None: + monkeypatch.delenv(MODAL_TOKEN_ID_ENV, raising=False) + monkeypatch.delenv(MODAL_TOKEN_SECRET_ENV, raising=False) + + with caplog.at_level("WARNING"): + setup_modal_client(token_id="bad-id", token_secret="bad-secret") + + # Both ID and secret should trigger prefix warnings + assert "Expected prefix: ak-" in caplog.text + assert "Expected prefix: as-" in caplog.text