From ce1de79bceb1dab2634d482fd26414ae4ff2e52d Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 12 Sep 2025 10:08:34 +0200 Subject: [PATCH 01/36] Add version 0.84.3 to legacy docs (#3949) Co-authored-by: ZenML GmbH (cherry picked from commit af530778c1ce819a4111a0178b4ee778919265cc) --- docs/book/reference/legacy-docs.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/book/reference/legacy-docs.md b/docs/book/reference/legacy-docs.md index 668b54d7ba3..06d2c2cd576 100644 --- a/docs/book/reference/legacy-docs.md +++ b/docs/book/reference/legacy-docs.md @@ -5,4 +5,4 @@ icon: user-bounty-hunter # Legacy docs -
0.84.2https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.84.2/
0.84.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.84.1/
0.84.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.84.0/
0.83.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.83.1/
0.83.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.83.0/
0.82.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.82.0/
0.81.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.81.0/
0.80.2https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.80.2/
0.80.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.80.1/
0.80.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.80.0/
0.75.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.75.0/
0.74.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.74.0/
0.73.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.73.0/
0.72.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.72.0
0.71.0https://zenml-io.gitbook.io/zenml-legacy-documentation/0.71.0
0.70.0https://zenml-io.gitbook.io/zenml-legacy-documentation/0.70.0
0.68.1https://zenml-io.gitbook.io/zenml-legacy-documentation/0.68.1
0.68.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.68.0/
0.67.0https://app.gitbook.com/o/-MCl1Hlw9oU4xibZ-ymf/s/IAT17KUXNgudlhTyCPn4/
0.66.0https://zenml-io.gitbook.io/zenml-legacy-documentation/0.66.0
0.65.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.65.0/
0.64.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.64.0/
0.63.0https://app.gitbook.com/o/-MCl1Hlw9oU4xibZ-ymf/s/N7uiChp9LXuqoNrLj3U3/
0.62.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.62.0/
0.61.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.61.0
0.60.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.60.0/
0.58.2https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.58.2/
0.58.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.58.1/
0.58.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.58.0
0.57.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.57.1
0.57.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.57.0
0.56.4https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.56.4
0.56.3https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.56.3
0.56.2https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.56.2
0.56.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.56.1
0.55.5https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.55.5
0.55.4https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.55.4
0.55.3https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.55.3
0.55.2https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.55.2
0.55.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.55.1
0.55.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.55.0
0.54.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.54.1
0.54.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.54.0
0.53.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.53.1
0.53.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.53.0
0.52.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.52.0
0.51.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.51.0
0.50.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.50.0
0.47.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.47.0-legacy
0.46.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.46.1-legacy
0.46.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.46.0-legacy
0.45.6https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.45.6-legacy
0.45.5https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.45.5-legacy
0.45.4https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.45.4-legacy
0.45.3https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.45.3-legacy
0.45.2https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.45.2-legacy
0.44.4https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.44.4-legacy
0.44.3https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.44.3-legacy
0.44.2https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.44.2-legacy
0.44.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.44.1-legacy
0.43.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.43.1-legacy
0.43.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.43.0-legacy
0.42.2https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.42.2-legacy
0.42.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.42.1-legacy
0.42.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.42.0-legacy
0.41.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.41.0-legacy
0.40.3https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.40.3-legacy
0.40.2https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.40.2-legacy
0.40.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.40.1-legacy
0.40.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.40.0-legacy
0.39.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.39.1-legacy
0.39.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.39.0-legacy
0.38.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.38.0-legacy
0.37.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.37.0-legacy
0.36.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.36.1-legacy
0.36.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.36.0-legacy
0.35.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.35.1-legacy
0.35.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.35.0-legacy
0.34.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.34.0-legacy
0.33.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.33.0-legacy
0.32.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.32.1-legacy
0.32.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.32.0-legacy
0.31.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.31.1-legacy
0.31.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.31.0-legacy
0.30.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.30.0-legacy
0.20.5https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.20.5-legacy
0.20.4https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.20.4-legacy
0.20.3https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.20.3-legacy
0.20.2https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.20.2-legacy
0.20.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.20.0-legacy
0.13.2https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.13.2
0.13.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.13.1
0.13.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.13.0
0.12.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.12.0
0.11.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.11.0
0.10.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.10.0
+
0.84.3https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.84.3/
0.84.2https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.84.2/
0.84.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.84.1/
0.84.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.84.0/
0.83.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.83.1/
0.83.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.83.0/
0.82.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.82.0/
0.81.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.81.0/
0.80.2https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.80.2/
0.80.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.80.1/
0.80.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.80.0/
0.75.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.75.0/
0.74.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.74.0/
0.73.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.73.0/
0.72.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.72.0
0.71.0https://zenml-io.gitbook.io/zenml-legacy-documentation/0.71.0
0.70.0https://zenml-io.gitbook.io/zenml-legacy-documentation/0.70.0
0.68.1https://zenml-io.gitbook.io/zenml-legacy-documentation/0.68.1
0.68.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.68.0/
0.67.0https://app.gitbook.com/o/-MCl1Hlw9oU4xibZ-ymf/s/IAT17KUXNgudlhTyCPn4/
0.66.0https://zenml-io.gitbook.io/zenml-legacy-documentation/0.66.0
0.65.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.65.0/
0.64.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.64.0/
0.63.0https://app.gitbook.com/o/-MCl1Hlw9oU4xibZ-ymf/s/N7uiChp9LXuqoNrLj3U3/
0.62.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.62.0/
0.61.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.61.0
0.60.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.60.0/
0.58.2https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.58.2/
0.58.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.58.1/
0.58.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.58.0
0.57.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.57.1
0.57.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.57.0
0.56.4https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.56.4
0.56.3https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.56.3
0.56.2https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.56.2
0.56.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.56.1
0.55.5https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.55.5
0.55.4https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.55.4
0.55.3https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.55.3
0.55.2https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.55.2
0.55.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.55.1
0.55.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.55.0
0.54.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.54.1
0.54.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.54.0
0.53.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.53.1
0.53.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.53.0
0.52.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.52.0
0.51.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.51.0
0.50.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.50.0
0.47.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.47.0-legacy
0.46.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.46.1-legacy
0.46.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.46.0-legacy
0.45.6https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.45.6-legacy
0.45.5https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.45.5-legacy
0.45.4https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.45.4-legacy
0.45.3https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.45.3-legacy
0.45.2https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.45.2-legacy
0.44.4https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.44.4-legacy
0.44.3https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.44.3-legacy
0.44.2https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.44.2-legacy
0.44.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.44.1-legacy
0.43.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.43.1-legacy
0.43.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.43.0-legacy
0.42.2https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.42.2-legacy
0.42.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.42.1-legacy
0.42.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.42.0-legacy
0.41.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.41.0-legacy
0.40.3https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.40.3-legacy
0.40.2https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.40.2-legacy
0.40.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.40.1-legacy
0.40.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.40.0-legacy
0.39.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.39.1-legacy
0.39.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.39.0-legacy
0.38.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.38.0-legacy
0.37.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.37.0-legacy
0.36.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.36.1-legacy
0.36.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.36.0-legacy
0.35.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.35.1-legacy
0.35.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.35.0-legacy
0.34.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.34.0-legacy
0.33.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.33.0-legacy
0.32.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.32.1-legacy
0.32.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.32.0-legacy
0.31.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.31.1-legacy
0.31.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.31.0-legacy
0.30.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.30.0-legacy
0.20.5https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.20.5-legacy
0.20.4https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.20.4-legacy
0.20.3https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.20.3-legacy
0.20.2https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.20.2-legacy
0.20.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.20.0-legacy
0.13.2https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.13.2
0.13.1https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.13.1
0.13.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.13.0
0.12.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.12.0
0.11.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.11.0
0.10.0https://zenml-io.gitbook.io/zenml-legacy-documentation/v/0.10.0
From e87d336343a5a01b66f44013edf3d28d998ce515 Mon Sep 17 00:00:00 2001 From: SebastianScherer88 Date: Sat, 13 Sep 2025 00:40:49 +0100 Subject: [PATCH 02/36] started creating required files and mapping out the zenml config -> aws batch job definition conversion --- src/zenml/integrations/aws/__init__.py | 4 + .../integrations/aws/flavors/__init__.py | 7 + .../aws/flavors/batch_step_operator_flavor.py | 157 +++++++++ .../aws/step_operators/__init__.py | 8 +- .../aws/step_operators/batch_step_operator.py | 304 ++++++++++++++++++ .../batch_step_operator_entrypoint_config.py | 39 +++ .../step_operators/sagemaker_step_operator.py | 4 +- 7 files changed, 518 insertions(+), 5 deletions(-) create mode 100644 src/zenml/integrations/aws/flavors/batch_step_operator_flavor.py create mode 100644 src/zenml/integrations/aws/step_operators/batch_step_operator.py create mode 100644 src/zenml/integrations/aws/step_operators/batch_step_operator_entrypoint_config.py diff --git a/src/zenml/integrations/aws/__init__.py b/src/zenml/integrations/aws/__init__.py index d6cf5866fa0..231d81a6bb0 100644 --- a/src/zenml/integrations/aws/__init__.py +++ b/src/zenml/integrations/aws/__init__.py @@ -28,6 +28,7 @@ AWS_CONTAINER_REGISTRY_FLAVOR = "aws" AWS_SAGEMAKER_STEP_OPERATOR_FLAVOR = "sagemaker" AWS_SAGEMAKER_ORCHESTRATOR_FLAVOR = "sagemaker" +AWS_BATCH_STEP_OPERATOR_FLAVOR = "aws_batch" # Service connector constants AWS_CONNECTOR_TYPE = "aws" @@ -42,6 +43,7 @@ class AWSIntegration(Integration): NAME = AWS REQUIREMENTS = [ "sagemaker>=2.237.3", + "boto3>=1.40.30", "kubernetes", "aws-profile-manager", ] @@ -64,6 +66,7 @@ def flavors(cls) -> List[Type[Flavor]]: AWSImageBuilderFlavor, SagemakerOrchestratorFlavor, SagemakerStepOperatorFlavor, + AWSBatchStepOperatorFlavor ) return [ @@ -71,4 +74,5 @@ def flavors(cls) -> List[Type[Flavor]]: AWSImageBuilderFlavor, SagemakerStepOperatorFlavor, SagemakerOrchestratorFlavor, + AWSBatchStepOperatorFlavor ] diff --git a/src/zenml/integrations/aws/flavors/__init__.py b/src/zenml/integrations/aws/flavors/__init__.py index a2cdc428add..d5c0c35ea64 100644 --- a/src/zenml/integrations/aws/flavors/__init__.py +++ b/src/zenml/integrations/aws/flavors/__init__.py @@ -29,6 +29,10 @@ SagemakerStepOperatorConfig, SagemakerStepOperatorFlavor, ) +from zenml.integrations.aws.flavors.batch_step_operator_flavor import ( + AWSBatchStepOperatorConfig, + AWSBatchStepOperatorFlavor +) __all__ = [ "AWSContainerRegistryFlavor", @@ -39,4 +43,7 @@ "SagemakerStepOperatorConfig", "SagemakerOrchestratorFlavor", "SagemakerOrchestratorConfig", + "AWSBatchStepOperatorFlavor", + "AWSBatchStepOperatorConfig", + ] diff --git a/src/zenml/integrations/aws/flavors/batch_step_operator_flavor.py b/src/zenml/integrations/aws/flavors/batch_step_operator_flavor.py new file mode 100644 index 00000000000..be94a8f29a4 --- /dev/null +++ b/src/zenml/integrations/aws/flavors/batch_step_operator_flavor.py @@ -0,0 +1,157 @@ +# Copyright (c) ZenML GmbH 2022. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Amazon SageMaker step operator flavor.""" + +from typing import TYPE_CHECKING, Any, Dict, Optional, Type, Union + +from pydantic import Field + +from zenml.config.base_settings import BaseSettings +from zenml.integrations.aws import ( + AWS_RESOURCE_TYPE, + AWS_BATCH_STEP_OPERATOR_FLAVOR, +) +from zenml.models import ServiceConnectorRequirements +from zenml.step_operators.base_step_operator import ( + BaseStepOperatorConfig, + BaseStepOperatorFlavor, +) +from zenml.utils import deprecation_utils + +if TYPE_CHECKING: + from zenml.integrations.aws.step_operators import AWSBatchStepOperator + + +class AWSBatchStepOperatorSettings(BaseSettings): + """Settings for the Sagemaker step operator.""" + + instance_type: Optional[str] = Field( + None, + description="DEPRECATED: The instance type to use for the step execution. " + "Use estimator_args instead. Example: 'ml.m5.xlarge'", + ) + environment: Dict[str, str] = Field( + default_factory=dict, + description="Environment variables to pass to the container during execution. " + "Example: {'LOG_LEVEL': 'INFO', 'DEBUG_MODE': 'False'}", + ) + + _deprecation_validator = deprecation_utils.deprecate_pydantic_attributes( + "instance_type" + ) + + +class AWSBatchStepOperatorConfig( + BaseStepOperatorConfig, AWSBatchStepOperatorSettings +): + """Config for the AWS Batch step operator.""" + + execution_role: str = Field( + "", + description="The ECS execution role required to execute the AWS Batch" \ + " jobs as ECS tasks." + ) + job_role: str = Field( + "", + description="The ECS job role required by the container runtime insdide" \ + "the ECS task implementing the zenml step." + ) + + @property + def is_remote(self) -> bool: + """Checks if this stack component is running remotely. + + This designation is used to determine if the stack component can be + used with a local ZenML database or if it requires a remote ZenML + server. + + Returns: + True if this config is for a remote component, False otherwise. + """ + return True + + +class AWSBatchStepOperatorFlavor(BaseStepOperatorFlavor): + """Flavor for the AWS Batch step operator.""" + + @property + def name(self) -> str: + """Name of the flavor. + + Returns: + The name of the flavor. + """ + return AWS_BATCH_STEP_OPERATOR_FLAVOR + + @property + def service_connector_requirements( + self, + ) -> Optional[ServiceConnectorRequirements]: + """Service connector resource requirements for service connectors. + + Specifies resource requirements that are used to filter the available + service connector types that are compatible with this flavor. + + Returns: + Requirements for compatible service connectors, if a service + connector is required for this flavor. + """ + return ServiceConnectorRequirements(resource_type=AWS_RESOURCE_TYPE) + + @property + def docs_url(self) -> Optional[str]: + """A url to point at docs explaining this flavor. + + Returns: + A flavor docs url. + """ + return self.generate_default_docs_url() + + @property + def sdk_docs_url(self) -> Optional[str]: + """A url to point at SDK docs explaining this flavor. + + Returns: + A flavor SDK docs url. + """ + return self.generate_default_sdk_docs_url() + + @property + def logo_url(self) -> str: + """A url to represent the flavor in the dashboard. + + Returns: + The flavor logo. + """ + return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/step_operator/aws_batch.png" + + @property + def config_class(self) -> Type[AWSBatchStepOperatorConfig]: + """Returns BatchStepOperatorConfig config class. + + Returns: + The config class. + """ + return AWSBatchStepOperatorConfig + + @property + def implementation_class(self) -> Type["AWSBatchStepOperator"]: + """Implementation class. + + Returns: + The implementation class. + """ + from zenml.integrations.aws.step_operators import AWSBatchStepOperator + + return AWSBatchStepOperator diff --git a/src/zenml/integrations/aws/step_operators/__init__.py b/src/zenml/integrations/aws/step_operators/__init__.py index 9eee3140d43..b2d32099141 100644 --- a/src/zenml/integrations/aws/step_operators/__init__.py +++ b/src/zenml/integrations/aws/step_operators/__init__.py @@ -13,8 +13,10 @@ # permissions and limitations under the License. """Initialization of the Sagemaker Step Operator.""" -from zenml.integrations.aws.step_operators.sagemaker_step_operator import ( # noqa +from zenml.integrations.aws.step_operators.sagemaker_step_operator import ( # noqa: F401 SagemakerStepOperator, ) - -__all__ = ["SagemakerStepOperator"] +from zenml.integrations.aws.step_operators.batch_step_operator import ( # noqa: F401 + AWSBatchStepOperator, +) +__all__ = ["SagemakerStepOperator","AWSBatchStepOperator"] diff --git a/src/zenml/integrations/aws/step_operators/batch_step_operator.py b/src/zenml/integrations/aws/step_operators/batch_step_operator.py new file mode 100644 index 00000000000..dd296ce3091 --- /dev/null +++ b/src/zenml/integrations/aws/step_operators/batch_step_operator.py @@ -0,0 +1,304 @@ +# Copyright (c) ZenML GmbH 2022. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Implementation of the Sagemaker Step Operator.""" + +import time +from typing import ( + TYPE_CHECKING, + Dict, + List, + Optional, + Tuple, + Type, + Union, + Literal, + cast, +) +from pydantic import BaseModel + +import boto3 +from sagemaker.estimator import Estimator +from sagemaker.inputs import TrainingInput +from sagemaker.session import Session + +from zenml.client import Client +from zenml.config.build_configuration import BuildConfiguration +from zenml.enums import StackComponentType +from zenml.integrations.aws.flavors.batch_step_operator_flavor import ( + AWSBatchStepOperatorConfig, + AWSBatchStepOperatorSettings, +) +from zenml.integrations.aws.step_operators.batch_step_operator_entrypoint_config import ( + BATCH_STEP_ENV_VAR_SIZE_LIMIT, + AWSBatchEntrypointConfiguration, +) +from zenml.logger import get_logger +from zenml.stack import Stack, StackValidator +from zenml.step_operators import BaseStepOperator +from zenml.step_operators.step_operator_entrypoint_configuration import ( + StepOperatorEntrypointConfiguration, +) +from zenml.utils.env_utils import split_environment_variables +from zenml.utils.string_utils import random_str + +if TYPE_CHECKING: + from zenml.config.base_settings import BaseSettings + from zenml.config.step_run_info import StepRunInfo + from zenml.models import PipelineDeploymentBase + +logger = get_logger(__name__) + +BATCH_DOCKER_IMAGE_KEY = "batch_step_operator" +_ENTRYPOINT_ENV_VARIABLE = "__ZENML_ENTRYPOINT" + +class AWSBatchJobDefinitionContainerProperties(BaseModel): + image: str + command: List[str] + jobRoleArn: str + executionRoleArn: str + environment: List[Dict[str,str]] = [] # keys: 'name','value' + instanceType: str + resourceRequirements: List[Dict[str,str]] = [] # keys: 'value','type', with type one of 'GPU','VCPU','MEMORY' + secrets: List[Dict[str,str]] = [] # keys: 'name','value' + +class AWSBatchJobDefinitionNodePropertiesNodeRangeProperty(BaseModel): + targetNodes: str + container: AWSBatchJobDefinitionContainerProperties + +class AWSBatchJobDefinitionNodeProperties(BaseModel): + # we include this class for completeness sake to make it easier + # to add multinode support later + # for now, we'll set defaults to intuitively represent the only supported + # exeuction type ('container'); in reality AWS Batch will ignore this + # config + numNodes: int = 1 + mainNode: int = 0 + nodeRangeProperties: List[AWSBatchJobDefinitionNodePropertiesNodeRangeProperty] = [] + +class AWSBatchJobDefinitionRetryStrategy(BaseModel): + attempts: int = 2 + evaluateOnExit: List[Dict[str,str]] = [ + { + "onExitCode": "137", # out-of-memory killed + "action": "RETRY" + }, + { + "onReason": "*Host EC2*", + "action": "RETRY" + }, + { + "onExitCode": "*", # match everything else + "action": "EXIT" + } + ] + # Example: + # { + # 'onStatusReason': 'string', + # 'onReason': 'string', + # 'onExitCode': 'string', + # 'action': 'RETRY'|'EXIT' + # }, + +class AWSBatchJobDefinition(BaseModel): + jobDefinitionName: str + type: Literal['container','multinode'] = 'container' # we dont support multinode type in this version + parameters: Dict[str,str] = {} + schedulingPriority: int = 0 # ignored in FIFO queues + containerProperties: AWSBatchJobDefinitionContainerProperties + nodeProperties: AWSBatchJobDefinitionNodeProperties = AWSBatchJobDefinitionNodeProperties( + numNodes=1,mainNode=0,nodeRangeProperties=[]) # we'll focus on container mode for now - let's add multinode support later, as that will most likely require network configuration support as well + retryStrategy: AWSBatchJobDefinitionRetryStrategy = AWSBatchJobDefinitionRetryStrategy() + propagateTags: bool = False + timeout: Dict[str,int] = {'attemptDurationSeconds':60} # key 'attemptDurationSeconds' + tags: Dict[str,str] = {} + platformCapabilities: Literal['EC2','FARGATE'] = "EC2" #-- hardcode this to EC2, so we can use container and multinode interchangeably without worrying too much + + +class AWSBatchStepOperator(BaseStepOperator): + """Step operator to run a step on Sagemaker. + + This class defines code that builds an image with the ZenML entrypoint + to run using Sagemaker's Estimator. + """ + + @property + def config(self) -> AWSBatchStepOperatorConfig: + """Returns the `SagemakerStepOperatorConfig` config. + + Returns: + The configuration. + """ + return cast(AWSBatchStepOperatorConfig, self._config) + + @property + def settings_class(self) -> Optional[Type["BaseSettings"]]: + """Settings class for the SageMaker step operator. + + Returns: + The settings class. + """ + return AWSBatchStepOperatorSettings + + @property + def entrypoint_config_class( + self, + ) -> Type[StepOperatorEntrypointConfiguration]: + """Returns the entrypoint configuration class for this step operator. + + Returns: + The entrypoint configuration class for this step operator. + """ + return AWSBatchEntrypointConfiguration + + @property + def validator(self) -> Optional[StackValidator]: + """Validates the stack. + + Returns: + A validator that checks that the stack contains a remote container + registry and a remote artifact store. + """ + + def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]: + if stack.artifact_store.config.is_local: + return False, ( + "The Batch step operator runs code remotely and " + "needs to write files into the artifact store, but the " + f"artifact store `{stack.artifact_store.name}` of the " + "active stack is local. Please ensure that your stack " + "contains a remote artifact store when using the Batch " + "step operator." + ) + + container_registry = stack.container_registry + assert container_registry is not None + + if container_registry.config.is_local: + return False, ( + "The Batch step operator runs code remotely and " + "needs to push/pull Docker images, but the " + f"container registry `{container_registry.name}` of the " + "active stack is local. Please ensure that your stack " + "contains a remote container registry when using the " + "Batch step operator." + ) + + return True, "" + + return StackValidator( + required_components={ + StackComponentType.CONTAINER_REGISTRY, + StackComponentType.IMAGE_BUILDER, + }, + custom_validation_function=_validate_remote_components, + ) + + + def generate_job_definition(self, info: "StepRunInfo", entrypoint_command: List[str], environment: Dict[str,str]) -> AWSBatchJobDefinition: + """Utility to map zenml internal configurations to a valid AWS Batch + job definition.""" + pass + + + def get_docker_builds( + self, deployment: "PipelineDeploymentBase" + ) -> List["BuildConfiguration"]: + """Gets the Docker builds required for the component. + + Args: + deployment: The pipeline deployment for which to get the builds. + + Returns: + The required Docker builds. + """ + builds = [] + for step_name, step in deployment.step_configurations.items(): + if step.config.uses_step_operator(self.name): + build = BuildConfiguration( + key=BATCH_DOCKER_IMAGE_KEY, + settings=step.config.docker_settings, + step_name=step_name, + entrypoint=f"${_ENTRYPOINT_ENV_VARIABLE}", + ) + builds.append(build) + + return builds + + def launch( + self, + info: "StepRunInfo", + entrypoint_command: List[str], + environment: Dict[str, str], + ) -> None: + """Launches a step on AWS Batch. + + Args: + info: Information about the step run. + entrypoint_command: Command that executes the step. + environment: Environment variables to set in the step operator + environment. + + Raises: + RuntimeError: If the connector returns an object that is not a + `boto3.Session`. + """ + if not info.config.resource_settings.empty: + logger.warning( + "Specifying custom step resources is not supported for " + "the AWS Batch step operator. If you want to run this step " + "operator on specific resources, you can do so by configuring " + "a different instance type like this: " + "`zenml step-operator update %s " + "--instance_type=`", + self.name, + ) + + image_name = info.get_image(key=BATCH_DOCKER_IMAGE_KEY) + + settings = cast(AWSBatchStepOperatorSettings, self.get_settings(info)) + + batch = boto3.client('batch') + + # Batch allows 63 characters at maximum for job name - ZenML uses 60 for safety margin. + step_name = Client().get_run_step(info.step_run_id).name + training_job_name = f"{info.pipeline.name}-{step_name}"[:55] + suffix = random_str(4) + unique_training_job_name = f"{training_job_name}-{suffix}" + + response = batch.register_job_definition( + jobDefinitionName=unique_training_job_name, + type='container', + containerProperties={ + 'image': image_name , + 'command': entrypoint_command, + } + ) + + job_definition = response['jobDefinitionName'] + + response = batch.submit_job( + jobName=unique_training_job_name, + jobQueue=self.config.job_queue_name, + jobDefinition=job_definition, + ) + + job_id = response['jobId'] + + while True: + response = batch.describe_jobs(jobs=[job_id]) + status = response['jobs'][0]['status'] + if status in ['SUCCEEDED', 'FAILED']: + break + time.sleep(10) + logger.info(f'Job completed with status {status}') \ No newline at end of file diff --git a/src/zenml/integrations/aws/step_operators/batch_step_operator_entrypoint_config.py b/src/zenml/integrations/aws/step_operators/batch_step_operator_entrypoint_config.py new file mode 100644 index 00000000000..9cc99c6c4c8 --- /dev/null +++ b/src/zenml/integrations/aws/step_operators/batch_step_operator_entrypoint_config.py @@ -0,0 +1,39 @@ +# Copyright (c) ZenML GmbH 2023. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Entrypoint configuration for ZenML Batch step operator.""" + +from zenml.step_operators.step_operator_entrypoint_configuration import ( + StepOperatorEntrypointConfiguration, +) +from zenml.utils.env_utils import reconstruct_environment_variables + +BATCH_STEP_ENV_VAR_SIZE_LIMIT = 512 + + +class AWSBatchEntrypointConfiguration(StepOperatorEntrypointConfiguration): + """Entrypoint configuration for ZenML Batch step operator. + + The only purpose of this entrypoint configuration is to reconstruct the + environment variables that exceed the maximum length of 512 characters + allowed for Batch steps from their individual components. + """ + + def run(self) -> None: + """Runs the step.""" + # Reconstruct the environment variables that exceed the maximum length + # of 512 characters from their individual chunks + reconstruct_environment_variables() + + # Run the step + super().run() diff --git a/src/zenml/integrations/aws/step_operators/sagemaker_step_operator.py b/src/zenml/integrations/aws/step_operators/sagemaker_step_operator.py index 6456d4d9a37..2df9a81d21e 100644 --- a/src/zenml/integrations/aws/step_operators/sagemaker_step_operator.py +++ b/src/zenml/integrations/aws/step_operators/sagemaker_step_operator.py @@ -121,12 +121,12 @@ def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]: if container_registry.config.is_local: return False, ( - "The SageMaker step operator runs code remotely and " + "The Batch step operator runs code remotely and " "needs to push/pull Docker images, but the " f"container registry `{container_registry.name}` of the " "active stack is local. Please ensure that your stack " "contains a remote container registry when using the " - "SageMaker step operator." + "Batch step operator." ) return True, "" From 6b076d12a75944f844d730dedbdd506c06512341 Mon Sep 17 00:00:00 2001 From: SebastianScherer88 Date: Sat, 13 Sep 2025 16:29:21 +0100 Subject: [PATCH 03/36] finished first draft of aws batch step operator --- .../aws/flavors/batch_step_operator_flavor.py | 34 ++-- .../aws/step_operators/batch_step_operator.py | 180 ++++++++++++++---- 2 files changed, 164 insertions(+), 50 deletions(-) diff --git a/src/zenml/integrations/aws/flavors/batch_step_operator_flavor.py b/src/zenml/integrations/aws/flavors/batch_step_operator_flavor.py index be94a8f29a4..ee60a46f3e6 100644 --- a/src/zenml/integrations/aws/flavors/batch_step_operator_flavor.py +++ b/src/zenml/integrations/aws/flavors/batch_step_operator_flavor.py @@ -37,35 +37,43 @@ class AWSBatchStepOperatorSettings(BaseSettings): """Settings for the Sagemaker step operator.""" instance_type: Optional[str] = Field( - None, - description="DEPRECATED: The instance type to use for the step execution. " - "Use estimator_args instead. Example: 'ml.m5.xlarge'", + 'optimal', + description="The instance type for AWS Batch to use for the step" \ + " execution. Example: 'm5.xlarge'", ) environment: Dict[str, str] = Field( default_factory=dict, - description="Environment variables to pass to the container during execution. " - "Example: {'LOG_LEVEL': 'INFO', 'DEBUG_MODE': 'False'}", + description="Environment variables to pass to the container during " \ + "execution. Example: {'LOG_LEVEL': 'INFO', 'DEBUG_MODE': 'False'}", ) - - _deprecation_validator = deprecation_utils.deprecate_pydantic_attributes( - "instance_type" + timeout_seconds: int = Field( + default=120, + description="The number of seconds before AWS Batch times out the job." ) - class AWSBatchStepOperatorConfig( BaseStepOperatorConfig, AWSBatchStepOperatorSettings ): - """Config for the AWS Batch step operator.""" + """Config for the AWS Batch step operator. + + Note: We use ECS as a backend (not EKS), and EC2 as a compute engine (not + Fargate). This is because + - users can avoid the complexity of setting up an EKS cluster, and + - we can AWS Batch multinode type job support later, which requires EC2 + """ execution_role: str = Field( "", description="The ECS execution role required to execute the AWS Batch" \ - " jobs as ECS tasks." + " jobs as an ECS tasks." ) job_role: str = Field( "", - description="The ECS job role required by the container runtime insdide" \ - "the ECS task implementing the zenml step." + description="The ECS job role required by the container runtime inside" \ + "the ECS task." + ) + job_queue_name: str = Field( + description="The AWS Batch job queue to submit AWS Batch jobs to." ) @property diff --git a/src/zenml/integrations/aws/step_operators/batch_step_operator.py b/src/zenml/integrations/aws/step_operators/batch_step_operator.py index dd296ce3091..b8e93745055 100644 --- a/src/zenml/integrations/aws/step_operators/batch_step_operator.py +++ b/src/zenml/integrations/aws/step_operators/batch_step_operator.py @@ -54,6 +54,7 @@ if TYPE_CHECKING: from zenml.config.base_settings import BaseSettings + from zenml.config import ResourceSettings from zenml.config.step_run_info import StepRunInfo from zenml.models import PipelineDeploymentBase @@ -63,6 +64,7 @@ _ENTRYPOINT_ENV_VARIABLE = "__ZENML_ENTRYPOINT" class AWSBatchJobDefinitionContainerProperties(BaseModel): + """An AWS Batch job subconfiguration model for a container specification.""" image: str command: List[str] jobRoleArn: str @@ -73,20 +75,31 @@ class AWSBatchJobDefinitionContainerProperties(BaseModel): secrets: List[Dict[str,str]] = [] # keys: 'name','value' class AWSBatchJobDefinitionNodePropertiesNodeRangeProperty(BaseModel): + """An AWS Batch job subconfiguration model for a node in a multinode job + specifications. + + Note: We include this class for completeness sake to make it easier to add + multinode support later for now. + """ targetNodes: str container: AWSBatchJobDefinitionContainerProperties class AWSBatchJobDefinitionNodeProperties(BaseModel): - # we include this class for completeness sake to make it easier - # to add multinode support later - # for now, we'll set defaults to intuitively represent the only supported - # exeuction type ('container'); in reality AWS Batch will ignore this - # config + """An AWS Batch job subconfiguration model for multinode job specifications. + + Note: We include this class for completeness sake to make it easier to add + multinode support later for now, we'll set defaults to intuitively + represent the only supported exeuction type ('container'); in reality AWS + Batch will ignore this config. + """ numNodes: int = 1 mainNode: int = 0 - nodeRangeProperties: List[AWSBatchJobDefinitionNodePropertiesNodeRangeProperty] = [] + nodeRangeProperties: List[ + AWSBatchJobDefinitionNodePropertiesNodeRangeProperty + ] = [] class AWSBatchJobDefinitionRetryStrategy(BaseModel): + """An AWS Batch job subconfiguration model for retry specifications.""" attempts: int = 2 evaluateOnExit: List[Dict[str,str]] = [ { @@ -102,15 +115,15 @@ class AWSBatchJobDefinitionRetryStrategy(BaseModel): "action": "EXIT" } ] - # Example: - # { - # 'onStatusReason': 'string', - # 'onReason': 'string', - # 'onExitCode': 'string', - # 'action': 'RETRY'|'EXIT' - # }, class AWSBatchJobDefinition(BaseModel): + """A utility to validate AWS Batch job descriptions. + + Defaults fall into two categories: + - reasonable default values + - aligning the job description to be a valid 'container' type configuration, + as multinode jobs are not supported yet.""" + jobDefinitionName: str type: Literal['container','multinode'] = 'container' # we dont support multinode type in this version parameters: Dict[str,str] = {} @@ -126,15 +139,15 @@ class AWSBatchJobDefinition(BaseModel): class AWSBatchStepOperator(BaseStepOperator): - """Step operator to run a step on Sagemaker. + """Step operator to run a step on AWS Batch. This class defines code that builds an image with the ZenML entrypoint - to run using Sagemaker's Estimator. + to run using AWS Batch. """ @property def config(self) -> AWSBatchStepOperatorConfig: - """Returns the `SagemakerStepOperatorConfig` config. + """Returns the `AWSBatchStepOperatorConfig` config. Returns: The configuration. @@ -143,7 +156,7 @@ def config(self) -> AWSBatchStepOperatorConfig: @property def settings_class(self) -> Optional[Type["BaseSettings"]]: - """Settings class for the SageMaker step operator. + """Settings class for the AWS Batch step operator. Returns: The settings class. @@ -204,11 +217,117 @@ def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]: custom_validation_function=_validate_remote_components, ) + @staticmethod + def map_environment(environment: Dict[str,str]) -> List[Dict[str,str]]: + """Utility to map the {name:value} environment to the + [{"name":name,"value":value},] convention used in the AWS Batch job + definition spec. + + Args: + environment (Dict[str,str]): The step's environment variable + specification + + Returns: + List[Dict[str,str]]: The mapped environment variable specification + """ + + return [ + {"name":k,"value":v} for k,v in environment + ] + + @staticmethod + def map_resource_settings(resource_settings: ResourceSettings) -> List[Dict[str,str]]: + """Utility to map the resource_settings to the resource convention used + in the AWS Batch Job definition spec. + + Args: + resource_settings (ResourceSettings): The step's resource settings. + + Returns: + List[Dict[str,str]]: The mapped resource settings. + """ + mapped_resource_settings = [] + + if resource_settings.empty: + return mapped_resource_settings + else: + + if resource_settings.cpu_count is not None: + mapped_resource_settings.append( + { + "value": resource_settings.cpu_count, + "type": 'VCPU' + } + ) + + if resource_settings.gpu_count is not None: + mapped_resource_settings.append( + { + "value": resource_settings.gpu_count, + "type": 'GPU' + } + ) + + if resource_settings.get_memory() is not None: + mapped_resource_settings.append( + { + "value": resource_settings.get_memory(), + "type": 'MEMORY' + } + ) + + return mapped_resource_settings + + @staticmethod + def generate_unique_batch_job_name(info: "StepRunInfo") -> str: + """Utility to generate a unique AWS Batch job name. + + Args: + info (StepRunInfo): The step run information. + + Returns: + str: A unique name for the step's AWS Batch job definition + """ + + # Batch allows 63 characters at maximum for job name - ZenML uses 60 for safety margin. + step_name = Client().get_run_step(info.step_run_id).name + job_name = f"{info.pipeline.name}-{step_name}"[:55] + suffix = random_str(4) + return f"{job_name}-{suffix}" def generate_job_definition(self, info: "StepRunInfo", entrypoint_command: List[str], environment: Dict[str,str]) -> AWSBatchJobDefinition: """Utility to map zenml internal configurations to a valid AWS Batch job definition.""" - pass + + image_name = info.get_image(key=BATCH_DOCKER_IMAGE_KEY) + + resource_settings = info.config.resource_settings + step_settings = cast(AWSBatchStepOperatorSettings, self.get_settings(info)) + + job_name = self.generate_unique_batch_job_name(info) + + return AWSBatchJobDefinition( + jobDefinitionName=job_name, + containerProperties=AWSBatchJobDefinitionContainerProperties( + executionRoleArn=self.config.execution_role, + jobRoleArn=self.config.job_role, + image=image_name, + command=entrypoint_command, + environment=self.map_environment(environment), + instanceType=step_settings.instance_type, + resourceRequirements=self.map_resource_settings(resource_settings), + ), + timeout={'attemptDurationSeconds':step_settings.timeout_seconds}, + # type: Literal['container','multinode'] = 'container' # we dont support multinode type in this version + # parameters: Dict[str,str] = {} + # schedulingPriority: int = 0 # ignored in FIFO queues + # nodeProperties: AWSBatchJobDefinitionNodeProperties = AWSBatchJobDefinitionNodeProperties( + # numNodes=1,mainNode=0,nodeRangeProperties=[]) # we'll focus on container mode for now - let's add multinode support later, as that will most likely require network configuration support as well + # retryStrategy: AWSBatchJobDefinitionRetryStrategy = AWSBatchJobDefinitionRetryStrategy() + # propagateTags: bool = False + # tags: Dict[str,str] = {} + # platformCapabilities: Literal['EC2','FARGATE'] = "EC2" + ) def get_docker_builds( @@ -264,33 +383,20 @@ def launch( self.name, ) - image_name = info.get_image(key=BATCH_DOCKER_IMAGE_KEY) - - settings = cast(AWSBatchStepOperatorSettings, self.get_settings(info)) + job_definition = self.generate_job_definition(info, entrypoint_command, environment) batch = boto3.client('batch') - - # Batch allows 63 characters at maximum for job name - ZenML uses 60 for safety margin. - step_name = Client().get_run_step(info.step_run_id).name - training_job_name = f"{info.pipeline.name}-{step_name}"[:55] - suffix = random_str(4) - unique_training_job_name = f"{training_job_name}-{suffix}" - + response = batch.register_job_definition( - jobDefinitionName=unique_training_job_name, - type='container', - containerProperties={ - 'image': image_name , - 'command': entrypoint_command, - } + **job_definition.model_dump() ) - job_definition = response['jobDefinitionName'] + job_definition_name = response['jobDefinitionName'] response = batch.submit_job( - jobName=unique_training_job_name, + jobName=job_definition.jobDefinitionName, jobQueue=self.config.job_queue_name, - jobDefinition=job_definition, + jobDefinition=job_definition_name, ) job_id = response['jobId'] From c6f2a877126d23388ac6114176307f1f44a0b8e8 Mon Sep 17 00:00:00 2001 From: SebastianScherer88 Date: Sat, 13 Sep 2025 20:48:41 +0100 Subject: [PATCH 04/36] renaming modules and adding unit tests --- pyproject.toml | 2 +- .../integrations/aws/flavors/__init__.py | 2 +- ...r.py => aws_batch_step_operator_flavor.py} | 9 +- .../aws/step_operators/__init__.py | 2 +- ...operator.py => aws_batch_step_operator.py} | 30 +++---- ..._batch_step_operator_entrypoint_config.py} | 19 +--- .../aws/step_operators/__init__.py | 13 +++ .../test_aws_batch_step_operator.py | 86 +++++++++++++++++++ .../test_aws_batch_step_operator_flavor.py | 8 ++ 9 files changed, 130 insertions(+), 41 deletions(-) rename src/zenml/integrations/aws/flavors/{batch_step_operator_flavor.py => aws_batch_step_operator_flavor.py} (96%) rename src/zenml/integrations/aws/step_operators/{batch_step_operator.py => aws_batch_step_operator.py} (94%) rename src/zenml/integrations/aws/step_operators/{batch_step_operator_entrypoint_config.py => aws_batch_step_operator_entrypoint_config.py} (56%) create mode 100644 tests/integration/integrations/aws/step_operators/__init__.py create mode 100644 tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py create mode 100644 tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator_flavor.py diff --git a/pyproject.toml b/pyproject.toml index 61e3bce7497..11213fe1442 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -250,7 +250,7 @@ dev = [ "darglint", "pytest-randomly", "pytest-mock", - "pytest-clarity", + #"pytest-clarity", "pytest-instafail", "pytest-rerunfailures", "pytest-split", diff --git a/src/zenml/integrations/aws/flavors/__init__.py b/src/zenml/integrations/aws/flavors/__init__.py index d5c0c35ea64..bc32651cc72 100644 --- a/src/zenml/integrations/aws/flavors/__init__.py +++ b/src/zenml/integrations/aws/flavors/__init__.py @@ -29,7 +29,7 @@ SagemakerStepOperatorConfig, SagemakerStepOperatorFlavor, ) -from zenml.integrations.aws.flavors.batch_step_operator_flavor import ( +from zenml.integrations.aws.flavors.aws_batch_step_operator_flavor import ( AWSBatchStepOperatorConfig, AWSBatchStepOperatorFlavor ) diff --git a/src/zenml/integrations/aws/flavors/batch_step_operator_flavor.py b/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py similarity index 96% rename from src/zenml/integrations/aws/flavors/batch_step_operator_flavor.py rename to src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py index ee60a46f3e6..f814da513a4 100644 --- a/src/zenml/integrations/aws/flavors/batch_step_operator_flavor.py +++ b/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Amazon SageMaker step operator flavor.""" -from typing import TYPE_CHECKING, Any, Dict, Optional, Type, Union +from typing import TYPE_CHECKING, Dict, Optional, Type from pydantic import Field @@ -27,7 +27,6 @@ BaseStepOperatorConfig, BaseStepOperatorFlavor, ) -from zenml.utils import deprecation_utils if TYPE_CHECKING: from zenml.integrations.aws.step_operators import AWSBatchStepOperator @@ -36,13 +35,13 @@ class AWSBatchStepOperatorSettings(BaseSettings): """Settings for the Sagemaker step operator.""" - instance_type: Optional[str] = Field( - 'optimal', + instance_type: str = Field( + default='optimal', description="The instance type for AWS Batch to use for the step" \ " execution. Example: 'm5.xlarge'", ) environment: Dict[str, str] = Field( - default_factory=dict, + default={}, description="Environment variables to pass to the container during " \ "execution. Example: {'LOG_LEVEL': 'INFO', 'DEBUG_MODE': 'False'}", ) diff --git a/src/zenml/integrations/aws/step_operators/__init__.py b/src/zenml/integrations/aws/step_operators/__init__.py index b2d32099141..d04fccea431 100644 --- a/src/zenml/integrations/aws/step_operators/__init__.py +++ b/src/zenml/integrations/aws/step_operators/__init__.py @@ -16,7 +16,7 @@ from zenml.integrations.aws.step_operators.sagemaker_step_operator import ( # noqa: F401 SagemakerStepOperator, ) -from zenml.integrations.aws.step_operators.batch_step_operator import ( # noqa: F401 +from zenml.integrations.aws.step_operators.aws_batch_step_operator import ( # noqa: F401 AWSBatchStepOperator, ) __all__ = ["SagemakerStepOperator","AWSBatchStepOperator"] diff --git a/src/zenml/integrations/aws/step_operators/batch_step_operator.py b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py similarity index 94% rename from src/zenml/integrations/aws/step_operators/batch_step_operator.py rename to src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py index b8e93745055..4604b2a4064 100644 --- a/src/zenml/integrations/aws/step_operators/batch_step_operator.py +++ b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py @@ -14,6 +14,7 @@ """Implementation of the Sagemaker Step Operator.""" import time +import math from typing import ( TYPE_CHECKING, Dict, @@ -21,26 +22,20 @@ Optional, Tuple, Type, - Union, Literal, cast, ) from pydantic import BaseModel - import boto3 -from sagemaker.estimator import Estimator -from sagemaker.inputs import TrainingInput -from sagemaker.session import Session from zenml.client import Client from zenml.config.build_configuration import BuildConfiguration from zenml.enums import StackComponentType -from zenml.integrations.aws.flavors.batch_step_operator_flavor import ( +from zenml.integrations.aws.flavors.aws_batch_step_operator_flavor import ( AWSBatchStepOperatorConfig, AWSBatchStepOperatorSettings, ) -from zenml.integrations.aws.step_operators.batch_step_operator_entrypoint_config import ( - BATCH_STEP_ENV_VAR_SIZE_LIMIT, +from zenml.integrations.aws.step_operators.aws_batch_step_operator_entrypoint_config import ( AWSBatchEntrypointConfiguration, ) from zenml.logger import get_logger @@ -49,7 +44,6 @@ from zenml.step_operators.step_operator_entrypoint_configuration import ( StepOperatorEntrypointConfiguration, ) -from zenml.utils.env_utils import split_environment_variables from zenml.utils.string_utils import random_str if TYPE_CHECKING: @@ -60,7 +54,7 @@ logger = get_logger(__name__) -BATCH_DOCKER_IMAGE_KEY = "batch_step_operator" +BATCH_DOCKER_IMAGE_KEY = "aws_batch_step_operator" _ENTRYPOINT_ENV_VARIABLE = "__ZENML_ENTRYPOINT" class AWSBatchJobDefinitionContainerProperties(BaseModel): @@ -232,11 +226,11 @@ def map_environment(environment: Dict[str,str]) -> List[Dict[str,str]]: """ return [ - {"name":k,"value":v} for k,v in environment + {"name":k,"value":v} for k,v in environment.items() ] @staticmethod - def map_resource_settings(resource_settings: ResourceSettings) -> List[Dict[str,str]]: + def map_resource_settings(resource_settings: "ResourceSettings") -> List[Dict[str,str]]: """Utility to map the resource_settings to the resource convention used in the AWS Batch Job definition spec. @@ -253,9 +247,15 @@ def map_resource_settings(resource_settings: ResourceSettings) -> List[Dict[str, else: if resource_settings.cpu_count is not None: + + cpu_count_int = math.ceil(resource_settings.cpu_count) + + if cpu_count_int != resource_settings.cpu_count: + logger.info(f"AWS Batch only accepts int type cpu resource requirements. Converted {resource_settings.cpu_count} to {cpu_count_int}") + mapped_resource_settings.append( { - "value": resource_settings.cpu_count, + "value": str(cpu_count_int), "type": 'VCPU' } ) @@ -263,7 +263,7 @@ def map_resource_settings(resource_settings: ResourceSettings) -> List[Dict[str, if resource_settings.gpu_count is not None: mapped_resource_settings.append( { - "value": resource_settings.gpu_count, + "value": str(resource_settings.gpu_count), "type": 'GPU' } ) @@ -271,7 +271,7 @@ def map_resource_settings(resource_settings: ResourceSettings) -> List[Dict[str, if resource_settings.get_memory() is not None: mapped_resource_settings.append( { - "value": resource_settings.get_memory(), + "value": str(int(resource_settings.get_memory(unit="MiB"))), "type": 'MEMORY' } ) diff --git a/src/zenml/integrations/aws/step_operators/batch_step_operator_entrypoint_config.py b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator_entrypoint_config.py similarity index 56% rename from src/zenml/integrations/aws/step_operators/batch_step_operator_entrypoint_config.py rename to src/zenml/integrations/aws/step_operators/aws_batch_step_operator_entrypoint_config.py index 9cc99c6c4c8..cd57398a38f 100644 --- a/src/zenml/integrations/aws/step_operators/batch_step_operator_entrypoint_config.py +++ b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator_entrypoint_config.py @@ -16,24 +16,7 @@ from zenml.step_operators.step_operator_entrypoint_configuration import ( StepOperatorEntrypointConfiguration, ) -from zenml.utils.env_utils import reconstruct_environment_variables - -BATCH_STEP_ENV_VAR_SIZE_LIMIT = 512 class AWSBatchEntrypointConfiguration(StepOperatorEntrypointConfiguration): - """Entrypoint configuration for ZenML Batch step operator. - - The only purpose of this entrypoint configuration is to reconstruct the - environment variables that exceed the maximum length of 512 characters - allowed for Batch steps from their individual components. - """ - - def run(self) -> None: - """Runs the step.""" - # Reconstruct the environment variables that exceed the maximum length - # of 512 characters from their individual chunks - reconstruct_environment_variables() - - # Run the step - super().run() + """Entrypoint configuration for ZenML Batch step operator.""" \ No newline at end of file diff --git a/tests/integration/integrations/aws/step_operators/__init__.py b/tests/integration/integrations/aws/step_operators/__init__.py new file mode 100644 index 00000000000..cd90a82cfc2 --- /dev/null +++ b/tests/integration/integrations/aws/step_operators/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. diff --git a/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py b/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py new file mode 100644 index 00000000000..6ab344da4f1 --- /dev/null +++ b/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py @@ -0,0 +1,86 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. + + +import pytest + +from zenml.config.resource_settings import ResourceSettings +from zenml.integrations.aws.step_operators.aws_batch_step_operator import AWSBatchStepOperator + +def test_aws_batch_step_operator_map_environment(): + test_environment = {'key_1':'value_1','key_2':'value_2'} + expected = [ + { + "name": "key_1", + "value": "value_1" + }, + { + "name": "key_2", + "value": "value_2" + } + ] + + assert AWSBatchStepOperator.map_environment(test_environment) == expected + +@pytest.mark.parametrize( + "test_resource_settings,expected", + [ + ( + ResourceSettings(), + [] + ), + ( + ResourceSettings(cpu_count=0.4,gpu_count=1,memory="10MiB"), + [ + { + "value":"1", + "type":"VCPU" + }, + { + "value":"1", + "type":"GPU" + }, + { + "value": "10", + "type": "MEMORY" + } + ] + ), + ( + ResourceSettings(cpu_count=1,gpu_count=1), + [ + { + "value":"1", + "type":"VCPU" + }, + { + "value":"1", + "type":"GPU" + }, + ] + ), + ( + ResourceSettings(memory="1GiB"), + [ + { + "value": "1024", + "type": "MEMORY" + } + ] + ), + + ] +) +def test_aws_batch_step_operator_map_resource_settings(test_resource_settings,expected): + assert AWSBatchStepOperator.map_resource_settings(test_resource_settings) == expected diff --git a/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator_flavor.py b/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator_flavor.py new file mode 100644 index 00000000000..2a160dd8127 --- /dev/null +++ b/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator_flavor.py @@ -0,0 +1,8 @@ +from zenml.integrations.aws.flavors.aws_batch_step_operator_flavor import AWSBatchStepOperatorSettings + +def test_aws_batch_step_operator_settings(): + AWSBatchStepOperatorSettings( + instance_type="g4dn.xlarge", + environment={"key_1":"value_1","key_2":"value_2"}, + timeout_seconds=60 + ) \ No newline at end of file From 01017cc3c68f5a89438b9c4e0d0b9eaea0ef5619 Mon Sep 17 00:00:00 2001 From: SebastianScherer88 Date: Sun, 14 Sep 2025 00:18:38 +0100 Subject: [PATCH 05/36] added support for multinode aws batch job type --- .../flavors/aws_batch_step_operator_flavor.py | 16 ++++-- .../step_operators/aws_batch_step_operator.py | 51 ++++++++++++------- 2 files changed, 44 insertions(+), 23 deletions(-) diff --git a/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py b/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py index f814da513a4..5eb3520f746 100644 --- a/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py +++ b/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py @@ -13,9 +13,9 @@ # permissions and limitations under the License. """Amazon SageMaker step operator flavor.""" -from typing import TYPE_CHECKING, Dict, Optional, Type +from typing import TYPE_CHECKING, Dict, Optional, Type, List, Union -from pydantic import Field +from pydantic import Field, PositiveInt from zenml.config.base_settings import BaseSettings from zenml.integrations.aws import ( @@ -35,7 +35,7 @@ class AWSBatchStepOperatorSettings(BaseSettings): """Settings for the Sagemaker step operator.""" - instance_type: str = Field( + instance_type: Union[str,List[str]] = Field( default='optimal', description="The instance type for AWS Batch to use for the step" \ " execution. Example: 'm5.xlarge'", @@ -45,11 +45,19 @@ class AWSBatchStepOperatorSettings(BaseSettings): description="Environment variables to pass to the container during " \ "execution. Example: {'LOG_LEVEL': 'INFO', 'DEBUG_MODE': 'False'}", ) - timeout_seconds: int = Field( + node_count: PositiveInt = Field( + description="The number of AWS Batch nodes to run the step on. If > 1," \ + "an AWS Batch multinode job will be run, with the network connectivity" \ + "between the nodes provided by AWS Batch. See https://docs.aws.amazon.com/batch/latest/userguide/multi-node-parallel-jobs.html" \ + "for details." + ) + timeout_seconds: PositiveInt = Field( default=120, description="The number of seconds before AWS Batch times out the job." ) + + class AWSBatchStepOperatorConfig( BaseStepOperatorConfig, AWSBatchStepOperatorSettings ): diff --git a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py index 4604b2a4064..a90863e7c60 100644 --- a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py +++ b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py @@ -25,7 +25,7 @@ Literal, cast, ) -from pydantic import BaseModel +from pydantic import BaseModel, PositiveInt import boto3 from zenml.client import Client @@ -71,22 +71,14 @@ class AWSBatchJobDefinitionContainerProperties(BaseModel): class AWSBatchJobDefinitionNodePropertiesNodeRangeProperty(BaseModel): """An AWS Batch job subconfiguration model for a node in a multinode job specifications. - - Note: We include this class for completeness sake to make it easier to add - multinode support later for now. """ targetNodes: str container: AWSBatchJobDefinitionContainerProperties class AWSBatchJobDefinitionNodeProperties(BaseModel): """An AWS Batch job subconfiguration model for multinode job specifications. - - Note: We include this class for completeness sake to make it easier to add - multinode support later for now, we'll set defaults to intuitively - represent the only supported exeuction type ('container'); in reality AWS - Batch will ignore this config. """ - numNodes: int = 1 + numNodes: PositiveInt = 1 mainNode: int = 0 nodeRangeProperties: List[ AWSBatchJobDefinitionNodePropertiesNodeRangeProperty @@ -94,7 +86,7 @@ class AWSBatchJobDefinitionNodeProperties(BaseModel): class AWSBatchJobDefinitionRetryStrategy(BaseModel): """An AWS Batch job subconfiguration model for retry specifications.""" - attempts: int = 2 + attempts: PositiveInt = 2 evaluateOnExit: List[Dict[str,str]] = [ { "onExitCode": "137", # out-of-memory killed @@ -122,9 +114,8 @@ class AWSBatchJobDefinition(BaseModel): type: Literal['container','multinode'] = 'container' # we dont support multinode type in this version parameters: Dict[str,str] = {} schedulingPriority: int = 0 # ignored in FIFO queues - containerProperties: AWSBatchJobDefinitionContainerProperties - nodeProperties: AWSBatchJobDefinitionNodeProperties = AWSBatchJobDefinitionNodeProperties( - numNodes=1,mainNode=0,nodeRangeProperties=[]) # we'll focus on container mode for now - let's add multinode support later, as that will most likely require network configuration support as well + containerProperties: Optional[AWSBatchJobDefinitionContainerProperties] = None + nodeProperties: Optional[AWSBatchJobDefinitionNodeProperties] = None retryStrategy: AWSBatchJobDefinitionRetryStrategy = AWSBatchJobDefinitionRetryStrategy() propagateTags: bool = False timeout: Dict[str,int] = {'attemptDurationSeconds':60} # key 'attemptDurationSeconds' @@ -305,10 +296,7 @@ def generate_job_definition(self, info: "StepRunInfo", entrypoint_command: List[ step_settings = cast(AWSBatchStepOperatorSettings, self.get_settings(info)) job_name = self.generate_unique_batch_job_name(info) - - return AWSBatchJobDefinition( - jobDefinitionName=job_name, - containerProperties=AWSBatchJobDefinitionContainerProperties( + container_properties = AWSBatchJobDefinitionContainerProperties( executionRoleArn=self.config.execution_role, jobRoleArn=self.config.job_role, image=image_name, @@ -317,8 +305,33 @@ def generate_job_definition(self, info: "StepRunInfo", entrypoint_command: List[ instanceType=step_settings.instance_type, resourceRequirements=self.map_resource_settings(resource_settings), ), + + node_count = step_settings.node_count + + if node_count == 1: + kwargs = { + 'type':'container', + 'containerProperties':container_properties + } + else: + kwargs = { + 'type':'multinode', + 'nodeProperties':AWSBatchJobDefinitionNodeProperties( + numNodes=node_count, + nodeRangeProperties=[ + AWSBatchJobDefinitionNodePropertiesNodeRangeProperty( + targetNodes=','.join([str(node_index) for node_index in range(node_count)]), + container=container_properties + ) + ] + ) + } + + + return AWSBatchJobDefinition( + jobDefinitionName=job_name, timeout={'attemptDurationSeconds':step_settings.timeout_seconds}, - # type: Literal['container','multinode'] = 'container' # we dont support multinode type in this version + **kwargs # parameters: Dict[str,str] = {} # schedulingPriority: int = 0 # ignored in FIFO queues # nodeProperties: AWSBatchJobDefinitionNodeProperties = AWSBatchJobDefinitionNodeProperties( From b22672bcf8c5523991db855c4f49c47ef29023e7 Mon Sep 17 00:00:00 2001 From: SebastianScherer88 Date: Sun, 14 Sep 2025 00:59:04 +0100 Subject: [PATCH 06/36] added support for multinode aws batch job type --- .../flavors/aws_batch_step_operator_flavor.py | 1 + .../aws/step_operators/__init__.py | 3 +- .../step_operators/aws_batch_step_operator.py | 40 ++++++++++++++++++- .../test_aws_batch_step_operator.py | 20 +++++++++- 4 files changed, 61 insertions(+), 3 deletions(-) diff --git a/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py b/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py index 5eb3520f746..80e549d61bd 100644 --- a/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py +++ b/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py @@ -46,6 +46,7 @@ class AWSBatchStepOperatorSettings(BaseSettings): "execution. Example: {'LOG_LEVEL': 'INFO', 'DEBUG_MODE': 'False'}", ) node_count: PositiveInt = Field( + default=1, description="The number of AWS Batch nodes to run the step on. If > 1," \ "an AWS Batch multinode job will be run, with the network connectivity" \ "between the nodes provided by AWS Batch. See https://docs.aws.amazon.com/batch/latest/userguide/multi-node-parallel-jobs.html" \ diff --git a/src/zenml/integrations/aws/step_operators/__init__.py b/src/zenml/integrations/aws/step_operators/__init__.py index d04fccea431..e680dc1d063 100644 --- a/src/zenml/integrations/aws/step_operators/__init__.py +++ b/src/zenml/integrations/aws/step_operators/__init__.py @@ -18,5 +18,6 @@ ) from zenml.integrations.aws.step_operators.aws_batch_step_operator import ( # noqa: F401 AWSBatchStepOperator, + get_aws_batch_context ) -__all__ = ["SagemakerStepOperator","AWSBatchStepOperator"] +__all__ = ["SagemakerStepOperator","AWSBatchStepOperator","get_aws_batch_context"] diff --git a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py index a90863e7c60..a2e901f66c4 100644 --- a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py +++ b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py @@ -25,7 +25,8 @@ Literal, cast, ) -from pydantic import BaseModel, PositiveInt +from pydantic import BaseModel, PositiveInt, Field +from pydantic_settings import BaseSettings as PydanticBaseSettings import boto3 from zenml.client import Client @@ -57,6 +58,43 @@ BATCH_DOCKER_IMAGE_KEY = "aws_batch_step_operator" _ENTRYPOINT_ENV_VARIABLE = "__ZENML_ENTRYPOINT" +class AWSBatchContext(PydanticBaseSettings): + """A utility to access the AWS Batch job context environment variables.""" + main_node_index: int = Field( + description="This variable is set to the index number of the job's " + "main node. Your application code can compare the " + "AWS_BATCH_JOB_MAIN_NODE_INDEX to the AWS_BATCH_JOB_NODE_INDEX on " + "an individual node to determine if it's the main node.", + alias="AWS_BATCH_JOB_MAIN_NODE_INDEX" + ) + + main_node_address: str = Field( + description="This variable is only set in multi-node parallel job " + "child nodes. This variable isn't present on the main node. This " + "variable is set to the private IPv4 address of the job's main " + "node. Your child node's application code can use this address to " + "communicate with the main node.", + alias="AWS_BATCH_JOB_MAIN_NODE_PRIVATE_IPV4_ADDRESS" + ) + + node_index: int = Field( + description="This variable is set to the node index number of the " + "node. The node index begins at 0, and each node receives a unique" + " index number. For example, a multi-node parallel job with 10 " + "children has index values of 0-9.", + alias="AWS_BATCH_JOB_NODE_INDEX" + ) + + num_nodes: int = Field( + description="This variable is set to the number of nodes that you have" + "requested for your multi-node parallel job.", + alias="AWS_BATCH_JOB_NUM_NODES" + ) + +def get_aws_batch_context() -> AWSBatchContext: + """Utility to retrieve the AWS Batch runtime context.""" + return AWSBatchContext() + class AWSBatchJobDefinitionContainerProperties(BaseModel): """An AWS Batch job subconfiguration model for a container specification.""" image: str diff --git a/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py b/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py index 6ab344da4f1..b032818cfbf 100644 --- a/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py +++ b/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py @@ -16,9 +16,25 @@ import pytest from zenml.config.resource_settings import ResourceSettings -from zenml.integrations.aws.step_operators.aws_batch_step_operator import AWSBatchStepOperator +from zenml.integrations.aws.step_operators.aws_batch_step_operator import AWSBatchStepOperator, get_aws_batch_context + +def test_aws_batch_context(monkeypatch): + """Tests the AWSBatchContext class.""" + + monkeypatch.setenv('AWS_BATCH_JOB_MAIN_NODE_INDEX',0) + monkeypatch.setenv('AWS_BATCH_JOB_MAIN_NODE_PRIVATE_IPV4_ADDRESS','test-address') + monkeypatch.setenv('AWS_BATCH_JOB_NODE_INDEX',1) + monkeypatch.setenv('AWS_BATCH_JOB_NUM_NODES',2) + + test_aws_batch_context = get_aws_batch_context() + assert test_aws_batch_context.main_node_index == 0 + assert test_aws_batch_context.main_node_address == 'test-address' + assert test_aws_batch_context.node_index == 1 + assert test_aws_batch_context.num_nodes == 2 def test_aws_batch_step_operator_map_environment(): + """Tests the AWSBatchStepOperator's map_environment class method.""" + test_environment = {'key_1':'value_1','key_2':'value_2'} expected = [ { @@ -83,4 +99,6 @@ def test_aws_batch_step_operator_map_environment(): ] ) def test_aws_batch_step_operator_map_resource_settings(test_resource_settings,expected): + """Tests the AWSBatchStepOperator's map_resource_settings class method.""" + assert AWSBatchStepOperator.map_resource_settings(test_resource_settings) == expected From 371a4ac98cfdee5b65c27a7e3e59a917990354ce Mon Sep 17 00:00:00 2001 From: SebastianScherer88 Date: Sun, 14 Sep 2025 19:35:38 +0100 Subject: [PATCH 07/36] adding test dependency back in and fixing typo in sagemaker doc string --- pyproject.toml | 2 +- .../aws/step_operators/sagemaker_step_operator.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 11213fe1442..61e3bce7497 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -250,7 +250,7 @@ dev = [ "darglint", "pytest-randomly", "pytest-mock", - #"pytest-clarity", + "pytest-clarity", "pytest-instafail", "pytest-rerunfailures", "pytest-split", diff --git a/src/zenml/integrations/aws/step_operators/sagemaker_step_operator.py b/src/zenml/integrations/aws/step_operators/sagemaker_step_operator.py index 2df9a81d21e..6456d4d9a37 100644 --- a/src/zenml/integrations/aws/step_operators/sagemaker_step_operator.py +++ b/src/zenml/integrations/aws/step_operators/sagemaker_step_operator.py @@ -121,12 +121,12 @@ def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]: if container_registry.config.is_local: return False, ( - "The Batch step operator runs code remotely and " + "The SageMaker step operator runs code remotely and " "needs to push/pull Docker images, but the " f"container registry `{container_registry.name}` of the " "active stack is local. Please ensure that your stack " "contains a remote container registry when using the " - "Batch step operator." + "SageMaker step operator." ) return True, "" From a80f2665d2030f09800f3202238e2f7a50adccd6 Mon Sep 17 00:00:00 2001 From: SebastianScherer88 Date: Sun, 14 Sep 2025 19:39:20 +0100 Subject: [PATCH 08/36] renaming the aws batch runtime context retrieval utility --- src/zenml/integrations/aws/step_operators/__init__.py | 4 ++-- .../aws/step_operators/aws_batch_step_operator.py | 2 +- .../aws/step_operators/test_aws_batch_step_operator.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/zenml/integrations/aws/step_operators/__init__.py b/src/zenml/integrations/aws/step_operators/__init__.py index e680dc1d063..ecc24b98f0e 100644 --- a/src/zenml/integrations/aws/step_operators/__init__.py +++ b/src/zenml/integrations/aws/step_operators/__init__.py @@ -18,6 +18,6 @@ ) from zenml.integrations.aws.step_operators.aws_batch_step_operator import ( # noqa: F401 AWSBatchStepOperator, - get_aws_batch_context + get_context ) -__all__ = ["SagemakerStepOperator","AWSBatchStepOperator","get_aws_batch_context"] +__all__ = ["SagemakerStepOperator","AWSBatchStepOperator","get_context"] diff --git a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py index a2e901f66c4..69ef0b49efe 100644 --- a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py +++ b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py @@ -91,7 +91,7 @@ class AWSBatchContext(PydanticBaseSettings): alias="AWS_BATCH_JOB_NUM_NODES" ) -def get_aws_batch_context() -> AWSBatchContext: +def get_context() -> AWSBatchContext: """Utility to retrieve the AWS Batch runtime context.""" return AWSBatchContext() diff --git a/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py b/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py index b032818cfbf..3aaa7bf9a43 100644 --- a/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py +++ b/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py @@ -16,7 +16,7 @@ import pytest from zenml.config.resource_settings import ResourceSettings -from zenml.integrations.aws.step_operators.aws_batch_step_operator import AWSBatchStepOperator, get_aws_batch_context +from zenml.integrations.aws.step_operators.aws_batch_step_operator import AWSBatchStepOperator, get_context def test_aws_batch_context(monkeypatch): """Tests the AWSBatchContext class.""" @@ -26,7 +26,7 @@ def test_aws_batch_context(monkeypatch): monkeypatch.setenv('AWS_BATCH_JOB_NODE_INDEX',1) monkeypatch.setenv('AWS_BATCH_JOB_NUM_NODES',2) - test_aws_batch_context = get_aws_batch_context() + test_aws_batch_context = get_context() assert test_aws_batch_context.main_node_index == 0 assert test_aws_batch_context.main_node_address == 'test-address' assert test_aws_batch_context.node_index == 1 From e372f85c4b245d655d69f3c21d480ea496c27232 Mon Sep 17 00:00:00 2001 From: SebastianScherer88 Date: Sat, 13 Sep 2025 00:40:49 +0100 Subject: [PATCH 09/36] started creating required files and mapping out the zenml config -> aws batch job definition conversion --- src/zenml/integrations/aws/__init__.py | 4 + .../integrations/aws/flavors/__init__.py | 7 + .../aws/flavors/batch_step_operator_flavor.py | 157 +++++++++ .../aws/step_operators/__init__.py | 8 +- .../aws/step_operators/batch_step_operator.py | 304 ++++++++++++++++++ .../batch_step_operator_entrypoint_config.py | 39 +++ .../step_operators/sagemaker_step_operator.py | 4 +- 7 files changed, 518 insertions(+), 5 deletions(-) create mode 100644 src/zenml/integrations/aws/flavors/batch_step_operator_flavor.py create mode 100644 src/zenml/integrations/aws/step_operators/batch_step_operator.py create mode 100644 src/zenml/integrations/aws/step_operators/batch_step_operator_entrypoint_config.py diff --git a/src/zenml/integrations/aws/__init__.py b/src/zenml/integrations/aws/__init__.py index d6cf5866fa0..231d81a6bb0 100644 --- a/src/zenml/integrations/aws/__init__.py +++ b/src/zenml/integrations/aws/__init__.py @@ -28,6 +28,7 @@ AWS_CONTAINER_REGISTRY_FLAVOR = "aws" AWS_SAGEMAKER_STEP_OPERATOR_FLAVOR = "sagemaker" AWS_SAGEMAKER_ORCHESTRATOR_FLAVOR = "sagemaker" +AWS_BATCH_STEP_OPERATOR_FLAVOR = "aws_batch" # Service connector constants AWS_CONNECTOR_TYPE = "aws" @@ -42,6 +43,7 @@ class AWSIntegration(Integration): NAME = AWS REQUIREMENTS = [ "sagemaker>=2.237.3", + "boto3>=1.40.30", "kubernetes", "aws-profile-manager", ] @@ -64,6 +66,7 @@ def flavors(cls) -> List[Type[Flavor]]: AWSImageBuilderFlavor, SagemakerOrchestratorFlavor, SagemakerStepOperatorFlavor, + AWSBatchStepOperatorFlavor ) return [ @@ -71,4 +74,5 @@ def flavors(cls) -> List[Type[Flavor]]: AWSImageBuilderFlavor, SagemakerStepOperatorFlavor, SagemakerOrchestratorFlavor, + AWSBatchStepOperatorFlavor ] diff --git a/src/zenml/integrations/aws/flavors/__init__.py b/src/zenml/integrations/aws/flavors/__init__.py index a2cdc428add..d5c0c35ea64 100644 --- a/src/zenml/integrations/aws/flavors/__init__.py +++ b/src/zenml/integrations/aws/flavors/__init__.py @@ -29,6 +29,10 @@ SagemakerStepOperatorConfig, SagemakerStepOperatorFlavor, ) +from zenml.integrations.aws.flavors.batch_step_operator_flavor import ( + AWSBatchStepOperatorConfig, + AWSBatchStepOperatorFlavor +) __all__ = [ "AWSContainerRegistryFlavor", @@ -39,4 +43,7 @@ "SagemakerStepOperatorConfig", "SagemakerOrchestratorFlavor", "SagemakerOrchestratorConfig", + "AWSBatchStepOperatorFlavor", + "AWSBatchStepOperatorConfig", + ] diff --git a/src/zenml/integrations/aws/flavors/batch_step_operator_flavor.py b/src/zenml/integrations/aws/flavors/batch_step_operator_flavor.py new file mode 100644 index 00000000000..be94a8f29a4 --- /dev/null +++ b/src/zenml/integrations/aws/flavors/batch_step_operator_flavor.py @@ -0,0 +1,157 @@ +# Copyright (c) ZenML GmbH 2022. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Amazon SageMaker step operator flavor.""" + +from typing import TYPE_CHECKING, Any, Dict, Optional, Type, Union + +from pydantic import Field + +from zenml.config.base_settings import BaseSettings +from zenml.integrations.aws import ( + AWS_RESOURCE_TYPE, + AWS_BATCH_STEP_OPERATOR_FLAVOR, +) +from zenml.models import ServiceConnectorRequirements +from zenml.step_operators.base_step_operator import ( + BaseStepOperatorConfig, + BaseStepOperatorFlavor, +) +from zenml.utils import deprecation_utils + +if TYPE_CHECKING: + from zenml.integrations.aws.step_operators import AWSBatchStepOperator + + +class AWSBatchStepOperatorSettings(BaseSettings): + """Settings for the Sagemaker step operator.""" + + instance_type: Optional[str] = Field( + None, + description="DEPRECATED: The instance type to use for the step execution. " + "Use estimator_args instead. Example: 'ml.m5.xlarge'", + ) + environment: Dict[str, str] = Field( + default_factory=dict, + description="Environment variables to pass to the container during execution. " + "Example: {'LOG_LEVEL': 'INFO', 'DEBUG_MODE': 'False'}", + ) + + _deprecation_validator = deprecation_utils.deprecate_pydantic_attributes( + "instance_type" + ) + + +class AWSBatchStepOperatorConfig( + BaseStepOperatorConfig, AWSBatchStepOperatorSettings +): + """Config for the AWS Batch step operator.""" + + execution_role: str = Field( + "", + description="The ECS execution role required to execute the AWS Batch" \ + " jobs as ECS tasks." + ) + job_role: str = Field( + "", + description="The ECS job role required by the container runtime insdide" \ + "the ECS task implementing the zenml step." + ) + + @property + def is_remote(self) -> bool: + """Checks if this stack component is running remotely. + + This designation is used to determine if the stack component can be + used with a local ZenML database or if it requires a remote ZenML + server. + + Returns: + True if this config is for a remote component, False otherwise. + """ + return True + + +class AWSBatchStepOperatorFlavor(BaseStepOperatorFlavor): + """Flavor for the AWS Batch step operator.""" + + @property + def name(self) -> str: + """Name of the flavor. + + Returns: + The name of the flavor. + """ + return AWS_BATCH_STEP_OPERATOR_FLAVOR + + @property + def service_connector_requirements( + self, + ) -> Optional[ServiceConnectorRequirements]: + """Service connector resource requirements for service connectors. + + Specifies resource requirements that are used to filter the available + service connector types that are compatible with this flavor. + + Returns: + Requirements for compatible service connectors, if a service + connector is required for this flavor. + """ + return ServiceConnectorRequirements(resource_type=AWS_RESOURCE_TYPE) + + @property + def docs_url(self) -> Optional[str]: + """A url to point at docs explaining this flavor. + + Returns: + A flavor docs url. + """ + return self.generate_default_docs_url() + + @property + def sdk_docs_url(self) -> Optional[str]: + """A url to point at SDK docs explaining this flavor. + + Returns: + A flavor SDK docs url. + """ + return self.generate_default_sdk_docs_url() + + @property + def logo_url(self) -> str: + """A url to represent the flavor in the dashboard. + + Returns: + The flavor logo. + """ + return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/step_operator/aws_batch.png" + + @property + def config_class(self) -> Type[AWSBatchStepOperatorConfig]: + """Returns BatchStepOperatorConfig config class. + + Returns: + The config class. + """ + return AWSBatchStepOperatorConfig + + @property + def implementation_class(self) -> Type["AWSBatchStepOperator"]: + """Implementation class. + + Returns: + The implementation class. + """ + from zenml.integrations.aws.step_operators import AWSBatchStepOperator + + return AWSBatchStepOperator diff --git a/src/zenml/integrations/aws/step_operators/__init__.py b/src/zenml/integrations/aws/step_operators/__init__.py index 9eee3140d43..b2d32099141 100644 --- a/src/zenml/integrations/aws/step_operators/__init__.py +++ b/src/zenml/integrations/aws/step_operators/__init__.py @@ -13,8 +13,10 @@ # permissions and limitations under the License. """Initialization of the Sagemaker Step Operator.""" -from zenml.integrations.aws.step_operators.sagemaker_step_operator import ( # noqa +from zenml.integrations.aws.step_operators.sagemaker_step_operator import ( # noqa: F401 SagemakerStepOperator, ) - -__all__ = ["SagemakerStepOperator"] +from zenml.integrations.aws.step_operators.batch_step_operator import ( # noqa: F401 + AWSBatchStepOperator, +) +__all__ = ["SagemakerStepOperator","AWSBatchStepOperator"] diff --git a/src/zenml/integrations/aws/step_operators/batch_step_operator.py b/src/zenml/integrations/aws/step_operators/batch_step_operator.py new file mode 100644 index 00000000000..dd296ce3091 --- /dev/null +++ b/src/zenml/integrations/aws/step_operators/batch_step_operator.py @@ -0,0 +1,304 @@ +# Copyright (c) ZenML GmbH 2022. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Implementation of the Sagemaker Step Operator.""" + +import time +from typing import ( + TYPE_CHECKING, + Dict, + List, + Optional, + Tuple, + Type, + Union, + Literal, + cast, +) +from pydantic import BaseModel + +import boto3 +from sagemaker.estimator import Estimator +from sagemaker.inputs import TrainingInput +from sagemaker.session import Session + +from zenml.client import Client +from zenml.config.build_configuration import BuildConfiguration +from zenml.enums import StackComponentType +from zenml.integrations.aws.flavors.batch_step_operator_flavor import ( + AWSBatchStepOperatorConfig, + AWSBatchStepOperatorSettings, +) +from zenml.integrations.aws.step_operators.batch_step_operator_entrypoint_config import ( + BATCH_STEP_ENV_VAR_SIZE_LIMIT, + AWSBatchEntrypointConfiguration, +) +from zenml.logger import get_logger +from zenml.stack import Stack, StackValidator +from zenml.step_operators import BaseStepOperator +from zenml.step_operators.step_operator_entrypoint_configuration import ( + StepOperatorEntrypointConfiguration, +) +from zenml.utils.env_utils import split_environment_variables +from zenml.utils.string_utils import random_str + +if TYPE_CHECKING: + from zenml.config.base_settings import BaseSettings + from zenml.config.step_run_info import StepRunInfo + from zenml.models import PipelineDeploymentBase + +logger = get_logger(__name__) + +BATCH_DOCKER_IMAGE_KEY = "batch_step_operator" +_ENTRYPOINT_ENV_VARIABLE = "__ZENML_ENTRYPOINT" + +class AWSBatchJobDefinitionContainerProperties(BaseModel): + image: str + command: List[str] + jobRoleArn: str + executionRoleArn: str + environment: List[Dict[str,str]] = [] # keys: 'name','value' + instanceType: str + resourceRequirements: List[Dict[str,str]] = [] # keys: 'value','type', with type one of 'GPU','VCPU','MEMORY' + secrets: List[Dict[str,str]] = [] # keys: 'name','value' + +class AWSBatchJobDefinitionNodePropertiesNodeRangeProperty(BaseModel): + targetNodes: str + container: AWSBatchJobDefinitionContainerProperties + +class AWSBatchJobDefinitionNodeProperties(BaseModel): + # we include this class for completeness sake to make it easier + # to add multinode support later + # for now, we'll set defaults to intuitively represent the only supported + # exeuction type ('container'); in reality AWS Batch will ignore this + # config + numNodes: int = 1 + mainNode: int = 0 + nodeRangeProperties: List[AWSBatchJobDefinitionNodePropertiesNodeRangeProperty] = [] + +class AWSBatchJobDefinitionRetryStrategy(BaseModel): + attempts: int = 2 + evaluateOnExit: List[Dict[str,str]] = [ + { + "onExitCode": "137", # out-of-memory killed + "action": "RETRY" + }, + { + "onReason": "*Host EC2*", + "action": "RETRY" + }, + { + "onExitCode": "*", # match everything else + "action": "EXIT" + } + ] + # Example: + # { + # 'onStatusReason': 'string', + # 'onReason': 'string', + # 'onExitCode': 'string', + # 'action': 'RETRY'|'EXIT' + # }, + +class AWSBatchJobDefinition(BaseModel): + jobDefinitionName: str + type: Literal['container','multinode'] = 'container' # we dont support multinode type in this version + parameters: Dict[str,str] = {} + schedulingPriority: int = 0 # ignored in FIFO queues + containerProperties: AWSBatchJobDefinitionContainerProperties + nodeProperties: AWSBatchJobDefinitionNodeProperties = AWSBatchJobDefinitionNodeProperties( + numNodes=1,mainNode=0,nodeRangeProperties=[]) # we'll focus on container mode for now - let's add multinode support later, as that will most likely require network configuration support as well + retryStrategy: AWSBatchJobDefinitionRetryStrategy = AWSBatchJobDefinitionRetryStrategy() + propagateTags: bool = False + timeout: Dict[str,int] = {'attemptDurationSeconds':60} # key 'attemptDurationSeconds' + tags: Dict[str,str] = {} + platformCapabilities: Literal['EC2','FARGATE'] = "EC2" #-- hardcode this to EC2, so we can use container and multinode interchangeably without worrying too much + + +class AWSBatchStepOperator(BaseStepOperator): + """Step operator to run a step on Sagemaker. + + This class defines code that builds an image with the ZenML entrypoint + to run using Sagemaker's Estimator. + """ + + @property + def config(self) -> AWSBatchStepOperatorConfig: + """Returns the `SagemakerStepOperatorConfig` config. + + Returns: + The configuration. + """ + return cast(AWSBatchStepOperatorConfig, self._config) + + @property + def settings_class(self) -> Optional[Type["BaseSettings"]]: + """Settings class for the SageMaker step operator. + + Returns: + The settings class. + """ + return AWSBatchStepOperatorSettings + + @property + def entrypoint_config_class( + self, + ) -> Type[StepOperatorEntrypointConfiguration]: + """Returns the entrypoint configuration class for this step operator. + + Returns: + The entrypoint configuration class for this step operator. + """ + return AWSBatchEntrypointConfiguration + + @property + def validator(self) -> Optional[StackValidator]: + """Validates the stack. + + Returns: + A validator that checks that the stack contains a remote container + registry and a remote artifact store. + """ + + def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]: + if stack.artifact_store.config.is_local: + return False, ( + "The Batch step operator runs code remotely and " + "needs to write files into the artifact store, but the " + f"artifact store `{stack.artifact_store.name}` of the " + "active stack is local. Please ensure that your stack " + "contains a remote artifact store when using the Batch " + "step operator." + ) + + container_registry = stack.container_registry + assert container_registry is not None + + if container_registry.config.is_local: + return False, ( + "The Batch step operator runs code remotely and " + "needs to push/pull Docker images, but the " + f"container registry `{container_registry.name}` of the " + "active stack is local. Please ensure that your stack " + "contains a remote container registry when using the " + "Batch step operator." + ) + + return True, "" + + return StackValidator( + required_components={ + StackComponentType.CONTAINER_REGISTRY, + StackComponentType.IMAGE_BUILDER, + }, + custom_validation_function=_validate_remote_components, + ) + + + def generate_job_definition(self, info: "StepRunInfo", entrypoint_command: List[str], environment: Dict[str,str]) -> AWSBatchJobDefinition: + """Utility to map zenml internal configurations to a valid AWS Batch + job definition.""" + pass + + + def get_docker_builds( + self, deployment: "PipelineDeploymentBase" + ) -> List["BuildConfiguration"]: + """Gets the Docker builds required for the component. + + Args: + deployment: The pipeline deployment for which to get the builds. + + Returns: + The required Docker builds. + """ + builds = [] + for step_name, step in deployment.step_configurations.items(): + if step.config.uses_step_operator(self.name): + build = BuildConfiguration( + key=BATCH_DOCKER_IMAGE_KEY, + settings=step.config.docker_settings, + step_name=step_name, + entrypoint=f"${_ENTRYPOINT_ENV_VARIABLE}", + ) + builds.append(build) + + return builds + + def launch( + self, + info: "StepRunInfo", + entrypoint_command: List[str], + environment: Dict[str, str], + ) -> None: + """Launches a step on AWS Batch. + + Args: + info: Information about the step run. + entrypoint_command: Command that executes the step. + environment: Environment variables to set in the step operator + environment. + + Raises: + RuntimeError: If the connector returns an object that is not a + `boto3.Session`. + """ + if not info.config.resource_settings.empty: + logger.warning( + "Specifying custom step resources is not supported for " + "the AWS Batch step operator. If you want to run this step " + "operator on specific resources, you can do so by configuring " + "a different instance type like this: " + "`zenml step-operator update %s " + "--instance_type=`", + self.name, + ) + + image_name = info.get_image(key=BATCH_DOCKER_IMAGE_KEY) + + settings = cast(AWSBatchStepOperatorSettings, self.get_settings(info)) + + batch = boto3.client('batch') + + # Batch allows 63 characters at maximum for job name - ZenML uses 60 for safety margin. + step_name = Client().get_run_step(info.step_run_id).name + training_job_name = f"{info.pipeline.name}-{step_name}"[:55] + suffix = random_str(4) + unique_training_job_name = f"{training_job_name}-{suffix}" + + response = batch.register_job_definition( + jobDefinitionName=unique_training_job_name, + type='container', + containerProperties={ + 'image': image_name , + 'command': entrypoint_command, + } + ) + + job_definition = response['jobDefinitionName'] + + response = batch.submit_job( + jobName=unique_training_job_name, + jobQueue=self.config.job_queue_name, + jobDefinition=job_definition, + ) + + job_id = response['jobId'] + + while True: + response = batch.describe_jobs(jobs=[job_id]) + status = response['jobs'][0]['status'] + if status in ['SUCCEEDED', 'FAILED']: + break + time.sleep(10) + logger.info(f'Job completed with status {status}') \ No newline at end of file diff --git a/src/zenml/integrations/aws/step_operators/batch_step_operator_entrypoint_config.py b/src/zenml/integrations/aws/step_operators/batch_step_operator_entrypoint_config.py new file mode 100644 index 00000000000..9cc99c6c4c8 --- /dev/null +++ b/src/zenml/integrations/aws/step_operators/batch_step_operator_entrypoint_config.py @@ -0,0 +1,39 @@ +# Copyright (c) ZenML GmbH 2023. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Entrypoint configuration for ZenML Batch step operator.""" + +from zenml.step_operators.step_operator_entrypoint_configuration import ( + StepOperatorEntrypointConfiguration, +) +from zenml.utils.env_utils import reconstruct_environment_variables + +BATCH_STEP_ENV_VAR_SIZE_LIMIT = 512 + + +class AWSBatchEntrypointConfiguration(StepOperatorEntrypointConfiguration): + """Entrypoint configuration for ZenML Batch step operator. + + The only purpose of this entrypoint configuration is to reconstruct the + environment variables that exceed the maximum length of 512 characters + allowed for Batch steps from their individual components. + """ + + def run(self) -> None: + """Runs the step.""" + # Reconstruct the environment variables that exceed the maximum length + # of 512 characters from their individual chunks + reconstruct_environment_variables() + + # Run the step + super().run() diff --git a/src/zenml/integrations/aws/step_operators/sagemaker_step_operator.py b/src/zenml/integrations/aws/step_operators/sagemaker_step_operator.py index 6456d4d9a37..2df9a81d21e 100644 --- a/src/zenml/integrations/aws/step_operators/sagemaker_step_operator.py +++ b/src/zenml/integrations/aws/step_operators/sagemaker_step_operator.py @@ -121,12 +121,12 @@ def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]: if container_registry.config.is_local: return False, ( - "The SageMaker step operator runs code remotely and " + "The Batch step operator runs code remotely and " "needs to push/pull Docker images, but the " f"container registry `{container_registry.name}` of the " "active stack is local. Please ensure that your stack " "contains a remote container registry when using the " - "SageMaker step operator." + "Batch step operator." ) return True, "" From 3d8c39b94d015dfbc4135eaa6e949e821025b3e6 Mon Sep 17 00:00:00 2001 From: SebastianScherer88 Date: Sat, 13 Sep 2025 16:29:21 +0100 Subject: [PATCH 10/36] finished first draft of aws batch step operator --- .../aws/flavors/batch_step_operator_flavor.py | 34 ++-- .../aws/step_operators/batch_step_operator.py | 180 ++++++++++++++---- 2 files changed, 164 insertions(+), 50 deletions(-) diff --git a/src/zenml/integrations/aws/flavors/batch_step_operator_flavor.py b/src/zenml/integrations/aws/flavors/batch_step_operator_flavor.py index be94a8f29a4..ee60a46f3e6 100644 --- a/src/zenml/integrations/aws/flavors/batch_step_operator_flavor.py +++ b/src/zenml/integrations/aws/flavors/batch_step_operator_flavor.py @@ -37,35 +37,43 @@ class AWSBatchStepOperatorSettings(BaseSettings): """Settings for the Sagemaker step operator.""" instance_type: Optional[str] = Field( - None, - description="DEPRECATED: The instance type to use for the step execution. " - "Use estimator_args instead. Example: 'ml.m5.xlarge'", + 'optimal', + description="The instance type for AWS Batch to use for the step" \ + " execution. Example: 'm5.xlarge'", ) environment: Dict[str, str] = Field( default_factory=dict, - description="Environment variables to pass to the container during execution. " - "Example: {'LOG_LEVEL': 'INFO', 'DEBUG_MODE': 'False'}", + description="Environment variables to pass to the container during " \ + "execution. Example: {'LOG_LEVEL': 'INFO', 'DEBUG_MODE': 'False'}", ) - - _deprecation_validator = deprecation_utils.deprecate_pydantic_attributes( - "instance_type" + timeout_seconds: int = Field( + default=120, + description="The number of seconds before AWS Batch times out the job." ) - class AWSBatchStepOperatorConfig( BaseStepOperatorConfig, AWSBatchStepOperatorSettings ): - """Config for the AWS Batch step operator.""" + """Config for the AWS Batch step operator. + + Note: We use ECS as a backend (not EKS), and EC2 as a compute engine (not + Fargate). This is because + - users can avoid the complexity of setting up an EKS cluster, and + - we can AWS Batch multinode type job support later, which requires EC2 + """ execution_role: str = Field( "", description="The ECS execution role required to execute the AWS Batch" \ - " jobs as ECS tasks." + " jobs as an ECS tasks." ) job_role: str = Field( "", - description="The ECS job role required by the container runtime insdide" \ - "the ECS task implementing the zenml step." + description="The ECS job role required by the container runtime inside" \ + "the ECS task." + ) + job_queue_name: str = Field( + description="The AWS Batch job queue to submit AWS Batch jobs to." ) @property diff --git a/src/zenml/integrations/aws/step_operators/batch_step_operator.py b/src/zenml/integrations/aws/step_operators/batch_step_operator.py index dd296ce3091..b8e93745055 100644 --- a/src/zenml/integrations/aws/step_operators/batch_step_operator.py +++ b/src/zenml/integrations/aws/step_operators/batch_step_operator.py @@ -54,6 +54,7 @@ if TYPE_CHECKING: from zenml.config.base_settings import BaseSettings + from zenml.config import ResourceSettings from zenml.config.step_run_info import StepRunInfo from zenml.models import PipelineDeploymentBase @@ -63,6 +64,7 @@ _ENTRYPOINT_ENV_VARIABLE = "__ZENML_ENTRYPOINT" class AWSBatchJobDefinitionContainerProperties(BaseModel): + """An AWS Batch job subconfiguration model for a container specification.""" image: str command: List[str] jobRoleArn: str @@ -73,20 +75,31 @@ class AWSBatchJobDefinitionContainerProperties(BaseModel): secrets: List[Dict[str,str]] = [] # keys: 'name','value' class AWSBatchJobDefinitionNodePropertiesNodeRangeProperty(BaseModel): + """An AWS Batch job subconfiguration model for a node in a multinode job + specifications. + + Note: We include this class for completeness sake to make it easier to add + multinode support later for now. + """ targetNodes: str container: AWSBatchJobDefinitionContainerProperties class AWSBatchJobDefinitionNodeProperties(BaseModel): - # we include this class for completeness sake to make it easier - # to add multinode support later - # for now, we'll set defaults to intuitively represent the only supported - # exeuction type ('container'); in reality AWS Batch will ignore this - # config + """An AWS Batch job subconfiguration model for multinode job specifications. + + Note: We include this class for completeness sake to make it easier to add + multinode support later for now, we'll set defaults to intuitively + represent the only supported exeuction type ('container'); in reality AWS + Batch will ignore this config. + """ numNodes: int = 1 mainNode: int = 0 - nodeRangeProperties: List[AWSBatchJobDefinitionNodePropertiesNodeRangeProperty] = [] + nodeRangeProperties: List[ + AWSBatchJobDefinitionNodePropertiesNodeRangeProperty + ] = [] class AWSBatchJobDefinitionRetryStrategy(BaseModel): + """An AWS Batch job subconfiguration model for retry specifications.""" attempts: int = 2 evaluateOnExit: List[Dict[str,str]] = [ { @@ -102,15 +115,15 @@ class AWSBatchJobDefinitionRetryStrategy(BaseModel): "action": "EXIT" } ] - # Example: - # { - # 'onStatusReason': 'string', - # 'onReason': 'string', - # 'onExitCode': 'string', - # 'action': 'RETRY'|'EXIT' - # }, class AWSBatchJobDefinition(BaseModel): + """A utility to validate AWS Batch job descriptions. + + Defaults fall into two categories: + - reasonable default values + - aligning the job description to be a valid 'container' type configuration, + as multinode jobs are not supported yet.""" + jobDefinitionName: str type: Literal['container','multinode'] = 'container' # we dont support multinode type in this version parameters: Dict[str,str] = {} @@ -126,15 +139,15 @@ class AWSBatchJobDefinition(BaseModel): class AWSBatchStepOperator(BaseStepOperator): - """Step operator to run a step on Sagemaker. + """Step operator to run a step on AWS Batch. This class defines code that builds an image with the ZenML entrypoint - to run using Sagemaker's Estimator. + to run using AWS Batch. """ @property def config(self) -> AWSBatchStepOperatorConfig: - """Returns the `SagemakerStepOperatorConfig` config. + """Returns the `AWSBatchStepOperatorConfig` config. Returns: The configuration. @@ -143,7 +156,7 @@ def config(self) -> AWSBatchStepOperatorConfig: @property def settings_class(self) -> Optional[Type["BaseSettings"]]: - """Settings class for the SageMaker step operator. + """Settings class for the AWS Batch step operator. Returns: The settings class. @@ -204,11 +217,117 @@ def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]: custom_validation_function=_validate_remote_components, ) + @staticmethod + def map_environment(environment: Dict[str,str]) -> List[Dict[str,str]]: + """Utility to map the {name:value} environment to the + [{"name":name,"value":value},] convention used in the AWS Batch job + definition spec. + + Args: + environment (Dict[str,str]): The step's environment variable + specification + + Returns: + List[Dict[str,str]]: The mapped environment variable specification + """ + + return [ + {"name":k,"value":v} for k,v in environment + ] + + @staticmethod + def map_resource_settings(resource_settings: ResourceSettings) -> List[Dict[str,str]]: + """Utility to map the resource_settings to the resource convention used + in the AWS Batch Job definition spec. + + Args: + resource_settings (ResourceSettings): The step's resource settings. + + Returns: + List[Dict[str,str]]: The mapped resource settings. + """ + mapped_resource_settings = [] + + if resource_settings.empty: + return mapped_resource_settings + else: + + if resource_settings.cpu_count is not None: + mapped_resource_settings.append( + { + "value": resource_settings.cpu_count, + "type": 'VCPU' + } + ) + + if resource_settings.gpu_count is not None: + mapped_resource_settings.append( + { + "value": resource_settings.gpu_count, + "type": 'GPU' + } + ) + + if resource_settings.get_memory() is not None: + mapped_resource_settings.append( + { + "value": resource_settings.get_memory(), + "type": 'MEMORY' + } + ) + + return mapped_resource_settings + + @staticmethod + def generate_unique_batch_job_name(info: "StepRunInfo") -> str: + """Utility to generate a unique AWS Batch job name. + + Args: + info (StepRunInfo): The step run information. + + Returns: + str: A unique name for the step's AWS Batch job definition + """ + + # Batch allows 63 characters at maximum for job name - ZenML uses 60 for safety margin. + step_name = Client().get_run_step(info.step_run_id).name + job_name = f"{info.pipeline.name}-{step_name}"[:55] + suffix = random_str(4) + return f"{job_name}-{suffix}" def generate_job_definition(self, info: "StepRunInfo", entrypoint_command: List[str], environment: Dict[str,str]) -> AWSBatchJobDefinition: """Utility to map zenml internal configurations to a valid AWS Batch job definition.""" - pass + + image_name = info.get_image(key=BATCH_DOCKER_IMAGE_KEY) + + resource_settings = info.config.resource_settings + step_settings = cast(AWSBatchStepOperatorSettings, self.get_settings(info)) + + job_name = self.generate_unique_batch_job_name(info) + + return AWSBatchJobDefinition( + jobDefinitionName=job_name, + containerProperties=AWSBatchJobDefinitionContainerProperties( + executionRoleArn=self.config.execution_role, + jobRoleArn=self.config.job_role, + image=image_name, + command=entrypoint_command, + environment=self.map_environment(environment), + instanceType=step_settings.instance_type, + resourceRequirements=self.map_resource_settings(resource_settings), + ), + timeout={'attemptDurationSeconds':step_settings.timeout_seconds}, + # type: Literal['container','multinode'] = 'container' # we dont support multinode type in this version + # parameters: Dict[str,str] = {} + # schedulingPriority: int = 0 # ignored in FIFO queues + # nodeProperties: AWSBatchJobDefinitionNodeProperties = AWSBatchJobDefinitionNodeProperties( + # numNodes=1,mainNode=0,nodeRangeProperties=[]) # we'll focus on container mode for now - let's add multinode support later, as that will most likely require network configuration support as well + # retryStrategy: AWSBatchJobDefinitionRetryStrategy = AWSBatchJobDefinitionRetryStrategy() + # propagateTags: bool = False + # tags: Dict[str,str] = {} + # platformCapabilities: Literal['EC2','FARGATE'] = "EC2" + ) def get_docker_builds( @@ -264,33 +383,20 @@ def launch( self.name, ) - image_name = info.get_image(key=BATCH_DOCKER_IMAGE_KEY) - - settings = cast(AWSBatchStepOperatorSettings, self.get_settings(info)) + job_definition = self.generate_job_definition(info, entrypoint_command, environment) batch = boto3.client('batch') - - # Batch allows 63 characters at maximum for job name - ZenML uses 60 for safety margin. - step_name = Client().get_run_step(info.step_run_id).name - training_job_name = f"{info.pipeline.name}-{step_name}"[:55] - suffix = random_str(4) - unique_training_job_name = f"{training_job_name}-{suffix}" - + response = batch.register_job_definition( - jobDefinitionName=unique_training_job_name, - type='container', - containerProperties={ - 'image': image_name , - 'command': entrypoint_command, - } + **job_definition.model_dump() ) - job_definition = response['jobDefinitionName'] + job_definition_name = response['jobDefinitionName'] response = batch.submit_job( - jobName=unique_training_job_name, + jobName=job_definition.jobDefinitionName, jobQueue=self.config.job_queue_name, - jobDefinition=job_definition, + jobDefinition=job_definition_name, ) job_id = response['jobId'] From 05433317b18437ac52d72bd7f977cd8c2c4f04e1 Mon Sep 17 00:00:00 2001 From: SebastianScherer88 Date: Sat, 13 Sep 2025 20:48:41 +0100 Subject: [PATCH 11/36] renaming modules and adding unit tests --- pyproject.toml | 2 +- .../integrations/aws/flavors/__init__.py | 2 +- ...r.py => aws_batch_step_operator_flavor.py} | 9 +- .../aws/step_operators/__init__.py | 2 +- ...operator.py => aws_batch_step_operator.py} | 30 +++---- ..._batch_step_operator_entrypoint_config.py} | 19 +--- .../aws/step_operators/__init__.py | 13 +++ .../test_aws_batch_step_operator.py | 86 +++++++++++++++++++ .../test_aws_batch_step_operator_flavor.py | 8 ++ 9 files changed, 130 insertions(+), 41 deletions(-) rename src/zenml/integrations/aws/flavors/{batch_step_operator_flavor.py => aws_batch_step_operator_flavor.py} (96%) rename src/zenml/integrations/aws/step_operators/{batch_step_operator.py => aws_batch_step_operator.py} (94%) rename src/zenml/integrations/aws/step_operators/{batch_step_operator_entrypoint_config.py => aws_batch_step_operator_entrypoint_config.py} (56%) create mode 100644 tests/integration/integrations/aws/step_operators/__init__.py create mode 100644 tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py create mode 100644 tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator_flavor.py diff --git a/pyproject.toml b/pyproject.toml index 61e3bce7497..11213fe1442 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -250,7 +250,7 @@ dev = [ "darglint", "pytest-randomly", "pytest-mock", - "pytest-clarity", + #"pytest-clarity", "pytest-instafail", "pytest-rerunfailures", "pytest-split", diff --git a/src/zenml/integrations/aws/flavors/__init__.py b/src/zenml/integrations/aws/flavors/__init__.py index d5c0c35ea64..bc32651cc72 100644 --- a/src/zenml/integrations/aws/flavors/__init__.py +++ b/src/zenml/integrations/aws/flavors/__init__.py @@ -29,7 +29,7 @@ SagemakerStepOperatorConfig, SagemakerStepOperatorFlavor, ) -from zenml.integrations.aws.flavors.batch_step_operator_flavor import ( +from zenml.integrations.aws.flavors.aws_batch_step_operator_flavor import ( AWSBatchStepOperatorConfig, AWSBatchStepOperatorFlavor ) diff --git a/src/zenml/integrations/aws/flavors/batch_step_operator_flavor.py b/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py similarity index 96% rename from src/zenml/integrations/aws/flavors/batch_step_operator_flavor.py rename to src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py index ee60a46f3e6..f814da513a4 100644 --- a/src/zenml/integrations/aws/flavors/batch_step_operator_flavor.py +++ b/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Amazon SageMaker step operator flavor.""" -from typing import TYPE_CHECKING, Any, Dict, Optional, Type, Union +from typing import TYPE_CHECKING, Dict, Optional, Type from pydantic import Field @@ -27,7 +27,6 @@ BaseStepOperatorConfig, BaseStepOperatorFlavor, ) -from zenml.utils import deprecation_utils if TYPE_CHECKING: from zenml.integrations.aws.step_operators import AWSBatchStepOperator @@ -36,13 +35,13 @@ class AWSBatchStepOperatorSettings(BaseSettings): """Settings for the Sagemaker step operator.""" - instance_type: Optional[str] = Field( - 'optimal', + instance_type: str = Field( + default='optimal', description="The instance type for AWS Batch to use for the step" \ " execution. Example: 'm5.xlarge'", ) environment: Dict[str, str] = Field( - default_factory=dict, + default={}, description="Environment variables to pass to the container during " \ "execution. Example: {'LOG_LEVEL': 'INFO', 'DEBUG_MODE': 'False'}", ) diff --git a/src/zenml/integrations/aws/step_operators/__init__.py b/src/zenml/integrations/aws/step_operators/__init__.py index b2d32099141..d04fccea431 100644 --- a/src/zenml/integrations/aws/step_operators/__init__.py +++ b/src/zenml/integrations/aws/step_operators/__init__.py @@ -16,7 +16,7 @@ from zenml.integrations.aws.step_operators.sagemaker_step_operator import ( # noqa: F401 SagemakerStepOperator, ) -from zenml.integrations.aws.step_operators.batch_step_operator import ( # noqa: F401 +from zenml.integrations.aws.step_operators.aws_batch_step_operator import ( # noqa: F401 AWSBatchStepOperator, ) __all__ = ["SagemakerStepOperator","AWSBatchStepOperator"] diff --git a/src/zenml/integrations/aws/step_operators/batch_step_operator.py b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py similarity index 94% rename from src/zenml/integrations/aws/step_operators/batch_step_operator.py rename to src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py index b8e93745055..4604b2a4064 100644 --- a/src/zenml/integrations/aws/step_operators/batch_step_operator.py +++ b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py @@ -14,6 +14,7 @@ """Implementation of the Sagemaker Step Operator.""" import time +import math from typing import ( TYPE_CHECKING, Dict, @@ -21,26 +22,20 @@ Optional, Tuple, Type, - Union, Literal, cast, ) from pydantic import BaseModel - import boto3 -from sagemaker.estimator import Estimator -from sagemaker.inputs import TrainingInput -from sagemaker.session import Session from zenml.client import Client from zenml.config.build_configuration import BuildConfiguration from zenml.enums import StackComponentType -from zenml.integrations.aws.flavors.batch_step_operator_flavor import ( +from zenml.integrations.aws.flavors.aws_batch_step_operator_flavor import ( AWSBatchStepOperatorConfig, AWSBatchStepOperatorSettings, ) -from zenml.integrations.aws.step_operators.batch_step_operator_entrypoint_config import ( - BATCH_STEP_ENV_VAR_SIZE_LIMIT, +from zenml.integrations.aws.step_operators.aws_batch_step_operator_entrypoint_config import ( AWSBatchEntrypointConfiguration, ) from zenml.logger import get_logger @@ -49,7 +44,6 @@ from zenml.step_operators.step_operator_entrypoint_configuration import ( StepOperatorEntrypointConfiguration, ) -from zenml.utils.env_utils import split_environment_variables from zenml.utils.string_utils import random_str if TYPE_CHECKING: @@ -60,7 +54,7 @@ logger = get_logger(__name__) -BATCH_DOCKER_IMAGE_KEY = "batch_step_operator" +BATCH_DOCKER_IMAGE_KEY = "aws_batch_step_operator" _ENTRYPOINT_ENV_VARIABLE = "__ZENML_ENTRYPOINT" class AWSBatchJobDefinitionContainerProperties(BaseModel): @@ -232,11 +226,11 @@ def map_environment(environment: Dict[str,str]) -> List[Dict[str,str]]: """ return [ - {"name":k,"value":v} for k,v in environment + {"name":k,"value":v} for k,v in environment.items() ] @staticmethod - def map_resource_settings(resource_settings: ResourceSettings) -> List[Dict[str,str]]: + def map_resource_settings(resource_settings: "ResourceSettings") -> List[Dict[str,str]]: """Utility to map the resource_settings to the resource convention used in the AWS Batch Job definition spec. @@ -253,9 +247,15 @@ def map_resource_settings(resource_settings: ResourceSettings) -> List[Dict[str, else: if resource_settings.cpu_count is not None: + + cpu_count_int = math.ceil(resource_settings.cpu_count) + + if cpu_count_int != resource_settings.cpu_count: + logger.info(f"AWS Batch only accepts int type cpu resource requirements. Converted {resource_settings.cpu_count} to {cpu_count_int}") + mapped_resource_settings.append( { - "value": resource_settings.cpu_count, + "value": str(cpu_count_int), "type": 'VCPU' } ) @@ -263,7 +263,7 @@ def map_resource_settings(resource_settings: ResourceSettings) -> List[Dict[str, if resource_settings.gpu_count is not None: mapped_resource_settings.append( { - "value": resource_settings.gpu_count, + "value": str(resource_settings.gpu_count), "type": 'GPU' } ) @@ -271,7 +271,7 @@ def map_resource_settings(resource_settings: ResourceSettings) -> List[Dict[str, if resource_settings.get_memory() is not None: mapped_resource_settings.append( { - "value": resource_settings.get_memory(), + "value": str(int(resource_settings.get_memory(unit="MiB"))), "type": 'MEMORY' } ) diff --git a/src/zenml/integrations/aws/step_operators/batch_step_operator_entrypoint_config.py b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator_entrypoint_config.py similarity index 56% rename from src/zenml/integrations/aws/step_operators/batch_step_operator_entrypoint_config.py rename to src/zenml/integrations/aws/step_operators/aws_batch_step_operator_entrypoint_config.py index 9cc99c6c4c8..cd57398a38f 100644 --- a/src/zenml/integrations/aws/step_operators/batch_step_operator_entrypoint_config.py +++ b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator_entrypoint_config.py @@ -16,24 +16,7 @@ from zenml.step_operators.step_operator_entrypoint_configuration import ( StepOperatorEntrypointConfiguration, ) -from zenml.utils.env_utils import reconstruct_environment_variables - -BATCH_STEP_ENV_VAR_SIZE_LIMIT = 512 class AWSBatchEntrypointConfiguration(StepOperatorEntrypointConfiguration): - """Entrypoint configuration for ZenML Batch step operator. - - The only purpose of this entrypoint configuration is to reconstruct the - environment variables that exceed the maximum length of 512 characters - allowed for Batch steps from their individual components. - """ - - def run(self) -> None: - """Runs the step.""" - # Reconstruct the environment variables that exceed the maximum length - # of 512 characters from their individual chunks - reconstruct_environment_variables() - - # Run the step - super().run() + """Entrypoint configuration for ZenML Batch step operator.""" \ No newline at end of file diff --git a/tests/integration/integrations/aws/step_operators/__init__.py b/tests/integration/integrations/aws/step_operators/__init__.py new file mode 100644 index 00000000000..cd90a82cfc2 --- /dev/null +++ b/tests/integration/integrations/aws/step_operators/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. diff --git a/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py b/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py new file mode 100644 index 00000000000..6ab344da4f1 --- /dev/null +++ b/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py @@ -0,0 +1,86 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. + + +import pytest + +from zenml.config.resource_settings import ResourceSettings +from zenml.integrations.aws.step_operators.aws_batch_step_operator import AWSBatchStepOperator + +def test_aws_batch_step_operator_map_environment(): + test_environment = {'key_1':'value_1','key_2':'value_2'} + expected = [ + { + "name": "key_1", + "value": "value_1" + }, + { + "name": "key_2", + "value": "value_2" + } + ] + + assert AWSBatchStepOperator.map_environment(test_environment) == expected + +@pytest.mark.parametrize( + "test_resource_settings,expected", + [ + ( + ResourceSettings(), + [] + ), + ( + ResourceSettings(cpu_count=0.4,gpu_count=1,memory="10MiB"), + [ + { + "value":"1", + "type":"VCPU" + }, + { + "value":"1", + "type":"GPU" + }, + { + "value": "10", + "type": "MEMORY" + } + ] + ), + ( + ResourceSettings(cpu_count=1,gpu_count=1), + [ + { + "value":"1", + "type":"VCPU" + }, + { + "value":"1", + "type":"GPU" + }, + ] + ), + ( + ResourceSettings(memory="1GiB"), + [ + { + "value": "1024", + "type": "MEMORY" + } + ] + ), + + ] +) +def test_aws_batch_step_operator_map_resource_settings(test_resource_settings,expected): + assert AWSBatchStepOperator.map_resource_settings(test_resource_settings) == expected diff --git a/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator_flavor.py b/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator_flavor.py new file mode 100644 index 00000000000..2a160dd8127 --- /dev/null +++ b/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator_flavor.py @@ -0,0 +1,8 @@ +from zenml.integrations.aws.flavors.aws_batch_step_operator_flavor import AWSBatchStepOperatorSettings + +def test_aws_batch_step_operator_settings(): + AWSBatchStepOperatorSettings( + instance_type="g4dn.xlarge", + environment={"key_1":"value_1","key_2":"value_2"}, + timeout_seconds=60 + ) \ No newline at end of file From c9b58292dec864a0c1ed25bef98091e8af9b01c7 Mon Sep 17 00:00:00 2001 From: SebastianScherer88 Date: Sun, 14 Sep 2025 00:18:38 +0100 Subject: [PATCH 12/36] added support for multinode aws batch job type --- .../flavors/aws_batch_step_operator_flavor.py | 16 ++++-- .../step_operators/aws_batch_step_operator.py | 51 ++++++++++++------- 2 files changed, 44 insertions(+), 23 deletions(-) diff --git a/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py b/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py index f814da513a4..5eb3520f746 100644 --- a/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py +++ b/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py @@ -13,9 +13,9 @@ # permissions and limitations under the License. """Amazon SageMaker step operator flavor.""" -from typing import TYPE_CHECKING, Dict, Optional, Type +from typing import TYPE_CHECKING, Dict, Optional, Type, List, Union -from pydantic import Field +from pydantic import Field, PositiveInt from zenml.config.base_settings import BaseSettings from zenml.integrations.aws import ( @@ -35,7 +35,7 @@ class AWSBatchStepOperatorSettings(BaseSettings): """Settings for the Sagemaker step operator.""" - instance_type: str = Field( + instance_type: Union[str,List[str]] = Field( default='optimal', description="The instance type for AWS Batch to use for the step" \ " execution. Example: 'm5.xlarge'", @@ -45,11 +45,19 @@ class AWSBatchStepOperatorSettings(BaseSettings): description="Environment variables to pass to the container during " \ "execution. Example: {'LOG_LEVEL': 'INFO', 'DEBUG_MODE': 'False'}", ) - timeout_seconds: int = Field( + node_count: PositiveInt = Field( + description="The number of AWS Batch nodes to run the step on. If > 1," \ + "an AWS Batch multinode job will be run, with the network connectivity" \ + "between the nodes provided by AWS Batch. See https://docs.aws.amazon.com/batch/latest/userguide/multi-node-parallel-jobs.html" \ + "for details." + ) + timeout_seconds: PositiveInt = Field( default=120, description="The number of seconds before AWS Batch times out the job." ) + + class AWSBatchStepOperatorConfig( BaseStepOperatorConfig, AWSBatchStepOperatorSettings ): diff --git a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py index 4604b2a4064..a90863e7c60 100644 --- a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py +++ b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py @@ -25,7 +25,7 @@ Literal, cast, ) -from pydantic import BaseModel +from pydantic import BaseModel, PositiveInt import boto3 from zenml.client import Client @@ -71,22 +71,14 @@ class AWSBatchJobDefinitionContainerProperties(BaseModel): class AWSBatchJobDefinitionNodePropertiesNodeRangeProperty(BaseModel): """An AWS Batch job subconfiguration model for a node in a multinode job specifications. - - Note: We include this class for completeness sake to make it easier to add - multinode support later for now. """ targetNodes: str container: AWSBatchJobDefinitionContainerProperties class AWSBatchJobDefinitionNodeProperties(BaseModel): """An AWS Batch job subconfiguration model for multinode job specifications. - - Note: We include this class for completeness sake to make it easier to add - multinode support later for now, we'll set defaults to intuitively - represent the only supported exeuction type ('container'); in reality AWS - Batch will ignore this config. """ - numNodes: int = 1 + numNodes: PositiveInt = 1 mainNode: int = 0 nodeRangeProperties: List[ AWSBatchJobDefinitionNodePropertiesNodeRangeProperty @@ -94,7 +86,7 @@ class AWSBatchJobDefinitionNodeProperties(BaseModel): class AWSBatchJobDefinitionRetryStrategy(BaseModel): """An AWS Batch job subconfiguration model for retry specifications.""" - attempts: int = 2 + attempts: PositiveInt = 2 evaluateOnExit: List[Dict[str,str]] = [ { "onExitCode": "137", # out-of-memory killed @@ -122,9 +114,8 @@ class AWSBatchJobDefinition(BaseModel): type: Literal['container','multinode'] = 'container' # we dont support multinode type in this version parameters: Dict[str,str] = {} schedulingPriority: int = 0 # ignored in FIFO queues - containerProperties: AWSBatchJobDefinitionContainerProperties - nodeProperties: AWSBatchJobDefinitionNodeProperties = AWSBatchJobDefinitionNodeProperties( - numNodes=1,mainNode=0,nodeRangeProperties=[]) # we'll focus on container mode for now - let's add multinode support later, as that will most likely require network configuration support as well + containerProperties: Optional[AWSBatchJobDefinitionContainerProperties] = None + nodeProperties: Optional[AWSBatchJobDefinitionNodeProperties] = None retryStrategy: AWSBatchJobDefinitionRetryStrategy = AWSBatchJobDefinitionRetryStrategy() propagateTags: bool = False timeout: Dict[str,int] = {'attemptDurationSeconds':60} # key 'attemptDurationSeconds' @@ -305,10 +296,7 @@ def generate_job_definition(self, info: "StepRunInfo", entrypoint_command: List[ step_settings = cast(AWSBatchStepOperatorSettings, self.get_settings(info)) job_name = self.generate_unique_batch_job_name(info) - - return AWSBatchJobDefinition( - jobDefinitionName=job_name, - containerProperties=AWSBatchJobDefinitionContainerProperties( + container_properties = AWSBatchJobDefinitionContainerProperties( executionRoleArn=self.config.execution_role, jobRoleArn=self.config.job_role, image=image_name, @@ -317,8 +305,33 @@ def generate_job_definition(self, info: "StepRunInfo", entrypoint_command: List[ instanceType=step_settings.instance_type, resourceRequirements=self.map_resource_settings(resource_settings), ), + + node_count = step_settings.node_count + + if node_count == 1: + kwargs = { + 'type':'container', + 'containerProperties':container_properties + } + else: + kwargs = { + 'type':'multinode', + 'nodeProperties':AWSBatchJobDefinitionNodeProperties( + numNodes=node_count, + nodeRangeProperties=[ + AWSBatchJobDefinitionNodePropertiesNodeRangeProperty( + targetNodes=','.join([str(node_index) for node_index in range(node_count)]), + container=container_properties + ) + ] + ) + } + + + return AWSBatchJobDefinition( + jobDefinitionName=job_name, timeout={'attemptDurationSeconds':step_settings.timeout_seconds}, - # type: Literal['container','multinode'] = 'container' # we dont support multinode type in this version + **kwargs # parameters: Dict[str,str] = {} # schedulingPriority: int = 0 # ignored in FIFO queues # nodeProperties: AWSBatchJobDefinitionNodeProperties = AWSBatchJobDefinitionNodeProperties( From c787379597bc8e00b3abd75f0e44c9eb90907a39 Mon Sep 17 00:00:00 2001 From: SebastianScherer88 Date: Sun, 14 Sep 2025 00:59:04 +0100 Subject: [PATCH 13/36] added support for multinode aws batch job type --- .../flavors/aws_batch_step_operator_flavor.py | 1 + .../aws/step_operators/__init__.py | 3 +- .../step_operators/aws_batch_step_operator.py | 40 ++++++++++++++++++- .../test_aws_batch_step_operator.py | 20 +++++++++- 4 files changed, 61 insertions(+), 3 deletions(-) diff --git a/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py b/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py index 5eb3520f746..80e549d61bd 100644 --- a/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py +++ b/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py @@ -46,6 +46,7 @@ class AWSBatchStepOperatorSettings(BaseSettings): "execution. Example: {'LOG_LEVEL': 'INFO', 'DEBUG_MODE': 'False'}", ) node_count: PositiveInt = Field( + default=1, description="The number of AWS Batch nodes to run the step on. If > 1," \ "an AWS Batch multinode job will be run, with the network connectivity" \ "between the nodes provided by AWS Batch. See https://docs.aws.amazon.com/batch/latest/userguide/multi-node-parallel-jobs.html" \ diff --git a/src/zenml/integrations/aws/step_operators/__init__.py b/src/zenml/integrations/aws/step_operators/__init__.py index d04fccea431..e680dc1d063 100644 --- a/src/zenml/integrations/aws/step_operators/__init__.py +++ b/src/zenml/integrations/aws/step_operators/__init__.py @@ -18,5 +18,6 @@ ) from zenml.integrations.aws.step_operators.aws_batch_step_operator import ( # noqa: F401 AWSBatchStepOperator, + get_aws_batch_context ) -__all__ = ["SagemakerStepOperator","AWSBatchStepOperator"] +__all__ = ["SagemakerStepOperator","AWSBatchStepOperator","get_aws_batch_context"] diff --git a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py index a90863e7c60..a2e901f66c4 100644 --- a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py +++ b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py @@ -25,7 +25,8 @@ Literal, cast, ) -from pydantic import BaseModel, PositiveInt +from pydantic import BaseModel, PositiveInt, Field +from pydantic_settings import BaseSettings as PydanticBaseSettings import boto3 from zenml.client import Client @@ -57,6 +58,43 @@ BATCH_DOCKER_IMAGE_KEY = "aws_batch_step_operator" _ENTRYPOINT_ENV_VARIABLE = "__ZENML_ENTRYPOINT" +class AWSBatchContext(PydanticBaseSettings): + """A utility to access the AWS Batch job context environment variables.""" + main_node_index: int = Field( + description="This variable is set to the index number of the job's " + "main node. Your application code can compare the " + "AWS_BATCH_JOB_MAIN_NODE_INDEX to the AWS_BATCH_JOB_NODE_INDEX on " + "an individual node to determine if it's the main node.", + alias="AWS_BATCH_JOB_MAIN_NODE_INDEX" + ) + + main_node_address: str = Field( + description="This variable is only set in multi-node parallel job " + "child nodes. This variable isn't present on the main node. This " + "variable is set to the private IPv4 address of the job's main " + "node. Your child node's application code can use this address to " + "communicate with the main node.", + alias="AWS_BATCH_JOB_MAIN_NODE_PRIVATE_IPV4_ADDRESS" + ) + + node_index: int = Field( + description="This variable is set to the node index number of the " + "node. The node index begins at 0, and each node receives a unique" + " index number. For example, a multi-node parallel job with 10 " + "children has index values of 0-9.", + alias="AWS_BATCH_JOB_NODE_INDEX" + ) + + num_nodes: int = Field( + description="This variable is set to the number of nodes that you have" + "requested for your multi-node parallel job.", + alias="AWS_BATCH_JOB_NUM_NODES" + ) + +def get_aws_batch_context() -> AWSBatchContext: + """Utility to retrieve the AWS Batch runtime context.""" + return AWSBatchContext() + class AWSBatchJobDefinitionContainerProperties(BaseModel): """An AWS Batch job subconfiguration model for a container specification.""" image: str diff --git a/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py b/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py index 6ab344da4f1..b032818cfbf 100644 --- a/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py +++ b/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py @@ -16,9 +16,25 @@ import pytest from zenml.config.resource_settings import ResourceSettings -from zenml.integrations.aws.step_operators.aws_batch_step_operator import AWSBatchStepOperator +from zenml.integrations.aws.step_operators.aws_batch_step_operator import AWSBatchStepOperator, get_aws_batch_context + +def test_aws_batch_context(monkeypatch): + """Tests the AWSBatchContext class.""" + + monkeypatch.setenv('AWS_BATCH_JOB_MAIN_NODE_INDEX',0) + monkeypatch.setenv('AWS_BATCH_JOB_MAIN_NODE_PRIVATE_IPV4_ADDRESS','test-address') + monkeypatch.setenv('AWS_BATCH_JOB_NODE_INDEX',1) + monkeypatch.setenv('AWS_BATCH_JOB_NUM_NODES',2) + + test_aws_batch_context = get_aws_batch_context() + assert test_aws_batch_context.main_node_index == 0 + assert test_aws_batch_context.main_node_address == 'test-address' + assert test_aws_batch_context.node_index == 1 + assert test_aws_batch_context.num_nodes == 2 def test_aws_batch_step_operator_map_environment(): + """Tests the AWSBatchStepOperator's map_environment class method.""" + test_environment = {'key_1':'value_1','key_2':'value_2'} expected = [ { @@ -83,4 +99,6 @@ def test_aws_batch_step_operator_map_environment(): ] ) def test_aws_batch_step_operator_map_resource_settings(test_resource_settings,expected): + """Tests the AWSBatchStepOperator's map_resource_settings class method.""" + assert AWSBatchStepOperator.map_resource_settings(test_resource_settings) == expected From 5fd07616e1ba58a08d0991974220e17d9c2853a0 Mon Sep 17 00:00:00 2001 From: SebastianScherer88 Date: Sun, 14 Sep 2025 19:35:38 +0100 Subject: [PATCH 14/36] adding test dependency back in and fixing typo in sagemaker doc string --- pyproject.toml | 2 +- .../aws/step_operators/sagemaker_step_operator.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 11213fe1442..61e3bce7497 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -250,7 +250,7 @@ dev = [ "darglint", "pytest-randomly", "pytest-mock", - #"pytest-clarity", + "pytest-clarity", "pytest-instafail", "pytest-rerunfailures", "pytest-split", diff --git a/src/zenml/integrations/aws/step_operators/sagemaker_step_operator.py b/src/zenml/integrations/aws/step_operators/sagemaker_step_operator.py index 2df9a81d21e..6456d4d9a37 100644 --- a/src/zenml/integrations/aws/step_operators/sagemaker_step_operator.py +++ b/src/zenml/integrations/aws/step_operators/sagemaker_step_operator.py @@ -121,12 +121,12 @@ def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]: if container_registry.config.is_local: return False, ( - "The Batch step operator runs code remotely and " + "The SageMaker step operator runs code remotely and " "needs to push/pull Docker images, but the " f"container registry `{container_registry.name}` of the " "active stack is local. Please ensure that your stack " "contains a remote container registry when using the " - "Batch step operator." + "SageMaker step operator." ) return True, "" From 5466799c1ec7f00c56767c76f6926058ebfb98fe Mon Sep 17 00:00:00 2001 From: SebastianScherer88 Date: Sun, 14 Sep 2025 19:39:20 +0100 Subject: [PATCH 15/36] renaming the aws batch runtime context retrieval utility --- src/zenml/integrations/aws/step_operators/__init__.py | 4 ++-- .../aws/step_operators/aws_batch_step_operator.py | 2 +- .../aws/step_operators/test_aws_batch_step_operator.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/zenml/integrations/aws/step_operators/__init__.py b/src/zenml/integrations/aws/step_operators/__init__.py index e680dc1d063..ecc24b98f0e 100644 --- a/src/zenml/integrations/aws/step_operators/__init__.py +++ b/src/zenml/integrations/aws/step_operators/__init__.py @@ -18,6 +18,6 @@ ) from zenml.integrations.aws.step_operators.aws_batch_step_operator import ( # noqa: F401 AWSBatchStepOperator, - get_aws_batch_context + get_context ) -__all__ = ["SagemakerStepOperator","AWSBatchStepOperator","get_aws_batch_context"] +__all__ = ["SagemakerStepOperator","AWSBatchStepOperator","get_context"] diff --git a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py index a2e901f66c4..69ef0b49efe 100644 --- a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py +++ b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py @@ -91,7 +91,7 @@ class AWSBatchContext(PydanticBaseSettings): alias="AWS_BATCH_JOB_NUM_NODES" ) -def get_aws_batch_context() -> AWSBatchContext: +def get_context() -> AWSBatchContext: """Utility to retrieve the AWS Batch runtime context.""" return AWSBatchContext() diff --git a/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py b/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py index b032818cfbf..3aaa7bf9a43 100644 --- a/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py +++ b/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py @@ -16,7 +16,7 @@ import pytest from zenml.config.resource_settings import ResourceSettings -from zenml.integrations.aws.step_operators.aws_batch_step_operator import AWSBatchStepOperator, get_aws_batch_context +from zenml.integrations.aws.step_operators.aws_batch_step_operator import AWSBatchStepOperator, get_context def test_aws_batch_context(monkeypatch): """Tests the AWSBatchContext class.""" @@ -26,7 +26,7 @@ def test_aws_batch_context(monkeypatch): monkeypatch.setenv('AWS_BATCH_JOB_NODE_INDEX',1) monkeypatch.setenv('AWS_BATCH_JOB_NUM_NODES',2) - test_aws_batch_context = get_aws_batch_context() + test_aws_batch_context = get_context() assert test_aws_batch_context.main_node_index == 0 assert test_aws_batch_context.main_node_address == 'test-address' assert test_aws_batch_context.node_index == 1 From 17de12b534ee797f882430bc6ec8f691b92ff408 Mon Sep 17 00:00:00 2001 From: SebastianScherer88 Date: Tue, 16 Sep 2025 22:20:21 +0100 Subject: [PATCH 16/36] bounding aws integration dependency boto3 < 2 --- src/zenml/integrations/aws/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zenml/integrations/aws/__init__.py b/src/zenml/integrations/aws/__init__.py index 231d81a6bb0..477c7c56510 100644 --- a/src/zenml/integrations/aws/__init__.py +++ b/src/zenml/integrations/aws/__init__.py @@ -43,7 +43,7 @@ class AWSIntegration(Integration): NAME = AWS REQUIREMENTS = [ "sagemaker>=2.237.3", - "boto3>=1.40.30", + "boto3>=1.40.30,<2.0", "kubernetes", "aws-profile-manager", ] From 1fcefac2b0c71cf8429264b604e2671b65070621 Mon Sep 17 00:00:00 2001 From: SebastianScherer88 Date: Tue, 16 Sep 2025 22:24:26 +0100 Subject: [PATCH 17/36] using immutable default dict factory instead of mutable empty dict value --- .../integrations/aws/flavors/aws_batch_step_operator_flavor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py b/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py index 80e549d61bd..33dcfe28833 100644 --- a/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py +++ b/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py @@ -41,7 +41,7 @@ class AWSBatchStepOperatorSettings(BaseSettings): " execution. Example: 'm5.xlarge'", ) environment: Dict[str, str] = Field( - default={}, + default_factory=dict, description="Environment variables to pass to the container during " \ "execution. Example: {'LOG_LEVEL': 'INFO', 'DEBUG_MODE': 'False'}", ) From eb6c3205aaf7829882eb003b94d82edb6e7c2d3c Mon Sep 17 00:00:00 2001 From: SebastianScherer88 Date: Tue, 16 Sep 2025 22:33:20 +0100 Subject: [PATCH 18/36] removing commented out default args --- .../aws/step_operators/aws_batch_step_operator.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py index 69ef0b49efe..88eb4eef9b0 100644 --- a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py +++ b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py @@ -370,14 +370,6 @@ def generate_job_definition(self, info: "StepRunInfo", entrypoint_command: List[ jobDefinitionName=job_name, timeout={'attemptDurationSeconds':step_settings.timeout_seconds}, **kwargs - # parameters: Dict[str,str] = {} - # schedulingPriority: int = 0 # ignored in FIFO queues - # nodeProperties: AWSBatchJobDefinitionNodeProperties = AWSBatchJobDefinitionNodeProperties( - # numNodes=1,mainNode=0,nodeRangeProperties=[]) # we'll focus on container mode for now - let's add multinode support later, as that will most likely require network configuration support as well - # retryStrategy: AWSBatchJobDefinitionRetryStrategy = AWSBatchJobDefinitionRetryStrategy() - # propagateTags: bool = False - # tags: Dict[str,str] = {} - # platformCapabilities: Literal['EC2','FARGATE'] = "EC2" ) From d1c002b698d945e83d5bb4476907e91c4af29e1e Mon Sep 17 00:00:00 2001 From: SebastianScherer88 Date: Tue, 16 Sep 2025 22:35:03 +0100 Subject: [PATCH 19/36] removing incorrect warning stating that step level resources specification is not configured --- .../aws/step_operators/aws_batch_step_operator.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py index 88eb4eef9b0..c9235dfaced 100644 --- a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py +++ b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py @@ -415,16 +415,6 @@ def launch( RuntimeError: If the connector returns an object that is not a `boto3.Session`. """ - if not info.config.resource_settings.empty: - logger.warning( - "Specifying custom step resources is not supported for " - "the AWS Batch step operator. If you want to run this step " - "operator on specific resources, you can do so by configuring " - "a different instance type like this: " - "`zenml step-operator update %s " - "--instance_type=`", - self.name, - ) job_definition = self.generate_job_definition(info, entrypoint_command, environment) From 98e014e38c4748cb9111ccffad415ee1dff63779 Mon Sep 17 00:00:00 2001 From: SebastianScherer88 Date: Tue, 16 Sep 2025 22:51:51 +0100 Subject: [PATCH 20/36] increased timeout error to 1h and added batch client error handling --- .../flavors/aws_batch_step_operator_flavor.py | 2 +- .../step_operators/aws_batch_step_operator.py | 22 ++++++++++++++----- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py b/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py index 33dcfe28833..3c0e882ccc7 100644 --- a/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py +++ b/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py @@ -53,7 +53,7 @@ class AWSBatchStepOperatorSettings(BaseSettings): "for details." ) timeout_seconds: PositiveInt = Field( - default=120, + default=3600, description="The number of seconds before AWS Batch times out the job." ) diff --git a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py index c9235dfaced..345d9edcbfc 100644 --- a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py +++ b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py @@ -46,6 +46,7 @@ StepOperatorEntrypointConfiguration, ) from zenml.utils.string_utils import random_str +from botocore.exceptions import ClientError if TYPE_CHECKING: from zenml.config.base_settings import BaseSettings @@ -435,9 +436,18 @@ def launch( job_id = response['jobId'] while True: - response = batch.describe_jobs(jobs=[job_id]) - status = response['jobs'][0]['status'] - if status in ['SUCCEEDED', 'FAILED']: - break - time.sleep(10) - logger.info(f'Job completed with status {status}') \ No newline at end of file + try: + response = batch.describe_jobs(jobs=[job_id]) + status = response['jobs'][0]['status'] + + if status == ['SUCCEEDED']: + logger.info(f"Job completed successfully: {job_id}") + break + elif status == ["FAILED"]: + status_reason = response['jobs'][0].get('statusReason', 'Unknown') + raise RuntimeError(f'Job {job_id} failed: {status_reason}') + else: + time.sleep(10) + except ClientError as e: + logger.error(f"Failed to describe job {job_id}: {e}") + raise \ No newline at end of file From 1be59658e95188a4257faad7984effdccbae17a2 Mon Sep 17 00:00:00 2001 From: SebastianScherer88 Date: Tue, 16 Sep 2025 23:27:03 +0100 Subject: [PATCH 21/36] replicated the sagemaker orchestrator aws authentication and session creation approach. extended the flavour config to include the required aws auth related fields --- .../flavors/aws_batch_step_operator_flavor.py | 36 +++++++++--- .../step_operators/aws_batch_step_operator.py | 58 +++++++++++++++++-- 2 files changed, 82 insertions(+), 12 deletions(-) diff --git a/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py b/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py index 3c0e882ccc7..b318b604f27 100644 --- a/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py +++ b/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py @@ -16,7 +16,7 @@ from typing import TYPE_CHECKING, Dict, Optional, Type, List, Union from pydantic import Field, PositiveInt - +from zenml.utils.secret_utils import SecretField from zenml.config.base_settings import BaseSettings from zenml.integrations.aws import ( AWS_RESOURCE_TYPE, @@ -71,18 +71,40 @@ class AWSBatchStepOperatorConfig( """ execution_role: str = Field( - "", - description="The ECS execution role required to execute the AWS Batch" \ - " jobs as an ECS tasks." + description="The IAM role arn of the ECS execution role." ) job_role: str = Field( - "", - description="The ECS job role required by the container runtime inside" \ - "the ECS task." + description="The IAM role arn of the ECS job role." ) job_queue_name: str = Field( description="The AWS Batch job queue to submit AWS Batch jobs to." ) + aws_access_key_id: Optional[str] = SecretField( + default=None, + description="The AWS access key ID to use to authenticate to AWS. " + "If not provided, the value from the default AWS config will be used.", + ) + aws_secret_access_key: Optional[str] = SecretField( + default=None, + description="The AWS secret access key to use to authenticate to AWS. " + "If not provided, the value from the default AWS config will be used.", + ) + aws_profile: Optional[str] = Field( + None, + description="The AWS profile to use for authentication if not using " + "service connectors or explicit credentials. If not provided, the " + "default profile will be used.", + ) + aws_auth_role_arn: Optional[str] = Field( + None, + description="The ARN of an intermediate IAM role to assume when " + "authenticating to AWS.", + ) + region: Optional[str] = Field( + None, + description="The AWS region where the processing job will be run. " + "If not provided, the value from the default AWS config will be used.", + ) @property def is_remote(self) -> bool: diff --git a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py index 345d9edcbfc..fb017e2b3b9 100644 --- a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py +++ b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py @@ -27,7 +27,7 @@ ) from pydantic import BaseModel, PositiveInt, Field from pydantic_settings import BaseSettings as PydanticBaseSettings -import boto3 +from boto3 import Session from zenml.client import Client from zenml.config.build_configuration import BuildConfiguration @@ -186,6 +186,53 @@ def settings_class(self) -> Optional[Type["BaseSettings"]]: The settings class. """ return AWSBatchStepOperatorSettings + + def _get_aws_session(self) -> Session: + """Method to create the AWS Batch session with proper authentication. + + Returns: + The AWS Batch session. + + Raises: + RuntimeError: If the connector returns the wrong type for the + session. + """ + # Get authenticated session + # Option 1: Service connector + boto_session: Session + if connector := self.get_connector(): + boto_session = connector.connect() + if not isinstance(boto_session, Session): + raise RuntimeError( + f"Expected to receive a `boto3.Session` object from the " + f"linked connector, but got type `{type(boto_session)}`." + ) + # Option 2: Explicit configuration + # Args that are not provided will be taken from the default AWS config. + else: + boto_session = Session( + aws_access_key_id=self.config.aws_access_key_id, + aws_secret_access_key=self.config.aws_secret_access_key, + region_name=self.config.region, + profile_name=self.config.aws_profile, + ) + # If a role ARN is provided for authentication, assume the role + if self.config.aws_auth_role_arn: + sts = boto_session.client("sts") + response = sts.assume_role( + RoleArn=self.config.aws_auth_role_arn, + RoleSessionName="zenml-aws-batch-step-operator", + ) + credentials = response["Credentials"] + boto_session = Session( + aws_access_key_id=credentials["AccessKeyId"], + aws_secret_access_key=credentials["SecretAccessKey"], + aws_session_token=credentials["SessionToken"], + region_name=self.config.region, + ) + return Session( + boto_session=boto_session, + ) @property def entrypoint_config_class( @@ -419,15 +466,16 @@ def launch( job_definition = self.generate_job_definition(info, entrypoint_command, environment) - batch = boto3.client('batch') + boto_session = self._get_aws_session() + batch_client = boto_session.client('batch') - response = batch.register_job_definition( + response = batch_client.register_job_definition( **job_definition.model_dump() ) job_definition_name = response['jobDefinitionName'] - response = batch.submit_job( + response = batch_client.submit_job( jobName=job_definition.jobDefinitionName, jobQueue=self.config.job_queue_name, jobDefinition=job_definition_name, @@ -437,7 +485,7 @@ def launch( while True: try: - response = batch.describe_jobs(jobs=[job_id]) + response = batch_client.describe_jobs(jobs=[job_id]) status = response['jobs'][0]['status'] if status == ['SUCCEEDED']: From 070ef62f36defdc7b1d049fa325df8c404a46703 Mon Sep 17 00:00:00 2001 From: SebastianScherer88 Date: Sun, 21 Sep 2025 01:57:54 +0100 Subject: [PATCH 22/36] fixes off the back initial functional testing --- src/zenml/integrations/aws/__init__.py | 1 - .../step_operators/aws_batch_step_operator.py | 165 ++++++++++++------ .../test_aws_batch_step_operator_flavor.py | 33 +++- 3 files changed, 145 insertions(+), 54 deletions(-) diff --git a/src/zenml/integrations/aws/__init__.py b/src/zenml/integrations/aws/__init__.py index 477c7c56510..03876585f5e 100644 --- a/src/zenml/integrations/aws/__init__.py +++ b/src/zenml/integrations/aws/__init__.py @@ -43,7 +43,6 @@ class AWSIntegration(Integration): NAME = AWS REQUIREMENTS = [ "sagemaker>=2.237.3", - "boto3>=1.40.30,<2.0", "kubernetes", "aws-profile-manager", ] diff --git a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py index 5874fdf04ee..84b4feaa7a3 100644 --- a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py +++ b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py @@ -96,23 +96,26 @@ def get_context() -> AWSBatchContext: """Utility to retrieve the AWS Batch runtime context.""" return AWSBatchContext() -class AWSBatchJobDefinitionContainerProperties(BaseModel): - """An AWS Batch job subconfiguration model for a container specification.""" +class AWSBatchJobDefinitionContainerTypeContainerProperties(BaseModel): + """An AWS Batch job subconfiguration model for a container type job's container specification.""" image: str command: List[str] jobRoleArn: str executionRoleArn: str environment: List[Dict[str,str]] = [] # keys: 'name','value' - instanceType: str resourceRequirements: List[Dict[str,str]] = [] # keys: 'value','type', with type one of 'GPU','VCPU','MEMORY' secrets: List[Dict[str,str]] = [] # keys: 'name','value' +class AWSBatchJobDefinitionMultinodeTypeContainerProperties(AWSBatchJobDefinitionContainerTypeContainerProperties): + """An AWS Batch job subconfiguration model for a multinode type job's container specification.""" + instanceType: Optional[str] = "" + class AWSBatchJobDefinitionNodePropertiesNodeRangeProperty(BaseModel): """An AWS Batch job subconfiguration model for a node in a multinode job specifications. """ targetNodes: str - container: AWSBatchJobDefinitionContainerProperties + container: AWSBatchJobDefinitionMultinodeTypeContainerProperties class AWSBatchJobDefinitionNodeProperties(BaseModel): """An AWS Batch job subconfiguration model for multinode job specifications. @@ -127,21 +130,20 @@ class AWSBatchJobDefinitionRetryStrategy(BaseModel): """An AWS Batch job subconfiguration model for retry specifications.""" attempts: PositiveInt = 2 evaluateOnExit: List[Dict[str,str]] = [ - { - "onExitCode": "137", # out-of-memory killed - "action": "RETRY" - }, - { - "onReason": "*Host EC2*", - "action": "RETRY" - }, - { - "onExitCode": "*", # match everything else - "action": "EXIT" - } + # { + # "onExitCode": "137", # out-of-memory killed + # "action": "RETRY" + # }, + # { + # "onReason": "*Host EC2 terminated", + # "action": "RETRY" + # }, + # { + # "action": "EXIT" + # } ] -class AWSBatchJobDefinition(BaseModel): +class AWSBatchJobBaseDefinition(BaseModel): """A utility to validate AWS Batch job descriptions. Defaults fall into two categories: @@ -150,17 +152,25 @@ class AWSBatchJobDefinition(BaseModel): as multinode jobs are not supported yet.""" jobDefinitionName: str - type: Literal['container','multinode'] = 'container' # we dont support multinode type in this version + type: Literal['container','multinode'] parameters: Dict[str,str] = {} - schedulingPriority: int = 0 # ignored in FIFO queues - containerProperties: Optional[AWSBatchJobDefinitionContainerProperties] = None - nodeProperties: Optional[AWSBatchJobDefinitionNodeProperties] = None + # schedulingPriority: int = 0 # ignored in FIFO queues retryStrategy: AWSBatchJobDefinitionRetryStrategy = AWSBatchJobDefinitionRetryStrategy() propagateTags: bool = False timeout: Dict[str,int] = {'attemptDurationSeconds':60} # key 'attemptDurationSeconds' tags: Dict[str,str] = {} - platformCapabilities: Literal['EC2','FARGATE'] = "EC2" #-- hardcode this to EC2, so we can use container and multinode interchangeably without worrying too much + platformCapabilities: List[Literal["EC2","FARGATE"]] = ["EC2"] #-- hardcode this to EC2, so we can use container and multinode interchangeably without worrying too much +class AWSBatchJobContainerTypeDefinition(AWSBatchJobBaseDefinition): + """A utility to validate AWS Batch job descriptions of type contaienr.""" + + type: Literal['container','multinode'] = 'container' # we dont support multinode type in this version + containerProperties: Optional[AWSBatchJobDefinitionContainerTypeContainerProperties] = None + +class AWSBatchJobMultinodeTypeDefinition(AWSBatchJobBaseDefinition): + """A utility to validate AWS Batch job descriptions of type multinode.""" + type: Literal['container','multinode'] = 'multinode' # we dont support multinode type in this version + nodeProperties: Optional[AWSBatchJobDefinitionNodeProperties] = None class AWSBatchStepOperator(BaseStepOperator): """Step operator to run a step on AWS Batch. @@ -230,9 +240,7 @@ def _get_aws_session(self) -> Session: aws_session_token=credentials["SessionToken"], region_name=self.config.region, ) - return Session( - boto_session=boto_session, - ) + return boto_session @property def entrypoint_config_class( @@ -320,7 +328,19 @@ def map_resource_settings(resource_settings: "ResourceSettings") -> List[Dict[st mapped_resource_settings = [] if resource_settings.empty: - return mapped_resource_settings + # aws batch job description requires a value for vcpu + mapped_resource_settings.extend( + [ + { + "value": "1", + "type": 'VCPU' + }, + { + "value": "1024", + "type": 'MEMORY' + } + ] + ) else: if resource_settings.cpu_count is not None: @@ -372,7 +392,7 @@ def generate_unique_batch_job_name(info: "StepRunInfo") -> str: suffix = random_str(4) return f"{job_name}-{suffix}" - def generate_job_definition(self, info: "StepRunInfo", entrypoint_command: List[str], environment: Dict[str,str]) -> AWSBatchJobDefinition: + def generate_job_definition(self, info: "StepRunInfo", entrypoint_command: List[str], environment: Dict[str,str]) -> AWSBatchJobContainerTypeDefinition | AWSBatchJobMultinodeTypeDefinition: """Utility to map zenml internal configurations to a valid AWS Batch job definition.""" @@ -382,43 +402,84 @@ def generate_job_definition(self, info: "StepRunInfo", entrypoint_command: List[ step_settings = cast(AWSBatchStepOperatorSettings, self.get_settings(info)) job_name = self.generate_unique_batch_job_name(info) - container_properties = AWSBatchJobDefinitionContainerProperties( - executionRoleArn=self.config.execution_role, - jobRoleArn=self.config.job_role, - image=image_name, - command=entrypoint_command, - environment=self.map_environment(environment), - instanceType=step_settings.instance_type, - resourceRequirements=self.map_resource_settings(resource_settings), - ), + # container_properties = AWSBatchJobDefinitionContainerProperties( + # executionRoleArn=self.config.execution_role, + # jobRoleArn=self.config.job_role, + # image=image_name, + # command=entrypoint_command, + # environment=self.map_environment(environment), + # resourceRequirements=self.map_resource_settings(resource_settings), + # ).model_dump(exclude='instanceType') node_count = step_settings.node_count if node_count == 1: - kwargs = { - 'type':'container', - 'containerProperties':container_properties - } + return AWSBatchJobContainerTypeDefinition( + jobDefinitionName=job_name, + timeout={'attemptDurationSeconds':step_settings.timeout_seconds}, + type="container", + containerProperties=AWSBatchJobDefinitionContainerTypeContainerProperties( + executionRoleArn=self.config.execution_role, + jobRoleArn=self.config.job_role, + image=image_name, + command=entrypoint_command, + environment=self.map_environment(environment), + resourceRequirements=self.map_resource_settings(resource_settings), + ) + ) + # kwargs = { + # 'type':'container', + # 'containerProperties':container_properties + # } + # return AWSBatchJobDefinition( + # jobDefinitionName=job_name, + # timeout={'attemptDurationSeconds':step_settings.timeout_seconds}, + # **kwargs + # ).model_dump(exclude='nodeProperties') else: - kwargs = { - 'type':'multinode', - 'nodeProperties':AWSBatchJobDefinitionNodeProperties( + + # kwargs = { + # 'type':'multinode', + # 'nodeProperties':AWSBatchJobDefinitionNodeProperties( + # numNodes=node_count, + # nodeRangeProperties=[ + # AWSBatchJobDefinitionNodePropertiesNodeRangeProperty( + # targetNodes=','.join([str(node_index) for node_index in range(node_count)]), + # container=AWSBatchJobDefinitionContainerProperties( + # executionRoleArn=self.config.execution_role, + # jobRoleArn=self.config.job_role, + # image=image_name, + # command=entrypoint_command, + # environment=self.map_environment(environment), + # instanceType=step_settings.instance_type, + # resourceRequirements=self.map_resource_settings(resource_settings), + # ) + # ) + # ] + # ) + # } + return AWSBatchJobMultinodeTypeDefinition( + jobDefinitionName=job_name, + timeout={'attemptDurationSeconds':step_settings.timeout_seconds}, + type="multinode", + nodeProperties=AWSBatchJobDefinitionNodeProperties( numNodes=node_count, nodeRangeProperties=[ AWSBatchJobDefinitionNodePropertiesNodeRangeProperty( targetNodes=','.join([str(node_index) for node_index in range(node_count)]), - container=container_properties + container=AWSBatchJobDefinitionMultinodeTypeContainerProperties( + executionRoleArn=self.config.execution_role, + jobRoleArn=self.config.job_role, + image=image_name, + command=entrypoint_command, + environment=self.map_environment(environment), + instanceType=step_settings.instance_type, + resourceRequirements=self.map_resource_settings(resource_settings), + ) ) ] ) - } - - - return AWSBatchJobDefinition( - jobDefinitionName=job_name, - timeout={'attemptDurationSeconds':step_settings.timeout_seconds}, - **kwargs - ) + ) def get_docker_builds( diff --git a/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator_flavor.py b/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator_flavor.py index 2a160dd8127..b658c47b5ae 100644 --- a/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator_flavor.py +++ b/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator_flavor.py @@ -1,8 +1,39 @@ from zenml.integrations.aws.flavors.aws_batch_step_operator_flavor import AWSBatchStepOperatorSettings +from zenml.integrations.aws.step_operators.aws_batch_step_operator import AWSBatchJobDefinition +from zenml import pipeline, step +from zenml.config import DockerSettings +from zenml.client import Client + +# Get the ZenML client +client = Client() def test_aws_batch_step_operator_settings(): AWSBatchStepOperatorSettings( instance_type="g4dn.xlarge", environment={"key_1":"value_1","key_2":"value_2"}, timeout_seconds=60 - ) \ No newline at end of file + ) + +def test_aws_batch_step_job_definition(): + AWSBatchJobDefinition( + jobDefinitionName="test-job-name", + ) + +# def test_aws_batch_step_operator_step(): + +# @step( +# step_operator="aws_batch", +# settings={ +# 'docker':DockerSettings(parent_image='test-image'), +# 'step_operator':AWSBatchStepOperatorSettings(instance_type='test-instance') +# } +# ) +# def test_step(name: str) -> str: + +# return f'Hello {name}! I am running on AWS Batch!' + +# @pipeline +# def test_pipeline(name: str): +# test_step(name) + +# test_pipeline("Sebastian",stack=Client().get_stack("aws_batch_stack")) \ No newline at end of file From a398139296f26c48a8752657179262ed8875db0d Mon Sep 17 00:00:00 2001 From: SebastianScherer88 Date: Sun, 21 Sep 2025 22:11:17 +0100 Subject: [PATCH 23/36] more changes after successfully e2e testing single node (i.e. aws batch container type) step after hacking + hardcoding some remote rds connection credentials into custom parent image. also needed parent image bc the new flavor module is obvs not installed in official distribution which is in the default docker image used by zenml --- .../step_operators/aws_batch_step_operator.py | 72 ++++--------------- 1 file changed, 15 insertions(+), 57 deletions(-) diff --git a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py index 84b4feaa7a3..1f71de169eb 100644 --- a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py +++ b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py @@ -130,26 +130,19 @@ class AWSBatchJobDefinitionRetryStrategy(BaseModel): """An AWS Batch job subconfiguration model for retry specifications.""" attempts: PositiveInt = 2 evaluateOnExit: List[Dict[str,str]] = [ - # { - # "onExitCode": "137", # out-of-memory killed - # "action": "RETRY" - # }, - # { - # "onReason": "*Host EC2 terminated", - # "action": "RETRY" - # }, - # { - # "action": "EXIT" - # } + { + "onExitCode": "137", # out-of-memory killed + "action": "RETRY" + }, + { + "onReason": "*Host EC2 terminated", # host EC2 rugpulled->try again + "action": "RETRY" + } ] class AWSBatchJobBaseDefinition(BaseModel): - """A utility to validate AWS Batch job descriptions. - - Defaults fall into two categories: - - reasonable default values - - aligning the job description to be a valid 'container' type configuration, - as multinode jobs are not supported yet.""" + """A utility to validate AWS Batch job descriptions. Base class + for container and multinode job definition types.""" jobDefinitionName: str type: Literal['container','multinode'] @@ -402,14 +395,6 @@ def generate_job_definition(self, info: "StepRunInfo", entrypoint_command: List[ step_settings = cast(AWSBatchStepOperatorSettings, self.get_settings(info)) job_name = self.generate_unique_batch_job_name(info) - # container_properties = AWSBatchJobDefinitionContainerProperties( - # executionRoleArn=self.config.execution_role, - # jobRoleArn=self.config.job_role, - # image=image_name, - # command=entrypoint_command, - # environment=self.map_environment(environment), - # resourceRequirements=self.map_resource_settings(resource_settings), - # ).model_dump(exclude='instanceType') node_count = step_settings.node_count @@ -427,37 +412,9 @@ def generate_job_definition(self, info: "StepRunInfo", entrypoint_command: List[ resourceRequirements=self.map_resource_settings(resource_settings), ) ) - # kwargs = { - # 'type':'container', - # 'containerProperties':container_properties - # } - # return AWSBatchJobDefinition( - # jobDefinitionName=job_name, - # timeout={'attemptDurationSeconds':step_settings.timeout_seconds}, - # **kwargs - # ).model_dump(exclude='nodeProperties') - else: - # kwargs = { - # 'type':'multinode', - # 'nodeProperties':AWSBatchJobDefinitionNodeProperties( - # numNodes=node_count, - # nodeRangeProperties=[ - # AWSBatchJobDefinitionNodePropertiesNodeRangeProperty( - # targetNodes=','.join([str(node_index) for node_index in range(node_count)]), - # container=AWSBatchJobDefinitionContainerProperties( - # executionRoleArn=self.config.execution_role, - # jobRoleArn=self.config.job_role, - # image=image_name, - # command=entrypoint_command, - # environment=self.map_environment(environment), - # instanceType=step_settings.instance_type, - # resourceRequirements=self.map_resource_settings(resource_settings), - # ) - # ) - # ] - # ) - # } + else: + return AWSBatchJobMultinodeTypeDefinition( jobDefinitionName=job_name, timeout={'attemptDurationSeconds':step_settings.timeout_seconds}, @@ -549,13 +506,14 @@ def launch( response = batch_client.describe_jobs(jobs=[job_id]) status = response['jobs'][0]['status'] - if status == ['SUCCEEDED']: + if status == 'SUCCEEDED': logger.info(f"Job completed successfully: {job_id}") break - elif status == ["FAILED"]: + elif status == "FAILED": status_reason = response['jobs'][0].get('statusReason', 'Unknown') raise RuntimeError(f'Job {job_id} failed: {status_reason}') else: + logger.info(f"Job {job_id} neither failed nor succeeded: {status}. Waiting another 10 seconds.") time.sleep(10) except ClientError as e: logger.error(f"Failed to describe job {job_id}: {e}") From 1a602eb02618974f11a4c1f56e092a94c6b98ad1 Mon Sep 17 00:00:00 2001 From: SebastianScherer88 Date: Sun, 21 Sep 2025 22:30:34 +0100 Subject: [PATCH 24/36] fixed step environment settings bug --- .../aws/step_operators/aws_batch_step_operator.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py index 1f71de169eb..91c40debb15 100644 --- a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py +++ b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py @@ -135,7 +135,7 @@ class AWSBatchJobDefinitionRetryStrategy(BaseModel): "action": "RETRY" }, { - "onReason": "*Host EC2 terminated", # host EC2 rugpulled->try again + "onReason": "Host EC2 terminated", # host EC2 rugpulled->try again "action": "RETRY" } ] @@ -394,6 +394,9 @@ def generate_job_definition(self, info: "StepRunInfo", entrypoint_command: List[ resource_settings = info.config.resource_settings step_settings = cast(AWSBatchStepOperatorSettings, self.get_settings(info)) + if step_settings.environment: + environment.update(step_settings.environment) + job_name = self.generate_unique_batch_job_name(info) node_count = step_settings.node_count From 69e60d1c44dfdd701a5f4c9af456d928b70ffbcb Mon Sep 17 00:00:00 2001 From: SebastianScherer88 Date: Sun, 21 Sep 2025 23:14:30 +0100 Subject: [PATCH 25/36] fixed the multinode targetnode syntax --- .../aws/step_operators/aws_batch_step_operator.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py index 91c40debb15..c3067f77d32 100644 --- a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py +++ b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py @@ -417,7 +417,6 @@ def generate_job_definition(self, info: "StepRunInfo", entrypoint_command: List[ ) else: - return AWSBatchJobMultinodeTypeDefinition( jobDefinitionName=job_name, timeout={'attemptDurationSeconds':step_settings.timeout_seconds}, @@ -426,7 +425,7 @@ def generate_job_definition(self, info: "StepRunInfo", entrypoint_command: List[ numNodes=node_count, nodeRangeProperties=[ AWSBatchJobDefinitionNodePropertiesNodeRangeProperty( - targetNodes=','.join([str(node_index) for node_index in range(node_count)]), + targetNodes=f"0:{node_count-1}", container=AWSBatchJobDefinitionMultinodeTypeContainerProperties( executionRoleArn=self.config.execution_role, jobRoleArn=self.config.job_role, @@ -436,7 +435,7 @@ def generate_job_definition(self, info: "StepRunInfo", entrypoint_command: List[ instanceType=step_settings.instance_type, resourceRequirements=self.map_resource_settings(resource_settings), ) - ) + ), ] ) ) @@ -487,6 +486,8 @@ def launch( job_definition = self.generate_job_definition(info, entrypoint_command, environment) + logger.info(f"Job definition: {job_definition}") + boto_session = self._get_aws_session() batch_client = boto_session.client('batch') From 0d53bcea23685dbf47a46d49d5dcad4d2a4b2f78 Mon Sep 17 00:00:00 2001 From: SebastianScherer88 Date: Mon, 22 Sep 2025 01:13:16 +0100 Subject: [PATCH 26/36] fixed type hints for instance type --- .../integrations/aws/flavors/aws_batch_step_operator_flavor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py b/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py index b318b604f27..2e35ef2de00 100644 --- a/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py +++ b/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py @@ -35,7 +35,7 @@ class AWSBatchStepOperatorSettings(BaseSettings): """Settings for the Sagemaker step operator.""" - instance_type: Union[str,List[str]] = Field( + instance_type: Union[str] = Field( default='optimal', description="The instance type for AWS Batch to use for the step" \ " execution. Example: 'm5.xlarge'", From 739fdaa5fdcc5b766bb914f54387f08957b16d08 Mon Sep 17 00:00:00 2001 From: SebastianScherer88 Date: Tue, 23 Sep 2025 01:35:30 +0100 Subject: [PATCH 27/36] stripping out multinode support as its not really needed given batch ec2 instances sizes, and instead support both fargate and ec2 (for gpu) backends on ecs --- .../flavors/aws_batch_step_operator_flavor.py | 35 ++- .../aws/step_operators/__init__.py | 3 +- .../step_operators/aws_batch_step_operator.py | 246 +++++++----------- 3 files changed, 112 insertions(+), 172 deletions(-) diff --git a/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py b/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py index 2e35ef2de00..03714552ca2 100644 --- a/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py +++ b/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py @@ -15,7 +15,7 @@ from typing import TYPE_CHECKING, Dict, Optional, Type, List, Union -from pydantic import Field, PositiveInt +from pydantic import Field, PositiveInt, field_validator from zenml.utils.secret_utils import SecretField from zenml.config.base_settings import BaseSettings from zenml.integrations.aws import ( @@ -35,28 +35,35 @@ class AWSBatchStepOperatorSettings(BaseSettings): """Settings for the Sagemaker step operator.""" - instance_type: Union[str] = Field( - default='optimal', - description="The instance type for AWS Batch to use for the step" \ - " execution. Example: 'm5.xlarge'", - ) environment: Dict[str, str] = Field( default_factory=dict, description="Environment variables to pass to the container during " \ "execution. Example: {'LOG_LEVEL': 'INFO', 'DEBUG_MODE': 'False'}", ) - node_count: PositiveInt = Field( - default=1, - description="The number of AWS Batch nodes to run the step on. If > 1," \ - "an AWS Batch multinode job will be run, with the network connectivity" \ - "between the nodes provided by AWS Batch. See https://docs.aws.amazon.com/batch/latest/userguide/multi-node-parallel-jobs.html" \ - "for details." + job_queue_name: str = Field( + default="", + description="The AWS Batch job queue to submit the step AWS Batch job" + " to. If not provided, falls back to the default job queue name " + "specified at stack registration time." + ) + platform_capability: str = Field( + default="FARGATE", + description="The AWS Batch platform capability for the step AWS Batch " + "job to be orchestrated with. Defaults to 'FARGATE'." ) timeout_seconds: PositiveInt = Field( default=3600, description="The number of seconds before AWS Batch times out the job." ) + @field_validator("platform_capability") + def validate_platform_capability(cls, value): + if value not in ["FARGATE","EC2"]: + raise ValueError(f"Invalid platform capability {value}. Must be " + "either 'FARGATE' or 'EC2'") + + return value + class AWSBatchStepOperatorConfig( @@ -76,8 +83,8 @@ class AWSBatchStepOperatorConfig( job_role: str = Field( description="The IAM role arn of the ECS job role." ) - job_queue_name: str = Field( - description="The AWS Batch job queue to submit AWS Batch jobs to." + default_job_queue_name: str = Field( + description="The default AWS Batch job queue to submit AWS Batch jobs to." ) aws_access_key_id: Optional[str] = SecretField( default=None, diff --git a/src/zenml/integrations/aws/step_operators/__init__.py b/src/zenml/integrations/aws/step_operators/__init__.py index ecc24b98f0e..d04fccea431 100644 --- a/src/zenml/integrations/aws/step_operators/__init__.py +++ b/src/zenml/integrations/aws/step_operators/__init__.py @@ -18,6 +18,5 @@ ) from zenml.integrations.aws.step_operators.aws_batch_step_operator import ( # noqa: F401 AWSBatchStepOperator, - get_context ) -__all__ = ["SagemakerStepOperator","AWSBatchStepOperator","get_context"] +__all__ = ["SagemakerStepOperator","AWSBatchStepOperator"] diff --git a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py index c3067f77d32..9dc046efb59 100644 --- a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py +++ b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py @@ -25,7 +25,7 @@ Literal, cast, ) -from pydantic import BaseModel, PositiveInt, Field +from pydantic import BaseModel, PositiveInt, Field, field_validator from pydantic_settings import BaseSettings as PydanticBaseSettings from boto3 import Session @@ -59,73 +59,20 @@ BATCH_DOCKER_IMAGE_KEY = "aws_batch_step_operator" _ENTRYPOINT_ENV_VARIABLE = "__ZENML_ENTRYPOINT" -class AWSBatchContext(PydanticBaseSettings): - """A utility to access the AWS Batch job context environment variables.""" - main_node_index: int = Field( - description="This variable is set to the index number of the job's " - "main node. Your application code can compare the " - "AWS_BATCH_JOB_MAIN_NODE_INDEX to the AWS_BATCH_JOB_NODE_INDEX on " - "an individual node to determine if it's the main node.", - alias="AWS_BATCH_JOB_MAIN_NODE_INDEX" - ) - - main_node_address: str = Field( - description="This variable is only set in multi-node parallel job " - "child nodes. This variable isn't present on the main node. This " - "variable is set to the private IPv4 address of the job's main " - "node. Your child node's application code can use this address to " - "communicate with the main node.", - alias="AWS_BATCH_JOB_MAIN_NODE_PRIVATE_IPV4_ADDRESS" - ) - - node_index: int = Field( - description="This variable is set to the node index number of the " - "node. The node index begins at 0, and each node receives a unique" - " index number. For example, a multi-node parallel job with 10 " - "children has index values of 0-9.", - alias="AWS_BATCH_JOB_NODE_INDEX" - ) - - num_nodes: int = Field( - description="This variable is set to the number of nodes that you have" - "requested for your multi-node parallel job.", - alias="AWS_BATCH_JOB_NUM_NODES" - ) - -def get_context() -> AWSBatchContext: - """Utility to retrieve the AWS Batch runtime context.""" - return AWSBatchContext() - -class AWSBatchJobDefinitionContainerTypeContainerProperties(BaseModel): +class ResourceRequirement(BaseModel): + type: Literal["MEMORY","VCPU","GPU"] + value: str + +class AWSBatchJobDefinitionContainerProperties(BaseModel): """An AWS Batch job subconfiguration model for a container type job's container specification.""" image: str command: List[str] jobRoleArn: str executionRoleArn: str environment: List[Dict[str,str]] = [] # keys: 'name','value' - resourceRequirements: List[Dict[str,str]] = [] # keys: 'value','type', with type one of 'GPU','VCPU','MEMORY' + resourceRequirements: List[ResourceRequirement] = [] # keys: 'value','type', with type one of 'GPU','VCPU','MEMORY' secrets: List[Dict[str,str]] = [] # keys: 'name','value' -class AWSBatchJobDefinitionMultinodeTypeContainerProperties(AWSBatchJobDefinitionContainerTypeContainerProperties): - """An AWS Batch job subconfiguration model for a multinode type job's container specification.""" - instanceType: Optional[str] = "" - -class AWSBatchJobDefinitionNodePropertiesNodeRangeProperty(BaseModel): - """An AWS Batch job subconfiguration model for a node in a multinode job - specifications. - """ - targetNodes: str - container: AWSBatchJobDefinitionMultinodeTypeContainerProperties - -class AWSBatchJobDefinitionNodeProperties(BaseModel): - """An AWS Batch job subconfiguration model for multinode job specifications. - """ - numNodes: PositiveInt = 1 - mainNode: int = 0 - nodeRangeProperties: List[ - AWSBatchJobDefinitionNodePropertiesNodeRangeProperty - ] = [] - class AWSBatchJobDefinitionRetryStrategy(BaseModel): """An AWS Batch job subconfiguration model for retry specifications.""" attempts: PositiveInt = 2 @@ -140,30 +87,41 @@ class AWSBatchJobDefinitionRetryStrategy(BaseModel): } ] -class AWSBatchJobBaseDefinition(BaseModel): +class AWSBatchJobDefinition(BaseModel): """A utility to validate AWS Batch job descriptions. Base class for container and multinode job definition types.""" jobDefinitionName: str - type: Literal['container','multinode'] + type: str = 'container' + containerProperties: Optional[AWSBatchJobDefinitionContainerProperties] = None parameters: Dict[str,str] = {} # schedulingPriority: int = 0 # ignored in FIFO queues retryStrategy: AWSBatchJobDefinitionRetryStrategy = AWSBatchJobDefinitionRetryStrategy() propagateTags: bool = False - timeout: Dict[str,int] = {'attemptDurationSeconds':60} # key 'attemptDurationSeconds' + timeout: Dict[str,int] = {'attemptDurationSeconds':3600} # key 'attemptDurationSeconds' tags: Dict[str,str] = {} - platformCapabilities: List[Literal["EC2","FARGATE"]] = ["EC2"] #-- hardcode this to EC2, so we can use container and multinode interchangeably without worrying too much + platformCapabilities: List[Literal["EC2","FARGATE"]] -class AWSBatchJobContainerTypeDefinition(AWSBatchJobBaseDefinition): - """A utility to validate AWS Batch job descriptions of type contaienr.""" - - type: Literal['container','multinode'] = 'container' # we dont support multinode type in this version - containerProperties: Optional[AWSBatchJobDefinitionContainerTypeContainerProperties] = None -class AWSBatchJobMultinodeTypeDefinition(AWSBatchJobBaseDefinition): - """A utility to validate AWS Batch job descriptions of type multinode.""" - type: Literal['container','multinode'] = 'multinode' # we dont support multinode type in this version - nodeProperties: Optional[AWSBatchJobDefinitionNodeProperties] = None +class AWSBatchJobDefinitionOnEC2(AWSBatchJobDefinition): + platformCapabilities: str = "EC2" + logConfiguration: dict[Literal["logDriver"],Literal["awsfirelens", "awslogs", "fluentd", "gelf", "json-file", "journald", "logentries", "syslog", "splunk"]] = {"logDriver":"awslogs"} + +class AWSBatchJobDefinitionOnFargate(AWSBatchJobDefinition): + platformCapabilities: str = "FARGATE" + logConfiguration: dict[Literal["logDriver"],Literal["awslogs","splunk"]] = {"logDriver":"awslogs"} + networkConfiguration: dict[Literal["assignPublicIp"],Literal["ENABLED","DISABLED"]] = {"assignPublicIp":"DISABLED"} + + @field_validator("containerProperties") + def check_resource_requirements(cls,value: AWSBatchJobDefinitionContainerProperties): + for resource_requirement in value.resourceRequirements: + if resource_requirement.type == "GPU": + raise ValueError( + f"Invalid fargate resource requirement: GPU.Use EC2 " + "platform capability if you need custom devices." + ) + + return value class AWSBatchStepOperator(BaseStepOperator): """Step operator to run a step on AWS Batch. @@ -308,7 +266,7 @@ def map_environment(environment: Dict[str,str]) -> List[Dict[str,str]]: ] @staticmethod - def map_resource_settings(resource_settings: "ResourceSettings") -> List[Dict[str,str]]: + def map_resource_settings(resource_settings: "ResourceSettings") -> List["ResourceRequirement"]: """Utility to map the resource_settings to the resource convention used in the AWS Batch Job definition spec. @@ -320,51 +278,43 @@ def map_resource_settings(resource_settings: "ResourceSettings") -> List[Dict[st """ mapped_resource_settings = [] - if resource_settings.empty: - # aws batch job description requires a value for vcpu - mapped_resource_settings.extend( - [ - { - "value": "1", - "type": 'VCPU' - }, - { - "value": "1024", - "type": 'MEMORY' - } - ] - ) - else: - - if resource_settings.cpu_count is not None: - - cpu_count_int = math.ceil(resource_settings.cpu_count) + # handle cpu requirements + if resource_settings.cpu_count is not None: - if cpu_count_int != resource_settings.cpu_count: - logger.info(f"AWS Batch only accepts int type cpu resource requirements. Converted {resource_settings.cpu_count} to {cpu_count_int}") + cpu_count_int = math.ceil(resource_settings.cpu_count) - mapped_resource_settings.append( - { - "value": str(cpu_count_int), - "type": 'VCPU' - } + if cpu_count_int != resource_settings.cpu_count: + logger.info( + "AWS Batch only accepts int type cpu resource " + f"requirements. Converted {resource_settings.cpu_count}" + f" to {cpu_count_int}." ) + cpu_requirement = ResourceRequirement(value=str(cpu_count_int),type="VCPU") + else: + cpu_requirement = ResourceRequirement(value="1",type="VCPU") - if resource_settings.gpu_count is not None: - mapped_resource_settings.append( - { - "value": str(resource_settings.gpu_count), - "type": 'GPU' - } - ) + mapped_resource_settings.append(cpu_requirement) - if resource_settings.get_memory() is not None: - mapped_resource_settings.append( - { - "value": str(int(resource_settings.get_memory(unit="MiB"))), - "type": 'MEMORY' - } + # handle memory requirements + memory = resource_settings.get_memory(unit="MiB") + if memory: + + memory_requirement = ResourceRequirement( + value=str(int(memory)), + type="MEMORY" + ) + else: + memory_requirement = ResourceRequirement(value="1024",type="MEMORY") + mapped_resource_settings.append(memory_requirement) + + # handle gpu requirements + if resource_settings.gpu_count is not None: + mapped_resource_settings.append( + ResourceRequirement( + value=str(resource_settings.gpu_count), + type="GPU" ) + ) return mapped_resource_settings @@ -385,7 +335,7 @@ def generate_unique_batch_job_name(info: "StepRunInfo") -> str: suffix = random_str(4) return f"{job_name}-{suffix}" - def generate_job_definition(self, info: "StepRunInfo", entrypoint_command: List[str], environment: Dict[str,str]) -> AWSBatchJobContainerTypeDefinition | AWSBatchJobMultinodeTypeDefinition: + def generate_job_definition(self, info: "StepRunInfo", entrypoint_command: List[str], environment: Dict[str,str]) -> AWSBatchJobDefinition: """Utility to map zenml internal configurations to a valid AWS Batch job definition.""" @@ -399,46 +349,24 @@ def generate_job_definition(self, info: "StepRunInfo", entrypoint_command: List[ job_name = self.generate_unique_batch_job_name(info) - node_count = step_settings.node_count - - if node_count == 1: - return AWSBatchJobContainerTypeDefinition( - jobDefinitionName=job_name, - timeout={'attemptDurationSeconds':step_settings.timeout_seconds}, - type="container", - containerProperties=AWSBatchJobDefinitionContainerTypeContainerProperties( - executionRoleArn=self.config.execution_role, - jobRoleArn=self.config.job_role, - image=image_name, - command=entrypoint_command, - environment=self.map_environment(environment), - resourceRequirements=self.map_resource_settings(resource_settings), - ) - ) - + if step_settings.platform_capability == "EC2": + AWSBatchJobDefinitionClass = AWSBatchJobDefinitionOnEC2 else: - return AWSBatchJobMultinodeTypeDefinition( - jobDefinitionName=job_name, - timeout={'attemptDurationSeconds':step_settings.timeout_seconds}, - type="multinode", - nodeProperties=AWSBatchJobDefinitionNodeProperties( - numNodes=node_count, - nodeRangeProperties=[ - AWSBatchJobDefinitionNodePropertiesNodeRangeProperty( - targetNodes=f"0:{node_count-1}", - container=AWSBatchJobDefinitionMultinodeTypeContainerProperties( - executionRoleArn=self.config.execution_role, - jobRoleArn=self.config.job_role, - image=image_name, - command=entrypoint_command, - environment=self.map_environment(environment), - instanceType=step_settings.instance_type, - resourceRequirements=self.map_resource_settings(resource_settings), - ) - ), - ] - ) + AWSBatchJobDefinitionClass = AWSBatchJobDefinitionOnFargate + + return AWSBatchJobDefinitionClass( + jobDefinitionName=job_name, + timeout={'attemptDurationSeconds':step_settings.timeout_seconds}, + type="container", + containerProperties=AWSBatchJobDefinitionContainerProperties( + executionRoleArn=self.config.execution_role, + jobRoleArn=self.config.job_role, + image=image_name, + command=entrypoint_command, + environment=self.map_environment(environment), + resourceRequirements=self.map_resource_settings(resource_settings), ) + ) def get_docker_builds( @@ -497,9 +425,11 @@ def launch( job_definition_name = response['jobDefinitionName'] + step_settings = cast(AWSBatchStepOperatorSettings, self.get_settings(info)) + response = batch_client.submit_job( jobName=job_definition.jobDefinitionName, - jobQueue=self.config.job_queue_name, + jobQueue=step_settings.job_queue_name if step_settings.job_queue_name else self.config.default_job_queue_name, jobDefinition=job_definition_name, ) @@ -509,15 +439,19 @@ def launch( try: response = batch_client.describe_jobs(jobs=[job_id]) status = response['jobs'][0]['status'] - + status_reason = response['jobs'][0].get('statusReason', 'Unknown') + if status == 'SUCCEEDED': logger.info(f"Job completed successfully: {job_id}") break elif status == "FAILED": - status_reason = response['jobs'][0].get('statusReason', 'Unknown') raise RuntimeError(f'Job {job_id} failed: {status_reason}') else: - logger.info(f"Job {job_id} neither failed nor succeeded: {status}. Waiting another 10 seconds.") + logger.info( + f"Job {job_id} neither failed nor succeeded. Status: " + f"{status}. Status reason: {status_reason}. Waiting " + "another 10 seconds." + ) time.sleep(10) except ClientError as e: logger.error(f"Failed to describe job {job_id}: {e}") From 96651075462121adce68c9f0e6fa0251cbdf7c85 Mon Sep 17 00:00:00 2001 From: SebastianScherer88 Date: Thu, 25 Sep 2025 22:24:26 +0100 Subject: [PATCH 28/36] fixed fargate networking bug. the container spec model didnt have a networking config attibute and was silently ignoring the kwargs passed in the FARGATE path --- .../flavors/aws_batch_step_operator_flavor.py | 25 +++--- .../step_operators/aws_batch_step_operator.py | 63 ++++++++------ .../test_aws_batch_step_operator.py | 85 ++++++++----------- 3 files changed, 83 insertions(+), 90 deletions(-) diff --git a/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py b/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py index 03714552ca2..3ed0efa5c01 100644 --- a/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py +++ b/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Amazon SageMaker step operator flavor.""" -from typing import TYPE_CHECKING, Dict, Optional, Type, List, Union +from typing import TYPE_CHECKING, Dict, Optional, Type, Literal from pydantic import Field, PositiveInt, field_validator from zenml.utils.secret_utils import SecretField @@ -44,26 +44,25 @@ class AWSBatchStepOperatorSettings(BaseSettings): default="", description="The AWS Batch job queue to submit the step AWS Batch job" " to. If not provided, falls back to the default job queue name " - "specified at stack registration time." + "specified at stack registration time. Must be compatible with" + "`backend`." ) - platform_capability: str = Field( - default="FARGATE", + backend: Literal['EC2','FARGATE'] = Field( + default="EC2", description="The AWS Batch platform capability for the step AWS Batch " - "job to be orchestrated with. Defaults to 'FARGATE'." + "job to be orchestrated with. Must be compatible with `job_queue_name`." + "Defaults to 'EC2'." + ) + assign_public_ip: Literal['ENABLED','DISABLED'] = Field( + default="ENABLED", + description="Sets the network configuration's assignPublicIp field." + "Only relevant for FARGATE backend." ) timeout_seconds: PositiveInt = Field( default=3600, description="The number of seconds before AWS Batch times out the job." ) - @field_validator("platform_capability") - def validate_platform_capability(cls, value): - if value not in ["FARGATE","EC2"]: - raise ValueError(f"Invalid platform capability {value}. Must be " - "either 'FARGATE' or 'EC2'") - - return value - class AWSBatchStepOperatorConfig( diff --git a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py index 9dc046efb59..4661e9f78f2 100644 --- a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py +++ b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py @@ -25,8 +25,7 @@ Literal, cast, ) -from pydantic import BaseModel, PositiveInt, Field, field_validator -from pydantic_settings import BaseSettings as PydanticBaseSettings +from pydantic import BaseModel, PositiveInt, field_validator from boto3 import Session from zenml.client import Client @@ -73,6 +72,25 @@ class AWSBatchJobDefinitionContainerProperties(BaseModel): resourceRequirements: List[ResourceRequirement] = [] # keys: 'value','type', with type one of 'GPU','VCPU','MEMORY' secrets: List[Dict[str,str]] = [] # keys: 'name','value' + +class AWSBatchJobDefinitionEC2ContainerProperties(AWSBatchJobDefinitionContainerProperties): + logConfiguration: dict[Literal["logDriver"],Literal["awsfirelens", "awslogs", "fluentd", "gelf", "json-file", "journald", "logentries", "syslog", "splunk"]] = {"logDriver":"awslogs"} + +class AWSBatchJobDefinitionFargateContainerProperties(AWSBatchJobDefinitionContainerProperties): + logConfiguration: dict[Literal["logDriver"],Literal["awslogs","splunk"]] = {"logDriver":"awslogs"} + networkConfiguration: dict[Literal['assignPublicIp'],Literal['ENABLED','DISABLED']] = {"assignPublicIp": "ENABLED"} + + @field_validator("resourceRequirements") + def check_resource_requirements(cls,value: List[ResourceRequirement]): + for resource_requirement in value: + if resource_requirement.type == "GPU": + raise ValueError( + f"Invalid fargate resource requirement: GPU.Use EC2 " + "platform capability if you need custom devices." + ) + + return value + class AWSBatchJobDefinitionRetryStrategy(BaseModel): """An AWS Batch job subconfiguration model for retry specifications.""" attempts: PositiveInt = 2 @@ -93,7 +111,6 @@ class AWSBatchJobDefinition(BaseModel): jobDefinitionName: str type: str = 'container' - containerProperties: Optional[AWSBatchJobDefinitionContainerProperties] = None parameters: Dict[str,str] = {} # schedulingPriority: int = 0 # ignored in FIFO queues retryStrategy: AWSBatchJobDefinitionRetryStrategy = AWSBatchJobDefinitionRetryStrategy() @@ -102,26 +119,13 @@ class AWSBatchJobDefinition(BaseModel): tags: Dict[str,str] = {} platformCapabilities: List[Literal["EC2","FARGATE"]] +class AWSBatchJobEC2Definition(AWSBatchJobDefinition): + containerProperties: AWSBatchJobDefinitionEC2ContainerProperties + platformCapabilities: list[Literal["EC2"]] = ["EC2"] -class AWSBatchJobDefinitionOnEC2(AWSBatchJobDefinition): - platformCapabilities: str = "EC2" - logConfiguration: dict[Literal["logDriver"],Literal["awsfirelens", "awslogs", "fluentd", "gelf", "json-file", "journald", "logentries", "syslog", "splunk"]] = {"logDriver":"awslogs"} - -class AWSBatchJobDefinitionOnFargate(AWSBatchJobDefinition): - platformCapabilities: str = "FARGATE" - logConfiguration: dict[Literal["logDriver"],Literal["awslogs","splunk"]] = {"logDriver":"awslogs"} - networkConfiguration: dict[Literal["assignPublicIp"],Literal["ENABLED","DISABLED"]] = {"assignPublicIp":"DISABLED"} - - @field_validator("containerProperties") - def check_resource_requirements(cls,value: AWSBatchJobDefinitionContainerProperties): - for resource_requirement in value.resourceRequirements: - if resource_requirement.type == "GPU": - raise ValueError( - f"Invalid fargate resource requirement: GPU.Use EC2 " - "platform capability if you need custom devices." - ) - - return value +class AWSBatchJobFargateDefinition(AWSBatchJobDefinition): + containerProperties: AWSBatchJobDefinitionFargateContainerProperties + platformCapabilities: list[Literal["FARGATE"]] = ["FARGATE"] class AWSBatchStepOperator(BaseStepOperator): """Step operator to run a step on AWS Batch. @@ -349,22 +353,27 @@ def generate_job_definition(self, info: "StepRunInfo", entrypoint_command: List[ job_name = self.generate_unique_batch_job_name(info) - if step_settings.platform_capability == "EC2": - AWSBatchJobDefinitionClass = AWSBatchJobDefinitionOnEC2 - else: - AWSBatchJobDefinitionClass = AWSBatchJobDefinitionOnFargate + if step_settings.backend == "EC2": + AWSBatchJobDefinitionClass = AWSBatchJobEC2Definition + AWSBatchContainerProperties = AWSBatchJobDefinitionEC2ContainerProperties + container_kwargs = {} + elif step_settings.backend == 'FARGATE': + AWSBatchJobDefinitionClass = AWSBatchJobFargateDefinition + AWSBatchContainerProperties = AWSBatchJobDefinitionFargateContainerProperties + container_kwargs = {'networkConfiguration': {"assignPublicIp":step_settings.assign_public_ip}} return AWSBatchJobDefinitionClass( jobDefinitionName=job_name, timeout={'attemptDurationSeconds':step_settings.timeout_seconds}, type="container", - containerProperties=AWSBatchJobDefinitionContainerProperties( + containerProperties=AWSBatchContainerProperties( executionRoleArn=self.config.execution_role, jobRoleArn=self.config.job_role, image=image_name, command=entrypoint_command, environment=self.map_environment(environment), resourceRequirements=self.map_resource_settings(resource_settings), + **container_kwargs ) ) diff --git a/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py b/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py index 3aaa7bf9a43..d8654635a0b 100644 --- a/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py +++ b/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py @@ -16,21 +16,13 @@ import pytest from zenml.config.resource_settings import ResourceSettings -from zenml.integrations.aws.step_operators.aws_batch_step_operator import AWSBatchStepOperator, get_context - -def test_aws_batch_context(monkeypatch): - """Tests the AWSBatchContext class.""" - - monkeypatch.setenv('AWS_BATCH_JOB_MAIN_NODE_INDEX',0) - monkeypatch.setenv('AWS_BATCH_JOB_MAIN_NODE_PRIVATE_IPV4_ADDRESS','test-address') - monkeypatch.setenv('AWS_BATCH_JOB_NODE_INDEX',1) - monkeypatch.setenv('AWS_BATCH_JOB_NUM_NODES',2) - - test_aws_batch_context = get_context() - assert test_aws_batch_context.main_node_index == 0 - assert test_aws_batch_context.main_node_address == 'test-address' - assert test_aws_batch_context.node_index == 1 - assert test_aws_batch_context.num_nodes == 2 +from zenml.integrations.aws.step_operators.aws_batch_step_operator import ( + AWSBatchStepOperator, + ResourceRequirement, + AWSBatchJobDefinitionOnEC2, + AWSBatchJobDefinitionOnFargate, + AWSBatchJobDefinitionContainerProperties +) def test_aws_batch_step_operator_map_environment(): """Tests the AWSBatchStepOperator's map_environment class method.""" @@ -54,51 +46,44 @@ def test_aws_batch_step_operator_map_environment(): [ ( ResourceSettings(), - [] - ), - ( - ResourceSettings(cpu_count=0.4,gpu_count=1,memory="10MiB"), [ - { - "value":"1", - "type":"VCPU" - }, - { - "value":"1", - "type":"GPU" - }, - { - "value": "10", - "type": "MEMORY" - } + ResourceRequirement(value="1",type="VCPU"), + ResourceRequirement(value="1024",type="MEMORY") ] ), ( - ResourceSettings(cpu_count=1,gpu_count=1), - [ - { - "value":"1", - "type":"VCPU" - }, - { - "value":"1", - "type":"GPU" - }, - ] - ), - ( - ResourceSettings(memory="1GiB"), + ResourceSettings(cpu_count=0.4,gpu_count=1,memory="10MiB"), [ - { - "value": "1024", - "type": "MEMORY" - } + ResourceRequirement(value="1",type="VCPU"), + ResourceRequirement(value="10",type="MEMORY"), + ResourceRequirement(value="1",type="GPU"), ] ), - ] ) def test_aws_batch_step_operator_map_resource_settings(test_resource_settings,expected): """Tests the AWSBatchStepOperator's map_resource_settings class method.""" assert AWSBatchStepOperator.map_resource_settings(test_resource_settings) == expected + +def test_aws_batch_job_definition_on_ec2(): + AWSBatchJobDefinitionOnEC2( + jobDefinitionName="test", + ) + +def test_aws_batch_job_definition_on_fargate(): + AWSBatchJobDefinitionOnFargate( + jobDefinitionName="test", + ) + +def test_aws_batch_job_definition_on_fargate_raise_gpu(): + with pytest.raises(ValueError, match="Invalid fargate resource requirement: GPU.Use EC2 platform capability if you need custom devices."): + AWSBatchJobDefinitionOnFargate( + containerProperties=AWSBatchJobDefinitionContainerProperties( + image="test-image", + command=["test","command"], + jobRoleArn="test-job-role-arn", + executionRoleArn="test-execution-role-arn", + resourceRequirements=[ResourceRequirement(value="1",type="GPU")] + ) + ) \ No newline at end of file From 4e171c1b1e72f0a4bf3d41ae1e5214adc5338f26 Mon Sep 17 00:00:00 2001 From: SebastianScherer88 Date: Thu, 25 Sep 2025 22:26:01 +0100 Subject: [PATCH 29/36] default backend is fargate bc its faster and easier to set up the infra --- .../aws/flavors/aws_batch_step_operator_flavor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py b/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py index 3ed0efa5c01..d2300e45860 100644 --- a/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py +++ b/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py @@ -48,10 +48,10 @@ class AWSBatchStepOperatorSettings(BaseSettings): "`backend`." ) backend: Literal['EC2','FARGATE'] = Field( - default="EC2", + default="FARGATE", description="The AWS Batch platform capability for the step AWS Batch " "job to be orchestrated with. Must be compatible with `job_queue_name`." - "Defaults to 'EC2'." + "Defaults to 'FARGATE'." ) assign_public_ip: Literal['ENABLED','DISABLED'] = Field( default="ENABLED", From 02f9281915af381f7b3ad575f12c3313cdb75275 Mon Sep 17 00:00:00 2001 From: SebastianScherer88 Date: Thu, 25 Sep 2025 22:32:58 +0100 Subject: [PATCH 30/36] fixed integration tests --- .../test_aws_batch_step_operator.py | 36 +++++++++++++----- .../test_aws_batch_step_operator_flavor.py | 37 ++----------------- 2 files changed, 29 insertions(+), 44 deletions(-) diff --git a/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py b/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py index d8654635a0b..add3fb6689d 100644 --- a/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py +++ b/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py @@ -19,9 +19,10 @@ from zenml.integrations.aws.step_operators.aws_batch_step_operator import ( AWSBatchStepOperator, ResourceRequirement, - AWSBatchJobDefinitionOnEC2, - AWSBatchJobDefinitionOnFargate, - AWSBatchJobDefinitionContainerProperties + AWSBatchJobEC2Definition, + AWSBatchJobFargateDefinition, + AWSBatchJobDefinitionEC2ContainerProperties, + AWSBatchJobDefinitionFargateContainerProperties ) def test_aws_batch_step_operator_map_environment(): @@ -66,20 +67,35 @@ def test_aws_batch_step_operator_map_resource_settings(test_resource_settings,ex assert AWSBatchStepOperator.map_resource_settings(test_resource_settings) == expected -def test_aws_batch_job_definition_on_ec2(): - AWSBatchJobDefinitionOnEC2( +def test_aws_batch_job_ec2_definition(): + AWSBatchJobEC2Definition( jobDefinitionName="test", + containerProperties=AWSBatchJobDefinitionEC2ContainerProperties( + image="test-image", + command=["test","command"], + jobRoleArn="test-job-role-arn", + executionRoleArn="test-execution-role-arn", + resourceRequirements=[ResourceRequirement(value="1",type="GPU")] + ) ) -def test_aws_batch_job_definition_on_fargate(): - AWSBatchJobDefinitionOnFargate( +def test_aws_batch_job_fargate_definition(): + AWSBatchJobFargateDefinition( jobDefinitionName="test", + containerProperties=AWSBatchJobDefinitionFargateContainerProperties( + image="test-image", + command=["test","command"], + jobRoleArn="test-job-role-arn", + executionRoleArn="test-execution-role-arn", + resourceRequirements=[ResourceRequirement(value="2",type="VCPU")] + ) ) -def test_aws_batch_job_definition_on_fargate_raise_gpu(): +def test_aws_batch_job_fargate_definition_raise_gpu(): with pytest.raises(ValueError, match="Invalid fargate resource requirement: GPU.Use EC2 platform capability if you need custom devices."): - AWSBatchJobDefinitionOnFargate( - containerProperties=AWSBatchJobDefinitionContainerProperties( + AWSBatchJobFargateDefinition( + jobDefinitionName="test", + containerProperties=AWSBatchJobDefinitionFargateContainerProperties( image="test-image", command=["test","command"], jobRoleArn="test-job-role-arn", diff --git a/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator_flavor.py b/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator_flavor.py index b658c47b5ae..6554d73b57e 100644 --- a/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator_flavor.py +++ b/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator_flavor.py @@ -1,39 +1,8 @@ from zenml.integrations.aws.flavors.aws_batch_step_operator_flavor import AWSBatchStepOperatorSettings -from zenml.integrations.aws.step_operators.aws_batch_step_operator import AWSBatchJobDefinition -from zenml import pipeline, step -from zenml.config import DockerSettings -from zenml.client import Client - -# Get the ZenML client -client = Client() def test_aws_batch_step_operator_settings(): AWSBatchStepOperatorSettings( - instance_type="g4dn.xlarge", + job_queue_name='test-job-queue', environment={"key_1":"value_1","key_2":"value_2"}, - timeout_seconds=60 - ) - -def test_aws_batch_step_job_definition(): - AWSBatchJobDefinition( - jobDefinitionName="test-job-name", - ) - -# def test_aws_batch_step_operator_step(): - -# @step( -# step_operator="aws_batch", -# settings={ -# 'docker':DockerSettings(parent_image='test-image'), -# 'step_operator':AWSBatchStepOperatorSettings(instance_type='test-instance') -# } -# ) -# def test_step(name: str) -> str: - -# return f'Hello {name}! I am running on AWS Batch!' - -# @pipeline -# def test_pipeline(name: str): -# test_step(name) - -# test_pipeline("Sebastian",stack=Client().get_stack("aws_batch_stack")) \ No newline at end of file + timeout_seconds=3600 + ) \ No newline at end of file From 8fc2959d05ffb87af2de1ea8916289833d30ac80 Mon Sep 17 00:00:00 2001 From: SebastianScherer88 Date: Wed, 1 Oct 2025 22:20:47 +0100 Subject: [PATCH 31/36] addressed all comments except logging --- .../flavors/aws_batch_step_operator_flavor.py | 2 +- .../aws/step_operators/__init__.py | 2 +- .../step_operators/aws_batch_step_operator.py | 34 +++++++------------ ...s_batch_step_operator_entrypoint_config.py | 22 ------------ 4 files changed, 14 insertions(+), 46 deletions(-) delete mode 100644 src/zenml/integrations/aws/step_operators/aws_batch_step_operator_entrypoint_config.py diff --git a/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py b/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py index d2300e45860..a042dc53b95 100644 --- a/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py +++ b/src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing # permissions and limitations under the License. -"""Amazon SageMaker step operator flavor.""" +"""AWS Batch Step operator flavor.""" from typing import TYPE_CHECKING, Dict, Optional, Type, Literal diff --git a/src/zenml/integrations/aws/step_operators/__init__.py b/src/zenml/integrations/aws/step_operators/__init__.py index d04fccea431..f766a309217 100644 --- a/src/zenml/integrations/aws/step_operators/__init__.py +++ b/src/zenml/integrations/aws/step_operators/__init__.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing # permissions and limitations under the License. -"""Initialization of the Sagemaker Step Operator.""" +"""Initialization of the AWS Batch Step Operator.""" from zenml.integrations.aws.step_operators.sagemaker_step_operator import ( # noqa: F401 SagemakerStepOperator, diff --git a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py index 4661e9f78f2..2bbd984e091 100644 --- a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py +++ b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing # permissions and limitations under the License. -"""Implementation of the Sagemaker Step Operator.""" +"""Implementation of the AWS Batch Step Operator.""" import time import math @@ -197,17 +197,6 @@ def _get_aws_session(self) -> Session: ) return boto_session - @property - def entrypoint_config_class( - self, - ) -> Type[StepOperatorEntrypointConfiguration]: - """Returns the entrypoint configuration class for this step operator. - - Returns: - The entrypoint configuration class for this step operator. - """ - return AWSBatchEntrypointConfiguration - @property def validator(self) -> Optional[StackValidator]: """Validates the stack. @@ -258,11 +247,11 @@ def map_environment(environment: Dict[str,str]) -> List[Dict[str,str]]: definition spec. Args: - environment (Dict[str,str]): The step's environment variable + environment: The step's environment variable specification Returns: - List[Dict[str,str]]: The mapped environment variable specification + The mapped environment variable specification """ return [ @@ -275,10 +264,10 @@ def map_resource_settings(resource_settings: "ResourceSettings") -> List["Resour in the AWS Batch Job definition spec. Args: - resource_settings (ResourceSettings): The step's resource settings. + resource_settings: The step's resource settings. Returns: - List[Dict[str,str]]: The mapped resource settings. + The mapped resource settings. """ mapped_resource_settings = [] @@ -327,16 +316,17 @@ def generate_unique_batch_job_name(info: "StepRunInfo") -> str: """Utility to generate a unique AWS Batch job name. Args: - info (StepRunInfo): The step run information. + info: The step run information. Returns: - str: A unique name for the step's AWS Batch job definition + A unique name for the step's AWS Batch job definition """ - # Batch allows 63 characters at maximum for job name - ZenML uses 60 for safety margin. - step_name = Client().get_run_step(info.step_run_id).name - job_name = f"{info.pipeline.name}-{step_name}"[:55] - suffix = random_str(4) + # Batch allows 128 alphanumeric characters at maximum for job name - ZenML uses 60 for safety margin. + # AWS Batch job description names are more permissive than ZenML pipeline and step naming rules, + # so no sanitation needed besides trimming + job_name = f"{info.pipeline.name}-{info.pipeline_step_name}"[:120] + suffix = random_str(6) return f"{job_name}-{suffix}" def generate_job_definition(self, info: "StepRunInfo", entrypoint_command: List[str], environment: Dict[str,str]) -> AWSBatchJobDefinition: diff --git a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator_entrypoint_config.py b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator_entrypoint_config.py deleted file mode 100644 index cd57398a38f..00000000000 --- a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator_entrypoint_config.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright (c) ZenML GmbH 2023. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at: -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -# or implied. See the License for the specific language governing -# permissions and limitations under the License. -"""Entrypoint configuration for ZenML Batch step operator.""" - -from zenml.step_operators.step_operator_entrypoint_configuration import ( - StepOperatorEntrypointConfiguration, -) - - -class AWSBatchEntrypointConfiguration(StepOperatorEntrypointConfiguration): - """Entrypoint configuration for ZenML Batch step operator.""" \ No newline at end of file From 6232f4fea1078d686eb415ac853de6ac70c37f6f Mon Sep 17 00:00:00 2001 From: SebastianScherer88 Date: Wed, 1 Oct 2025 22:25:25 +0100 Subject: [PATCH 32/36] buffer of 5 chars --- .../integrations/aws/step_operators/aws_batch_step_operator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py index 2bbd984e091..c86909056f6 100644 --- a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py +++ b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py @@ -325,7 +325,7 @@ def generate_unique_batch_job_name(info: "StepRunInfo") -> str: # Batch allows 128 alphanumeric characters at maximum for job name - ZenML uses 60 for safety margin. # AWS Batch job description names are more permissive than ZenML pipeline and step naming rules, # so no sanitation needed besides trimming - job_name = f"{info.pipeline.name}-{info.pipeline_step_name}"[:120] + job_name = f"{info.pipeline.name}-{info.pipeline_step_name}"[:115] suffix = random_str(6) return f"{job_name}-{suffix}" From 705c2a96a5a7c6f9618bbe4ff9ae0d131e32be24 Mon Sep 17 00:00:00 2001 From: SebastianScherer88 Date: Sat, 4 Oct 2025 15:59:27 +0100 Subject: [PATCH 33/36] added validation of pipeline and step name before assembling full job description name@ --- .../step_operators/aws_batch_step_operator.py | 32 +++++++++++++------ .../test_aws_batch_step_operator.py | 11 +++++++ 2 files changed, 34 insertions(+), 9 deletions(-) diff --git a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py index c86909056f6..75b08875967 100644 --- a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py +++ b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py @@ -25,25 +25,19 @@ Literal, cast, ) +import re from pydantic import BaseModel, PositiveInt, field_validator from boto3 import Session -from zenml.client import Client from zenml.config.build_configuration import BuildConfiguration from zenml.enums import StackComponentType from zenml.integrations.aws.flavors.aws_batch_step_operator_flavor import ( AWSBatchStepOperatorConfig, AWSBatchStepOperatorSettings, ) -from zenml.integrations.aws.step_operators.aws_batch_step_operator_entrypoint_config import ( - AWSBatchEntrypointConfiguration, -) from zenml.logger import get_logger from zenml.stack import Stack, StackValidator from zenml.step_operators import BaseStepOperator -from zenml.step_operators.step_operator_entrypoint_configuration import ( - StepOperatorEntrypointConfiguration, -) from zenml.utils.string_utils import random_str from botocore.exceptions import ClientError @@ -310,9 +304,14 @@ def map_resource_settings(resource_settings: "ResourceSettings") -> List["Resour ) return mapped_resource_settings - + @staticmethod - def generate_unique_batch_job_name(info: "StepRunInfo") -> str: + def is_name_valid(name: str) -> bool: + pattern = re.compile(r"[A-Za-z0-9_-]+$") + + return bool(pattern.fullmatch(name)) + + def generate_unique_batch_job_name(self, info: "StepRunInfo") -> str: """Utility to generate a unique AWS Batch job name. Args: @@ -325,6 +324,21 @@ def generate_unique_batch_job_name(info: "StepRunInfo") -> str: # Batch allows 128 alphanumeric characters at maximum for job name - ZenML uses 60 for safety margin. # AWS Batch job description names are more permissive than ZenML pipeline and step naming rules, # so no sanitation needed besides trimming + + assert ( + self.is_name_valid(info.pipeline.name), + f"Invalid pipeline name {info.pipeline.name}: A pipeline containing" + "an AWS Batch step must contain only upper and lower case " + "characters [a-zA-Z], digits [0-9], hyphens and underscores [-_]" + ) + + assert ( + self.is_name_valid(info.pipeline_step_name), + f"Invalid step name {info.pipeline_step_name}: An AWS Batch step " + "name must contain only upper and lower case characters [a-zA-Z], " + "digits [0-9], hyphens and underscores [-_]" + ) + job_name = f"{info.pipeline.name}-{info.pipeline_step_name}"[:115] suffix = random_str(6) return f"{job_name}-{suffix}" diff --git a/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py b/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py index add3fb6689d..a3332c9ee9b 100644 --- a/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py +++ b/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py @@ -67,6 +67,17 @@ def test_aws_batch_step_operator_map_resource_settings(test_resource_settings,ex assert AWSBatchStepOperator.map_resource_settings(test_resource_settings) == expected +@pytest.mark.parametrize( + "test_name,expected", + [ + ('valid-name-123abcABC_', True), + ('invalid$name!',False) + ] +) +def test_aws_batch_step_operator_is_name_valid(test_name, expected): + + assert AWSBatchStepOperator.is_name_valid(test_name) == expected + def test_aws_batch_job_ec2_definition(): AWSBatchJobEC2Definition( jobDefinitionName="test", From 971cd68e17c686abfbb4fc5dc30b466646ac6441 Mon Sep 17 00:00:00 2001 From: SebastianScherer88 Date: Sun, 5 Oct 2025 20:15:44 +0100 Subject: [PATCH 34/36] implemented name sanitization as suggested instead of raising exceptions for invalid characters --- .../step_operators/aws_batch_step_operator.py | 34 ++++++++----------- .../test_aws_batch_step_operator.py | 8 ++--- 2 files changed, 18 insertions(+), 24 deletions(-) diff --git a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py index 75b08875967..a2239f21649 100644 --- a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py +++ b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py @@ -40,6 +40,7 @@ from zenml.step_operators import BaseStepOperator from zenml.utils.string_utils import random_str from botocore.exceptions import ClientError +from string import ascii_letters, digits if TYPE_CHECKING: from zenml.config.base_settings import BaseSettings @@ -306,10 +307,13 @@ def map_resource_settings(resource_settings: "ResourceSettings") -> List["Resour return mapped_resource_settings @staticmethod - def is_name_valid(name: str) -> bool: - pattern = re.compile(r"[A-Za-z0-9_-]+$") + def sanitize_name(name: str) -> bool: + valid_characters = ascii_letters + digits + '-_' + sanitized_name = '' + for char in name: + sanitized_name += char if char in valid_characters else '-' - return bool(pattern.fullmatch(name)) + return sanitized_name def generate_unique_batch_job_name(self, info: "StepRunInfo") -> str: """Utility to generate a unique AWS Batch job name. @@ -321,25 +325,15 @@ def generate_unique_batch_job_name(self, info: "StepRunInfo") -> str: A unique name for the step's AWS Batch job definition """ - # Batch allows 128 alphanumeric characters at maximum for job name - ZenML uses 60 for safety margin. - # AWS Batch job description names are more permissive than ZenML pipeline and step naming rules, - # so no sanitation needed besides trimming + # Batch allows 128 alphanumeric characters at maximum for job name. + # We sanitize the pipeline and step names before concatenating, + # capping at 115 chars and finally suffixing with a 6 character random + # string - assert ( - self.is_name_valid(info.pipeline.name), - f"Invalid pipeline name {info.pipeline.name}: A pipeline containing" - "an AWS Batch step must contain only upper and lower case " - "characters [a-zA-Z], digits [0-9], hyphens and underscores [-_]" - ) - - assert ( - self.is_name_valid(info.pipeline_step_name), - f"Invalid step name {info.pipeline_step_name}: An AWS Batch step " - "name must contain only upper and lower case characters [a-zA-Z], " - "digits [0-9], hyphens and underscores [-_]" - ) + sanitized_pipeline_name = self.sanitize_name(info.pipeline.name) + sanitized_step_name = self.sanitize_name(sanitized_pipeline_name) - job_name = f"{info.pipeline.name}-{info.pipeline_step_name}"[:115] + job_name = f"{sanitized_pipeline_name}-{sanitized_step_name}"[:115] suffix = random_str(6) return f"{job_name}-{suffix}" diff --git a/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py b/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py index a3332c9ee9b..b24dade6db1 100644 --- a/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py +++ b/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py @@ -70,13 +70,13 @@ def test_aws_batch_step_operator_map_resource_settings(test_resource_settings,ex @pytest.mark.parametrize( "test_name,expected", [ - ('valid-name-123abcABC_', True), - ('invalid$name!',False) + ('valid-name-123abcABC_', 'valid-name-123abcABC_'), + ('this!is@not"a£valid$name%123','this-is-not-a-valid-name-123') ] ) -def test_aws_batch_step_operator_is_name_valid(test_name, expected): +def test_aws_batch_step_operator_sanitize_name(test_name, expected): - assert AWSBatchStepOperator.is_name_valid(test_name) == expected + assert AWSBatchStepOperator.sanitize_name(test_name) == expected def test_aws_batch_job_ec2_definition(): AWSBatchJobEC2Definition( From d2ace2405a0bb8f1d1b37df2a70d14415758772e Mon Sep 17 00:00:00 2001 From: SebastianScherer88 Date: Sun, 5 Oct 2025 22:05:33 +0100 Subject: [PATCH 35/36] added ec2 and fargate resource validation to schemas, simplified resource mapping method. updated unit test coverage --- .../step_operators/aws_batch_step_operator.py | 89 +++++++--- .../test_aws_batch_step_operator.py | 163 ++++++++++++++++-- 2 files changed, 211 insertions(+), 41 deletions(-) diff --git a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py index a2239f21649..16ff89526fb 100644 --- a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py +++ b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py @@ -53,6 +53,18 @@ BATCH_DOCKER_IMAGE_KEY = "aws_batch_step_operator" _ENTRYPOINT_ENV_VARIABLE = "__ZENML_ENTRYPOINT" + +VALID_FARGATE_VCPU = ('0.25', '0.5', '1', '2', '4', '8', '16') +VALID_FARGATE_MEMORY = { + '0.25': ('512', '1024', '2048'), + '0.5': ('1024', '2048', '3072', '4096'), + '1': ('2048', '3072', '4096', '5120', '6144', '7168', '8192'), + '2': tuple(str(m) for m in range(4096, 16385, 1024)), + '4': tuple(str(m) for m in range(8192, 30721, 1024)), + '8': tuple(str(m) for m in range(16384, 61441, 4096)), + '16': tuple(str(m) for m in range(32768, 122880, 8192)) +} + class ResourceRequirement(BaseModel): type: Literal["MEMORY","VCPU","GPU"] value: str @@ -67,24 +79,67 @@ class AWSBatchJobDefinitionContainerProperties(BaseModel): resourceRequirements: List[ResourceRequirement] = [] # keys: 'value','type', with type one of 'GPU','VCPU','MEMORY' secrets: List[Dict[str,str]] = [] # keys: 'name','value' - + class AWSBatchJobDefinitionEC2ContainerProperties(AWSBatchJobDefinitionContainerProperties): logConfiguration: dict[Literal["logDriver"],Literal["awsfirelens", "awslogs", "fluentd", "gelf", "json-file", "journald", "logentries", "syslog", "splunk"]] = {"logDriver":"awslogs"} + @field_validator("resourceRequirements") + def check_resource_requirements(cls,resource_requirements: List[ResourceRequirement]) -> List[ResourceRequirement]: + + gpu_requirement = [req for req in resource_requirements if req.type == "GPU"] + cpu_requirement = [req for req in resource_requirements if req.type == "VCPU"][0] + memory_requirement = [req for req in resource_requirements if req.type == "MEMORY"][0] + + cpu_float = float(cpu_requirement.value) + cpu_rounded_int = math.ceil(cpu_float) + + if cpu_float != cpu_rounded_int: + logger.info( + f"Rounded fractional EC2 resource VCPU vale from {cpu_float} to {cpu_rounded_int} " + "since AWS Batch on EC2 requires whole integer VCPU count value." + ) + resource_requirements = [ + ResourceRequirement( + type="VCPU", + value=str(cpu_rounded_int) + ), + memory_requirement + ] + resource_requirements.extend(gpu_requirement) + + return resource_requirements + class AWSBatchJobDefinitionFargateContainerProperties(AWSBatchJobDefinitionContainerProperties): logConfiguration: dict[Literal["logDriver"],Literal["awslogs","splunk"]] = {"logDriver":"awslogs"} networkConfiguration: dict[Literal['assignPublicIp'],Literal['ENABLED','DISABLED']] = {"assignPublicIp": "ENABLED"} @field_validator("resourceRequirements") - def check_resource_requirements(cls,value: List[ResourceRequirement]): - for resource_requirement in value: - if resource_requirement.type == "GPU": - raise ValueError( - f"Invalid fargate resource requirement: GPU.Use EC2 " - "platform capability if you need custom devices." - ) + def check_resource_requirements(cls,resource_requirements: List[ResourceRequirement]) -> List[ResourceRequirement]: + + gpu_requirement = [req for req in resource_requirements if req.type == "GPU"] + + if gpu_requirement: + raise ValueError( + f"Invalid fargate resource requirement: GPU. Use EC2 " + "platform capability if you need custom devices." + ) + + cpu_requirement = [req for req in resource_requirements if req.type == "VCPU"][0] + memory_requirement = [req for req in resource_requirements if req.type == "MEMORY"][0] + + if cpu_requirement.value not in VALID_FARGATE_VCPU: + raise ValueError( + f"Invalid fargate resource requirement VCPU value {cpu_requirement.value}." + f"Must be one of {VALID_FARGATE_VCPU}" + ) + + if memory_requirement.value not in VALID_FARGATE_MEMORY[cpu_requirement.value]: + raise ValueError( + f"Invalid fargate resource requirement MEMORY value {memory_requirement.value}." + f"For VCPU={cpu_requirement.value}, MEMORY must be one of {VALID_FARGATE_MEMORY[cpu_requirement.value]}" + ) - return value + return resource_requirements class AWSBatchJobDefinitionRetryStrategy(BaseModel): """An AWS Batch job subconfiguration model for retry specifications.""" @@ -268,25 +323,15 @@ def map_resource_settings(resource_settings: "ResourceSettings") -> List["Resour # handle cpu requirements if resource_settings.cpu_count is not None: - - cpu_count_int = math.ceil(resource_settings.cpu_count) - - if cpu_count_int != resource_settings.cpu_count: - logger.info( - "AWS Batch only accepts int type cpu resource " - f"requirements. Converted {resource_settings.cpu_count}" - f" to {cpu_count_int}." - ) - cpu_requirement = ResourceRequirement(value=str(cpu_count_int),type="VCPU") + cpu_requirement = ResourceRequirement(value=str(resource_settings.cpu_count),type="VCPU") else: cpu_requirement = ResourceRequirement(value="1",type="VCPU") mapped_resource_settings.append(cpu_requirement) # handle memory requirements - memory = resource_settings.get_memory(unit="MiB") + memory = resource_settings.get_memory(unit="MiB") if memory: - memory_requirement = ResourceRequirement( value=str(int(memory)), type="MEMORY" @@ -296,7 +341,7 @@ def map_resource_settings(resource_settings: "ResourceSettings") -> List["Resour mapped_resource_settings.append(memory_requirement) # handle gpu requirements - if resource_settings.gpu_count is not None: + if resource_settings.gpu_count is not None and resource_settings.gpu_count != 0: mapped_resource_settings.append( ResourceRequirement( value=str(resource_settings.gpu_count), diff --git a/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py b/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py index b24dade6db1..2f0bac6e3e4 100644 --- a/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py +++ b/tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py @@ -14,6 +14,7 @@ import pytest +from pydantic import ValidationError from zenml.config.resource_settings import ResourceSettings from zenml.integrations.aws.step_operators.aws_batch_step_operator import ( @@ -22,7 +23,9 @@ AWSBatchJobEC2Definition, AWSBatchJobFargateDefinition, AWSBatchJobDefinitionEC2ContainerProperties, - AWSBatchJobDefinitionFargateContainerProperties + AWSBatchJobDefinitionFargateContainerProperties, + VALID_FARGATE_MEMORY, + VALID_FARGATE_VCPU ) def test_aws_batch_step_operator_map_environment(): @@ -53,9 +56,9 @@ def test_aws_batch_step_operator_map_environment(): ] ), ( - ResourceSettings(cpu_count=0.4,gpu_count=1,memory="10MiB"), + ResourceSettings(cpu_count=0.25,gpu_count=1,memory="10MiB"), [ - ResourceRequirement(value="1",type="VCPU"), + ResourceRequirement(value="0.25",type="VCPU"), ResourceRequirement(value="10",type="MEMORY"), ResourceRequirement(value="1",type="GPU"), ] @@ -78,6 +81,134 @@ def test_aws_batch_step_operator_sanitize_name(test_name, expected): assert AWSBatchStepOperator.sanitize_name(test_name) == expected +@pytest.mark.parametrize( + "test_requirements,expected", + [ + ( + [ + ResourceRequirement(value="0.4",type="VCPU"), + ResourceRequirement(value="100",type="MEMORY"), + ResourceRequirement(value="1",type="GPU") + ],[ + ResourceRequirement(value="1",type="VCPU"), + ResourceRequirement(value="100",type="MEMORY"), + ResourceRequirement(value="1",type="GPU") + ] + ), + ( + [ + ResourceRequirement(value="1.1",type="VCPU"), + ResourceRequirement(value="100",type="MEMORY"), + ],[ + ResourceRequirement(value="2",type="VCPU"), + ResourceRequirement(value="100",type="MEMORY"), + ] + ), + ] +) +def test_aws_batch_job_definition_ec2_container_properties_resource_validation(test_requirements, expected): + actual = AWSBatchJobDefinitionEC2ContainerProperties( + image="test-image", + command=["test","command"], + jobRoleArn="test-job-role-arn", + executionRoleArn="test-execution-role-arn", + resourceRequirements=test_requirements + ) + + assert actual.resourceRequirements == expected + +@pytest.mark.parametrize( + "test_vcpu_memory_indices", + [ + (i,j) for i in range(len(VALID_FARGATE_VCPU)) for j in range(len(VALID_FARGATE_MEMORY[VALID_FARGATE_VCPU[i]])) + ] + +) +def test_aws_batch_job_definition_fargate_container_properties(test_vcpu_memory_indices): + + vcpu_index, memory_index = test_vcpu_memory_indices + test_vcpu_value = VALID_FARGATE_VCPU[vcpu_index] + test_memory_value = VALID_FARGATE_MEMORY[test_vcpu_value][memory_index] + + test_valid_requirements = [ + ResourceRequirement( + type="VCPU", + value=test_vcpu_value + ), + ResourceRequirement( + type="MEMORY", + value=test_memory_value + ) + ] + + AWSBatchJobDefinitionFargateContainerProperties( + image="test-image", + command=["test","command"], + jobRoleArn="test-job-role-arn", + executionRoleArn="test-execution-role-arn", + resourceRequirements=test_valid_requirements + ) + +@pytest.mark.parametrize( + "test_invalid_requirements,expected_message", + [ + ( + [ + ResourceRequirement( + type="VCPU", + value="invalid-value" + ), + ResourceRequirement( + type="MEMORY", + value="irrelevant-value" + ) + ], + "Invalid fargate resource requirement VCPU value*" + ), + ( + [ + ResourceRequirement( + type="VCPU", + value="16" # valid + ), + ResourceRequirement( + type="MEMORY", + value="invalid-value" + ) + ], + "Invalid fargate resource requirement MEMORY value*" + ), + ( + [ + ResourceRequirement( + type="VCPU", + value="irrelevant-value" + ), + ResourceRequirement( + type="MEMORY", + value="irrelevant-value" + ), + ResourceRequirement( + type="GPU", + value="1" # invalid + ) + ], + "Invalid fargate resource requirement: GPU. Use EC2*" + ) + ] +) +def test_aws_batch_job_definition_fargate_container_properties_raise_invalid_requirements(test_invalid_requirements,expected_message): + + with pytest.raises(ValidationError,match=expected_message): + + AWSBatchJobDefinitionFargateContainerProperties( + image="test-image", + command=["test","command"], + jobRoleArn="test-job-role-arn", + executionRoleArn="test-execution-role-arn", + resourceRequirements=test_invalid_requirements + ) + def test_aws_batch_job_ec2_definition(): AWSBatchJobEC2Definition( jobDefinitionName="test", @@ -86,7 +217,11 @@ def test_aws_batch_job_ec2_definition(): command=["test","command"], jobRoleArn="test-job-role-arn", executionRoleArn="test-execution-role-arn", - resourceRequirements=[ResourceRequirement(value="1",type="GPU")] + resourceRequirements=[ + ResourceRequirement(value="1",type="GPU"), + ResourceRequirement(value="1",type="VCPU"), + ResourceRequirement(value="1024",type="MEMORY") + ] ) ) @@ -98,19 +233,9 @@ def test_aws_batch_job_fargate_definition(): command=["test","command"], jobRoleArn="test-job-role-arn", executionRoleArn="test-execution-role-arn", - resourceRequirements=[ResourceRequirement(value="2",type="VCPU")] + resourceRequirements=[ + ResourceRequirement(value="0.5",type="VCPU"), + ResourceRequirement(value="3072",type="MEMORY") + ] ) - ) - -def test_aws_batch_job_fargate_definition_raise_gpu(): - with pytest.raises(ValueError, match="Invalid fargate resource requirement: GPU.Use EC2 platform capability if you need custom devices."): - AWSBatchJobFargateDefinition( - jobDefinitionName="test", - containerProperties=AWSBatchJobDefinitionFargateContainerProperties( - image="test-image", - command=["test","command"], - jobRoleArn="test-job-role-arn", - executionRoleArn="test-execution-role-arn", - resourceRequirements=[ResourceRequirement(value="1",type="GPU")] - ) - ) \ No newline at end of file + ) \ No newline at end of file From d3be040d0d52297e392ef4921609d64404eeb9a0 Mon Sep 17 00:00:00 2001 From: SebastianScherer88 Date: Mon, 6 Oct 2025 21:36:34 +0100 Subject: [PATCH 36/36] fixed bug in fargate resource memory validation range --- .../integrations/aws/step_operators/aws_batch_step_operator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py index 16ff89526fb..0eac0fe8b06 100644 --- a/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py +++ b/src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py @@ -62,7 +62,7 @@ '2': tuple(str(m) for m in range(4096, 16385, 1024)), '4': tuple(str(m) for m in range(8192, 30721, 1024)), '8': tuple(str(m) for m in range(16384, 61441, 4096)), - '16': tuple(str(m) for m in range(32768, 122880, 8192)) + '16': tuple(str(m) for m in range(32768, 122881, 8192)) } class ResourceRequirement(BaseModel):