diff --git a/src/zenml/integrations/aws/__init__.py b/src/zenml/integrations/aws/__init__.py index 28cd8a82a85..c30de95da5f 100644 --- a/src/zenml/integrations/aws/__init__.py +++ b/src/zenml/integrations/aws/__init__.py @@ -28,6 +28,7 @@ AWS_CONTAINER_REGISTRY_FLAVOR = "aws" AWS_SAGEMAKER_STEP_OPERATOR_FLAVOR = "sagemaker" AWS_SAGEMAKER_ORCHESTRATOR_FLAVOR = "sagemaker" +AWS_BATCH_STEP_OPERATOR_FLAVOR = "aws_batch" AWS_DEPLOYER_FLAVOR = "aws" # Service connector constants @@ -66,6 +67,7 @@ def flavors(cls) -> List[Type[Flavor]]: AWSImageBuilderFlavor, SagemakerOrchestratorFlavor, SagemakerStepOperatorFlavor, + AWSBatchStepOperatorFlavor ) return [ @@ -74,4 +76,5 @@ def flavors(cls) -> List[Type[Flavor]]: AWSImageBuilderFlavor, SagemakerStepOperatorFlavor, SagemakerOrchestratorFlavor, + AWSBatchStepOperatorFlavor ] diff --git a/src/zenml/integrations/aws/flavors/__init__.py b/src/zenml/integrations/aws/flavors/__init__.py index 823c08bbcdf..94a9b92740d 100644 --- a/src/zenml/integrations/aws/flavors/__init__.py +++ b/src/zenml/integrations/aws/flavors/__init__.py @@ -33,6 +33,10 @@ SagemakerStepOperatorConfig, SagemakerStepOperatorFlavor, ) +from zenml.integrations.aws.flavors.aws_batch_step_operator_flavor import ( + AWSBatchStepOperatorConfig, + AWSBatchStepOperatorFlavor +) __all__ = [ "AWSContainerRegistryFlavor", @@ -45,4 +49,7 @@ "SagemakerStepOperatorConfig", "SagemakerOrchestratorFlavor", "SagemakerOrchestratorConfig", + "AWSBatchStepOperatorFlavor", + "AWSBatchStepOperatorConfig", + ] diff --git a/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py b/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py new file mode 100644 index 00000000000..a042dc53b95 --- /dev/null +++ b/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py @@ -0,0 +1,201 @@ +# Copyright (c) ZenML GmbH 2022. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""AWS Batch Step operator flavor.""" + +from typing import TYPE_CHECKING, Dict, Optional, Type, Literal + +from pydantic import Field, PositiveInt, field_validator +from zenml.utils.secret_utils import SecretField +from zenml.config.base_settings import BaseSettings +from zenml.integrations.aws import ( + AWS_RESOURCE_TYPE, + AWS_BATCH_STEP_OPERATOR_FLAVOR, +) +from zenml.models import ServiceConnectorRequirements +from zenml.step_operators.base_step_operator import ( + BaseStepOperatorConfig, + BaseStepOperatorFlavor, +) + +if TYPE_CHECKING: + from zenml.integrations.aws.step_operators import AWSBatchStepOperator + + +class AWSBatchStepOperatorSettings(BaseSettings): + """Settings for the Sagemaker step operator.""" + + environment: Dict[str, str] = Field( + default_factory=dict, + description="Environment variables to pass to the container during " \ + "execution. Example: {'LOG_LEVEL': 'INFO', 'DEBUG_MODE': 'False'}", + ) + job_queue_name: str = Field( + default="", + description="The AWS Batch job queue to submit the step AWS Batch job" + " to. If not provided, falls back to the default job queue name " + "specified at stack registration time. Must be compatible with" + "`backend`." + ) + backend: Literal['EC2','FARGATE'] = Field( + default="FARGATE", + description="The AWS Batch platform capability for the step AWS Batch " + "job to be orchestrated with. Must be compatible with `job_queue_name`." + "Defaults to 'FARGATE'." + ) + assign_public_ip: Literal['ENABLED','DISABLED'] = Field( + default="ENABLED", + description="Sets the network configuration's assignPublicIp field." + "Only relevant for FARGATE backend." + ) + timeout_seconds: PositiveInt = Field( + default=3600, + description="The number of seconds before AWS Batch times out the job." + ) + + + +class AWSBatchStepOperatorConfig( + BaseStepOperatorConfig, AWSBatchStepOperatorSettings +): + """Config for the AWS Batch step operator. + + Note: We use ECS as a backend (not EKS), and EC2 as a compute engine (not + Fargate). This is because + - users can avoid the complexity of setting up an EKS cluster, and + - we can AWS Batch multinode type job support later, which requires EC2 + """ + + execution_role: str = Field( + description="The IAM role arn of the ECS execution role." + ) + job_role: str = Field( + description="The IAM role arn of the ECS job role." + ) + default_job_queue_name: str = Field( + description="The default AWS Batch job queue to submit AWS Batch jobs to." + ) + aws_access_key_id: Optional[str] = SecretField( + default=None, + description="The AWS access key ID to use to authenticate to AWS. " + "If not provided, the value from the default AWS config will be used.", + ) + aws_secret_access_key: Optional[str] = SecretField( + default=None, + description="The AWS secret access key to use to authenticate to AWS. " + "If not provided, the value from the default AWS config will be used.", + ) + aws_profile: Optional[str] = Field( + None, + description="The AWS profile to use for authentication if not using " + "service connectors or explicit credentials. If not provided, the " + "default profile will be used.", + ) + aws_auth_role_arn: Optional[str] = Field( + None, + description="The ARN of an intermediate IAM role to assume when " + "authenticating to AWS.", + ) + region: Optional[str] = Field( + None, + description="The AWS region where the processing job will be run. " + "If not provided, the value from the default AWS config will be used.", + ) + + @property + def is_remote(self) -> bool: + """Checks if this stack component is running remotely. + + This designation is used to determine if the stack component can be + used with a local ZenML database or if it requires a remote ZenML + server. + + Returns: + True if this config is for a remote component, False otherwise. + """ + return True + + +class AWSBatchStepOperatorFlavor(BaseStepOperatorFlavor): + """Flavor for the AWS Batch step operator.""" + + @property + def name(self) -> str: + """Name of the flavor. + + Returns: + The name of the flavor. + """ + return AWS_BATCH_STEP_OPERATOR_FLAVOR + + @property + def service_connector_requirements( + self, + ) -> Optional[ServiceConnectorRequirements]: + """Service connector resource requirements for service connectors. + + Specifies resource requirements that are used to filter the available + service connector types that are compatible with this flavor. + + Returns: + Requirements for compatible service connectors, if a service + connector is required for this flavor. + """ + return ServiceConnectorRequirements(resource_type=AWS_RESOURCE_TYPE) + + @property + def docs_url(self) -> Optional[str]: + """A url to point at docs explaining this flavor. + + Returns: + A flavor docs url. + """ + return self.generate_default_docs_url() + + @property + def sdk_docs_url(self) -> Optional[str]: + """A url to point at SDK docs explaining this flavor. + + Returns: + A flavor SDK docs url. + """ + return self.generate_default_sdk_docs_url() + + @property + def logo_url(self) -> str: + """A url to represent the flavor in the dashboard. + + Returns: + The flavor logo. + """ + return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/step_operator/aws_batch.png" + + @property + def config_class(self) -> Type[AWSBatchStepOperatorConfig]: + """Returns BatchStepOperatorConfig config class. + + Returns: + The config class. + """ + return AWSBatchStepOperatorConfig + + @property + def implementation_class(self) -> Type["AWSBatchStepOperator"]: + """Implementation class. + + Returns: + The implementation class. + """ + from zenml.integrations.aws.step_operators import AWSBatchStepOperator + + return AWSBatchStepOperator diff --git a/src/zenml/integrations/aws/step_operators/__init__.py b/src/zenml/integrations/aws/step_operators/__init__.py index 9eee3140d43..f766a309217 100644 --- a/src/zenml/integrations/aws/step_operators/__init__.py +++ b/src/zenml/integrations/aws/step_operators/__init__.py @@ -11,10 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing # permissions and limitations under the License. -"""Initialization of the Sagemaker Step Operator.""" +"""Initialization of the AWS Batch Step Operator.""" -from zenml.integrations.aws.step_operators.sagemaker_step_operator import ( # noqa +from zenml.integrations.aws.step_operators.sagemaker_step_operator import ( # noqa: F401 SagemakerStepOperator, ) - -__all__ = ["SagemakerStepOperator"] +from zenml.integrations.aws.step_operators.aws_batch_step_operator import ( # noqa: F401 + AWSBatchStepOperator, +) +__all__ = ["SagemakerStepOperator","AWSBatchStepOperator"] diff --git a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py new file mode 100644 index 00000000000..0eac0fe8b06 --- /dev/null +++ b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py @@ -0,0 +1,510 @@ +# Copyright (c) ZenML GmbH 2022. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Implementation of the AWS Batch Step Operator.""" + +import time +import math +from typing import ( + TYPE_CHECKING, + Dict, + List, + Optional, + Tuple, + Type, + Literal, + cast, +) +import re +from pydantic import BaseModel, PositiveInt, field_validator +from boto3 import Session + +from zenml.config.build_configuration import BuildConfiguration +from zenml.enums import StackComponentType +from zenml.integrations.aws.flavors.aws_batch_step_operator_flavor import ( + AWSBatchStepOperatorConfig, + AWSBatchStepOperatorSettings, +) +from zenml.logger import get_logger +from zenml.stack import Stack, StackValidator +from zenml.step_operators import BaseStepOperator +from zenml.utils.string_utils import random_str +from botocore.exceptions import ClientError +from string import ascii_letters, digits + +if TYPE_CHECKING: + from zenml.config.base_settings import BaseSettings + from zenml.config import ResourceSettings + from zenml.config.step_run_info import StepRunInfo + from zenml.models import PipelineDeploymentBase + +logger = get_logger(__name__) + +BATCH_DOCKER_IMAGE_KEY = "aws_batch_step_operator" +_ENTRYPOINT_ENV_VARIABLE = "__ZENML_ENTRYPOINT" + + +VALID_FARGATE_VCPU = ('0.25', '0.5', '1', '2', '4', '8', '16') +VALID_FARGATE_MEMORY = { + '0.25': ('512', '1024', '2048'), + '0.5': ('1024', '2048', '3072', '4096'), + '1': ('2048', '3072', '4096', '5120', '6144', '7168', '8192'), + '2': tuple(str(m) for m in range(4096, 16385, 1024)), + '4': tuple(str(m) for m in range(8192, 30721, 1024)), + '8': tuple(str(m) for m in range(16384, 61441, 4096)), + '16': tuple(str(m) for m in range(32768, 122881, 8192)) +} + +class ResourceRequirement(BaseModel): + type: Literal["MEMORY","VCPU","GPU"] + value: str + +class AWSBatchJobDefinitionContainerProperties(BaseModel): + """An AWS Batch job subconfiguration model for a container type job's container specification.""" + image: str + command: List[str] + jobRoleArn: str + executionRoleArn: str + environment: List[Dict[str,str]] = [] # keys: 'name','value' + resourceRequirements: List[ResourceRequirement] = [] # keys: 'value','type', with type one of 'GPU','VCPU','MEMORY' + secrets: List[Dict[str,str]] = [] # keys: 'name','value' + + +class AWSBatchJobDefinitionEC2ContainerProperties(AWSBatchJobDefinitionContainerProperties): + logConfiguration: dict[Literal["logDriver"],Literal["awsfirelens", "awslogs", "fluentd", "gelf", "json-file", "journald", "logentries", "syslog", "splunk"]] = {"logDriver":"awslogs"} + + @field_validator("resourceRequirements") + def check_resource_requirements(cls,resource_requirements: List[ResourceRequirement]) -> List[ResourceRequirement]: + + gpu_requirement = [req for req in resource_requirements if req.type == "GPU"] + cpu_requirement = [req for req in resource_requirements if req.type == "VCPU"][0] + memory_requirement = [req for req in resource_requirements if req.type == "MEMORY"][0] + + cpu_float = float(cpu_requirement.value) + cpu_rounded_int = math.ceil(cpu_float) + + if cpu_float != cpu_rounded_int: + logger.info( + f"Rounded fractional EC2 resource VCPU vale from {cpu_float} to {cpu_rounded_int} " + "since AWS Batch on EC2 requires whole integer VCPU count value." + ) + resource_requirements = [ + ResourceRequirement( + type="VCPU", + value=str(cpu_rounded_int) + ), + memory_requirement + ] + resource_requirements.extend(gpu_requirement) + + return resource_requirements + +class AWSBatchJobDefinitionFargateContainerProperties(AWSBatchJobDefinitionContainerProperties): + logConfiguration: dict[Literal["logDriver"],Literal["awslogs","splunk"]] = {"logDriver":"awslogs"} + networkConfiguration: dict[Literal['assignPublicIp'],Literal['ENABLED','DISABLED']] = {"assignPublicIp": "ENABLED"} + + @field_validator("resourceRequirements") + def check_resource_requirements(cls,resource_requirements: List[ResourceRequirement]) -> List[ResourceRequirement]: + + gpu_requirement = [req for req in resource_requirements if req.type == "GPU"] + + if gpu_requirement: + raise ValueError( + f"Invalid fargate resource requirement: GPU. Use EC2 " + "platform capability if you need custom devices." + ) + + cpu_requirement = [req for req in resource_requirements if req.type == "VCPU"][0] + memory_requirement = [req for req in resource_requirements if req.type == "MEMORY"][0] + + if cpu_requirement.value not in VALID_FARGATE_VCPU: + raise ValueError( + f"Invalid fargate resource requirement VCPU value {cpu_requirement.value}." + f"Must be one of {VALID_FARGATE_VCPU}" + ) + + if memory_requirement.value not in VALID_FARGATE_MEMORY[cpu_requirement.value]: + raise ValueError( + f"Invalid fargate resource requirement MEMORY value {memory_requirement.value}." + f"For VCPU={cpu_requirement.value}, MEMORY must be one of {VALID_FARGATE_MEMORY[cpu_requirement.value]}" + ) + + return resource_requirements + +class AWSBatchJobDefinitionRetryStrategy(BaseModel): + """An AWS Batch job subconfiguration model for retry specifications.""" + attempts: PositiveInt = 2 + evaluateOnExit: List[Dict[str,str]] = [ + { + "onExitCode": "137", # out-of-memory killed + "action": "RETRY" + }, + { + "onReason": "Host EC2 terminated", # host EC2 rugpulled->try again + "action": "RETRY" + } + ] + +class AWSBatchJobDefinition(BaseModel): + """A utility to validate AWS Batch job descriptions. Base class + for container and multinode job definition types.""" + + jobDefinitionName: str + type: str = 'container' + parameters: Dict[str,str] = {} + # schedulingPriority: int = 0 # ignored in FIFO queues + retryStrategy: AWSBatchJobDefinitionRetryStrategy = AWSBatchJobDefinitionRetryStrategy() + propagateTags: bool = False + timeout: Dict[str,int] = {'attemptDurationSeconds':3600} # key 'attemptDurationSeconds' + tags: Dict[str,str] = {} + platformCapabilities: List[Literal["EC2","FARGATE"]] + +class AWSBatchJobEC2Definition(AWSBatchJobDefinition): + containerProperties: AWSBatchJobDefinitionEC2ContainerProperties + platformCapabilities: list[Literal["EC2"]] = ["EC2"] + +class AWSBatchJobFargateDefinition(AWSBatchJobDefinition): + containerProperties: AWSBatchJobDefinitionFargateContainerProperties + platformCapabilities: list[Literal["FARGATE"]] = ["FARGATE"] + +class AWSBatchStepOperator(BaseStepOperator): + """Step operator to run a step on AWS Batch. + + This class defines code that builds an image with the ZenML entrypoint + to run using AWS Batch. + """ + + @property + def config(self) -> AWSBatchStepOperatorConfig: + """Returns the `AWSBatchStepOperatorConfig` config. + + Returns: + The configuration. + """ + return cast(AWSBatchStepOperatorConfig, self._config) + + @property + def settings_class(self) -> Optional[Type["BaseSettings"]]: + """Settings class for the AWS Batch step operator. + + Returns: + The settings class. + """ + return AWSBatchStepOperatorSettings + + def _get_aws_session(self) -> Session: + """Method to create the AWS Batch session with proper authentication. + + Returns: + The AWS Batch session. + + Raises: + RuntimeError: If the connector returns the wrong type for the + session. + """ + # Get authenticated session + # Option 1: Service connector + boto_session: Session + if connector := self.get_connector(): + boto_session = connector.connect() + if not isinstance(boto_session, Session): + raise RuntimeError( + f"Expected to receive a `boto3.Session` object from the " + f"linked connector, but got type `{type(boto_session)}`." + ) + # Option 2: Explicit configuration + # Args that are not provided will be taken from the default AWS config. + else: + boto_session = Session( + aws_access_key_id=self.config.aws_access_key_id, + aws_secret_access_key=self.config.aws_secret_access_key, + region_name=self.config.region, + profile_name=self.config.aws_profile, + ) + # If a role ARN is provided for authentication, assume the role + if self.config.aws_auth_role_arn: + sts = boto_session.client("sts") + response = sts.assume_role( + RoleArn=self.config.aws_auth_role_arn, + RoleSessionName="zenml-aws-batch-step-operator", + ) + credentials = response["Credentials"] + boto_session = Session( + aws_access_key_id=credentials["AccessKeyId"], + aws_secret_access_key=credentials["SecretAccessKey"], + aws_session_token=credentials["SessionToken"], + region_name=self.config.region, + ) + return boto_session + + @property + def validator(self) -> Optional[StackValidator]: + """Validates the stack. + + Returns: + A validator that checks that the stack contains a remote container + registry and a remote artifact store. + """ + + def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]: + if stack.artifact_store.config.is_local: + return False, ( + "The Batch step operator runs code remotely and " + "needs to write files into the artifact store, but the " + f"artifact store `{stack.artifact_store.name}` of the " + "active stack is local. Please ensure that your stack " + "contains a remote artifact store when using the Batch " + "step operator." + ) + + container_registry = stack.container_registry + assert container_registry is not None + + if container_registry.config.is_local: + return False, ( + "The Batch step operator runs code remotely and " + "needs to push/pull Docker images, but the " + f"container registry `{container_registry.name}` of the " + "active stack is local. Please ensure that your stack " + "contains a remote container registry when using the " + "Batch step operator." + ) + + return True, "" + + return StackValidator( + required_components={ + StackComponentType.CONTAINER_REGISTRY, + StackComponentType.IMAGE_BUILDER, + }, + custom_validation_function=_validate_remote_components, + ) + + @staticmethod + def map_environment(environment: Dict[str,str]) -> List[Dict[str,str]]: + """Utility to map the {name:value} environment to the + [{"name":name,"value":value},] convention used in the AWS Batch job + definition spec. + + Args: + environment: The step's environment variable + specification + + Returns: + The mapped environment variable specification + """ + + return [ + {"name":k,"value":v} for k,v in environment.items() + ] + + @staticmethod + def map_resource_settings(resource_settings: "ResourceSettings") -> List["ResourceRequirement"]: + """Utility to map the resource_settings to the resource convention used + in the AWS Batch Job definition spec. + + Args: + resource_settings: The step's resource settings. + + Returns: + The mapped resource settings. + """ + mapped_resource_settings = [] + + # handle cpu requirements + if resource_settings.cpu_count is not None: + cpu_requirement = ResourceRequirement(value=str(resource_settings.cpu_count),type="VCPU") + else: + cpu_requirement = ResourceRequirement(value="1",type="VCPU") + + mapped_resource_settings.append(cpu_requirement) + + # handle memory requirements + memory = resource_settings.get_memory(unit="MiB") + if memory: + memory_requirement = ResourceRequirement( + value=str(int(memory)), + type="MEMORY" + ) + else: + memory_requirement = ResourceRequirement(value="1024",type="MEMORY") + mapped_resource_settings.append(memory_requirement) + + # handle gpu requirements + if resource_settings.gpu_count is not None and resource_settings.gpu_count != 0: + mapped_resource_settings.append( + ResourceRequirement( + value=str(resource_settings.gpu_count), + type="GPU" + ) + ) + + return mapped_resource_settings + + @staticmethod + def sanitize_name(name: str) -> bool: + valid_characters = ascii_letters + digits + '-_' + sanitized_name = '' + for char in name: + sanitized_name += char if char in valid_characters else '-' + + return sanitized_name + + def generate_unique_batch_job_name(self, info: "StepRunInfo") -> str: + """Utility to generate a unique AWS Batch job name. + + Args: + info: The step run information. + + Returns: + A unique name for the step's AWS Batch job definition + """ + + # Batch allows 128 alphanumeric characters at maximum for job name. + # We sanitize the pipeline and step names before concatenating, + # capping at 115 chars and finally suffixing with a 6 character random + # string + + sanitized_pipeline_name = self.sanitize_name(info.pipeline.name) + sanitized_step_name = self.sanitize_name(sanitized_pipeline_name) + + job_name = f"{sanitized_pipeline_name}-{sanitized_step_name}"[:115] + suffix = random_str(6) + return f"{job_name}-{suffix}" + + def generate_job_definition(self, info: "StepRunInfo", entrypoint_command: List[str], environment: Dict[str,str]) -> AWSBatchJobDefinition: + """Utility to map zenml internal configurations to a valid AWS Batch + job definition.""" + + image_name = info.get_image(key=BATCH_DOCKER_IMAGE_KEY) + + resource_settings = info.config.resource_settings + step_settings = cast(AWSBatchStepOperatorSettings, self.get_settings(info)) + + if step_settings.environment: + environment.update(step_settings.environment) + + job_name = self.generate_unique_batch_job_name(info) + + if step_settings.backend == "EC2": + AWSBatchJobDefinitionClass = AWSBatchJobEC2Definition + AWSBatchContainerProperties = AWSBatchJobDefinitionEC2ContainerProperties + container_kwargs = {} + elif step_settings.backend == 'FARGATE': + AWSBatchJobDefinitionClass = AWSBatchJobFargateDefinition + AWSBatchContainerProperties = AWSBatchJobDefinitionFargateContainerProperties + container_kwargs = {'networkConfiguration': {"assignPublicIp":step_settings.assign_public_ip}} + + return AWSBatchJobDefinitionClass( + jobDefinitionName=job_name, + timeout={'attemptDurationSeconds':step_settings.timeout_seconds}, + type="container", + containerProperties=AWSBatchContainerProperties( + executionRoleArn=self.config.execution_role, + jobRoleArn=self.config.job_role, + image=image_name, + command=entrypoint_command, + environment=self.map_environment(environment), + resourceRequirements=self.map_resource_settings(resource_settings), + **container_kwargs + ) + ) + + + def get_docker_builds( + self, deployment: "PipelineDeploymentBase" + ) -> List["BuildConfiguration"]: + """Gets the Docker builds required for the component. + + Args: + deployment: The pipeline deployment for which to get the builds. + + Returns: + The required Docker builds. + """ + builds = [] + for step_name, step in deployment.step_configurations.items(): + if step.config.uses_step_operator(self.name): + build = BuildConfiguration( + key=BATCH_DOCKER_IMAGE_KEY, + settings=step.config.docker_settings, + step_name=step_name, + entrypoint=f"${_ENTRYPOINT_ENV_VARIABLE}", + ) + builds.append(build) + + return builds + + def launch( + self, + info: "StepRunInfo", + entrypoint_command: List[str], + environment: Dict[str, str], + ) -> None: + """Launches a step on AWS Batch. + + Args: + info: Information about the step run. + entrypoint_command: Command that executes the step. + environment: Environment variables to set in the step operator + environment. + + Raises: + RuntimeError: If the connector returns an object that is not a + `boto3.Session`. + """ + + job_definition = self.generate_job_definition(info, entrypoint_command, environment) + + logger.info(f"Job definition: {job_definition}") + + boto_session = self._get_aws_session() + batch_client = boto_session.client('batch') + + response = batch_client.register_job_definition( + **job_definition.model_dump() + ) + + job_definition_name = response['jobDefinitionName'] + + step_settings = cast(AWSBatchStepOperatorSettings, self.get_settings(info)) + + response = batch_client.submit_job( + jobName=job_definition.jobDefinitionName, + jobQueue=step_settings.job_queue_name if step_settings.job_queue_name else self.config.default_job_queue_name, + jobDefinition=job_definition_name, + ) + + job_id = response['jobId'] + + while True: + try: + response = batch_client.describe_jobs(jobs=[job_id]) + status = response['jobs'][0]['status'] + status_reason = response['jobs'][0].get('statusReason', 'Unknown') + + if status == 'SUCCEEDED': + logger.info(f"Job completed successfully: {job_id}") + break + elif status == "FAILED": + raise RuntimeError(f'Job {job_id} failed: {status_reason}') + else: + logger.info( + f"Job {job_id} neither failed nor succeeded. Status: " + f"{status}. Status reason: {status_reason}. Waiting " + "another 10 seconds." + ) + time.sleep(10) + except ClientError as e: + logger.error(f"Failed to describe job {job_id}: {e}") + raise diff --git a/tests/integration/integrations/aws/step_operators/__init__.py b/tests/integration/integrations/aws/step_operators/__init__.py new file mode 100644 index 00000000000..cd90a82cfc2 --- /dev/null +++ b/tests/integration/integrations/aws/step_operators/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. diff --git a/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py b/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py new file mode 100644 index 00000000000..2f0bac6e3e4 --- /dev/null +++ b/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py @@ -0,0 +1,241 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. + + +import pytest +from pydantic import ValidationError + +from zenml.config.resource_settings import ResourceSettings +from zenml.integrations.aws.step_operators.aws_batch_step_operator import ( + AWSBatchStepOperator, + ResourceRequirement, + AWSBatchJobEC2Definition, + AWSBatchJobFargateDefinition, + AWSBatchJobDefinitionEC2ContainerProperties, + AWSBatchJobDefinitionFargateContainerProperties, + VALID_FARGATE_MEMORY, + VALID_FARGATE_VCPU +) + +def test_aws_batch_step_operator_map_environment(): + """Tests the AWSBatchStepOperator's map_environment class method.""" + + test_environment = {'key_1':'value_1','key_2':'value_2'} + expected = [ + { + "name": "key_1", + "value": "value_1" + }, + { + "name": "key_2", + "value": "value_2" + } + ] + + assert AWSBatchStepOperator.map_environment(test_environment) == expected + +@pytest.mark.parametrize( + "test_resource_settings,expected", + [ + ( + ResourceSettings(), + [ + ResourceRequirement(value="1",type="VCPU"), + ResourceRequirement(value="1024",type="MEMORY") + ] + ), + ( + ResourceSettings(cpu_count=0.25,gpu_count=1,memory="10MiB"), + [ + ResourceRequirement(value="0.25",type="VCPU"), + ResourceRequirement(value="10",type="MEMORY"), + ResourceRequirement(value="1",type="GPU"), + ] + ), + ] +) +def test_aws_batch_step_operator_map_resource_settings(test_resource_settings,expected): + """Tests the AWSBatchStepOperator's map_resource_settings class method.""" + + assert AWSBatchStepOperator.map_resource_settings(test_resource_settings) == expected + +@pytest.mark.parametrize( + "test_name,expected", + [ + ('valid-name-123abcABC_', 'valid-name-123abcABC_'), + ('this!is@not"a£valid$name%123','this-is-not-a-valid-name-123') + ] +) +def test_aws_batch_step_operator_sanitize_name(test_name, expected): + + assert AWSBatchStepOperator.sanitize_name(test_name) == expected + +@pytest.mark.parametrize( + "test_requirements,expected", + [ + ( + [ + ResourceRequirement(value="0.4",type="VCPU"), + ResourceRequirement(value="100",type="MEMORY"), + ResourceRequirement(value="1",type="GPU") + ],[ + ResourceRequirement(value="1",type="VCPU"), + ResourceRequirement(value="100",type="MEMORY"), + ResourceRequirement(value="1",type="GPU") + ] + ), + ( + [ + ResourceRequirement(value="1.1",type="VCPU"), + ResourceRequirement(value="100",type="MEMORY"), + ],[ + ResourceRequirement(value="2",type="VCPU"), + ResourceRequirement(value="100",type="MEMORY"), + ] + ), + ] +) +def test_aws_batch_job_definition_ec2_container_properties_resource_validation(test_requirements, expected): + actual = AWSBatchJobDefinitionEC2ContainerProperties( + image="test-image", + command=["test","command"], + jobRoleArn="test-job-role-arn", + executionRoleArn="test-execution-role-arn", + resourceRequirements=test_requirements + ) + + assert actual.resourceRequirements == expected + +@pytest.mark.parametrize( + "test_vcpu_memory_indices", + [ + (i,j) for i in range(len(VALID_FARGATE_VCPU)) for j in range(len(VALID_FARGATE_MEMORY[VALID_FARGATE_VCPU[i]])) + ] + +) +def test_aws_batch_job_definition_fargate_container_properties(test_vcpu_memory_indices): + + vcpu_index, memory_index = test_vcpu_memory_indices + test_vcpu_value = VALID_FARGATE_VCPU[vcpu_index] + test_memory_value = VALID_FARGATE_MEMORY[test_vcpu_value][memory_index] + + test_valid_requirements = [ + ResourceRequirement( + type="VCPU", + value=test_vcpu_value + ), + ResourceRequirement( + type="MEMORY", + value=test_memory_value + ) + ] + + AWSBatchJobDefinitionFargateContainerProperties( + image="test-image", + command=["test","command"], + jobRoleArn="test-job-role-arn", + executionRoleArn="test-execution-role-arn", + resourceRequirements=test_valid_requirements + ) + +@pytest.mark.parametrize( + "test_invalid_requirements,expected_message", + [ + ( + [ + ResourceRequirement( + type="VCPU", + value="invalid-value" + ), + ResourceRequirement( + type="MEMORY", + value="irrelevant-value" + ) + ], + "Invalid fargate resource requirement VCPU value*" + ), + ( + [ + ResourceRequirement( + type="VCPU", + value="16" # valid + ), + ResourceRequirement( + type="MEMORY", + value="invalid-value" + ) + ], + "Invalid fargate resource requirement MEMORY value*" + ), + ( + [ + ResourceRequirement( + type="VCPU", + value="irrelevant-value" + ), + ResourceRequirement( + type="MEMORY", + value="irrelevant-value" + ), + ResourceRequirement( + type="GPU", + value="1" # invalid + ) + ], + "Invalid fargate resource requirement: GPU. Use EC2*" + ) + ] +) +def test_aws_batch_job_definition_fargate_container_properties_raise_invalid_requirements(test_invalid_requirements,expected_message): + + with pytest.raises(ValidationError,match=expected_message): + + AWSBatchJobDefinitionFargateContainerProperties( + image="test-image", + command=["test","command"], + jobRoleArn="test-job-role-arn", + executionRoleArn="test-execution-role-arn", + resourceRequirements=test_invalid_requirements + ) + +def test_aws_batch_job_ec2_definition(): + AWSBatchJobEC2Definition( + jobDefinitionName="test", + containerProperties=AWSBatchJobDefinitionEC2ContainerProperties( + image="test-image", + command=["test","command"], + jobRoleArn="test-job-role-arn", + executionRoleArn="test-execution-role-arn", + resourceRequirements=[ + ResourceRequirement(value="1",type="GPU"), + ResourceRequirement(value="1",type="VCPU"), + ResourceRequirement(value="1024",type="MEMORY") + ] + ) + ) + +def test_aws_batch_job_fargate_definition(): + AWSBatchJobFargateDefinition( + jobDefinitionName="test", + containerProperties=AWSBatchJobDefinitionFargateContainerProperties( + image="test-image", + command=["test","command"], + jobRoleArn="test-job-role-arn", + executionRoleArn="test-execution-role-arn", + resourceRequirements=[ + ResourceRequirement(value="0.5",type="VCPU"), + ResourceRequirement(value="3072",type="MEMORY") + ] + ) + ) \ No newline at end of file diff --git a/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator_flavor.py b/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator_flavor.py new file mode 100644 index 00000000000..6554d73b57e --- /dev/null +++ b/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator_flavor.py @@ -0,0 +1,8 @@ +from zenml.integrations.aws.flavors.aws_batch_step_operator_flavor import AWSBatchStepOperatorSettings + +def test_aws_batch_step_operator_settings(): + AWSBatchStepOperatorSettings( + job_queue_name='test-job-queue', + environment={"key_1":"value_1","key_2":"value_2"}, + timeout_seconds=3600 + ) \ No newline at end of file