Skip to content

Commit c9b5829

Browse files
added support for multinode aws batch job type
1 parent 0543331 commit c9b5829

File tree

2 files changed

+44
-23
lines changed

2 files changed

+44
-23
lines changed

src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
# permissions and limitations under the License.
1414
"""Amazon SageMaker step operator flavor."""
1515

16-
from typing import TYPE_CHECKING, Dict, Optional, Type
16+
from typing import TYPE_CHECKING, Dict, Optional, Type, List, Union
1717

18-
from pydantic import Field
18+
from pydantic import Field, PositiveInt
1919

2020
from zenml.config.base_settings import BaseSettings
2121
from zenml.integrations.aws import (
@@ -35,7 +35,7 @@
3535
class AWSBatchStepOperatorSettings(BaseSettings):
3636
"""Settings for the Sagemaker step operator."""
3737

38-
instance_type: str = Field(
38+
instance_type: Union[str,List[str]] = Field(
3939
default='optimal',
4040
description="The instance type for AWS Batch to use for the step" \
4141
" execution. Example: 'm5.xlarge'",
@@ -45,11 +45,19 @@ class AWSBatchStepOperatorSettings(BaseSettings):
4545
description="Environment variables to pass to the container during " \
4646
"execution. Example: {'LOG_LEVEL': 'INFO', 'DEBUG_MODE': 'False'}",
4747
)
48-
timeout_seconds: int = Field(
48+
node_count: PositiveInt = Field(
49+
description="The number of AWS Batch nodes to run the step on. If > 1," \
50+
"an AWS Batch multinode job will be run, with the network connectivity" \
51+
"between the nodes provided by AWS Batch. See https://docs.aws.amazon.com/batch/latest/userguide/multi-node-parallel-jobs.html" \
52+
"for details."
53+
)
54+
timeout_seconds: PositiveInt = Field(
4955
default=120,
5056
description="The number of seconds before AWS Batch times out the job."
5157
)
5258

59+
60+
5361
class AWSBatchStepOperatorConfig(
5462
BaseStepOperatorConfig, AWSBatchStepOperatorSettings
5563
):

src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
Literal,
2626
cast,
2727
)
28-
from pydantic import BaseModel
28+
from pydantic import BaseModel, PositiveInt
2929
import boto3
3030

3131
from zenml.client import Client
@@ -71,30 +71,22 @@ class AWSBatchJobDefinitionContainerProperties(BaseModel):
7171
class AWSBatchJobDefinitionNodePropertiesNodeRangeProperty(BaseModel):
7272
"""An AWS Batch job subconfiguration model for a node in a multinode job
7373
specifications.
74-
75-
Note: We include this class for completeness sake to make it easier to add
76-
multinode support later for now.
7774
"""
7875
targetNodes: str
7976
container: AWSBatchJobDefinitionContainerProperties
8077

8178
class AWSBatchJobDefinitionNodeProperties(BaseModel):
8279
"""An AWS Batch job subconfiguration model for multinode job specifications.
83-
84-
Note: We include this class for completeness sake to make it easier to add
85-
multinode support later for now, we'll set defaults to intuitively
86-
represent the only supported exeuction type ('container'); in reality AWS
87-
Batch will ignore this config.
8880
"""
89-
numNodes: int = 1
81+
numNodes: PositiveInt = 1
9082
mainNode: int = 0
9183
nodeRangeProperties: List[
9284
AWSBatchJobDefinitionNodePropertiesNodeRangeProperty
9385
] = []
9486

9587
class AWSBatchJobDefinitionRetryStrategy(BaseModel):
9688
"""An AWS Batch job subconfiguration model for retry specifications."""
97-
attempts: int = 2
89+
attempts: PositiveInt = 2
9890
evaluateOnExit: List[Dict[str,str]] = [
9991
{
10092
"onExitCode": "137", # out-of-memory killed
@@ -122,9 +114,8 @@ class AWSBatchJobDefinition(BaseModel):
122114
type: Literal['container','multinode'] = 'container' # we dont support multinode type in this version
123115
parameters: Dict[str,str] = {}
124116
schedulingPriority: int = 0 # ignored in FIFO queues
125-
containerProperties: AWSBatchJobDefinitionContainerProperties
126-
nodeProperties: AWSBatchJobDefinitionNodeProperties = AWSBatchJobDefinitionNodeProperties(
127-
numNodes=1,mainNode=0,nodeRangeProperties=[]) # we'll focus on container mode for now - let's add multinode support later, as that will most likely require network configuration support as well
117+
containerProperties: Optional[AWSBatchJobDefinitionContainerProperties] = None
118+
nodeProperties: Optional[AWSBatchJobDefinitionNodeProperties] = None
128119
retryStrategy: AWSBatchJobDefinitionRetryStrategy = AWSBatchJobDefinitionRetryStrategy()
129120
propagateTags: bool = False
130121
timeout: Dict[str,int] = {'attemptDurationSeconds':60} # key 'attemptDurationSeconds'
@@ -305,10 +296,7 @@ def generate_job_definition(self, info: "StepRunInfo", entrypoint_command: List[
305296
step_settings = cast(AWSBatchStepOperatorSettings, self.get_settings(info))
306297

307298
job_name = self.generate_unique_batch_job_name(info)
308-
309-
return AWSBatchJobDefinition(
310-
jobDefinitionName=job_name,
311-
containerProperties=AWSBatchJobDefinitionContainerProperties(
299+
container_properties = AWSBatchJobDefinitionContainerProperties(
312300
executionRoleArn=self.config.execution_role,
313301
jobRoleArn=self.config.job_role,
314302
image=image_name,
@@ -317,8 +305,33 @@ def generate_job_definition(self, info: "StepRunInfo", entrypoint_command: List[
317305
instanceType=step_settings.instance_type,
318306
resourceRequirements=self.map_resource_settings(resource_settings),
319307
),
308+
309+
node_count = step_settings.node_count
310+
311+
if node_count == 1:
312+
kwargs = {
313+
'type':'container',
314+
'containerProperties':container_properties
315+
}
316+
else:
317+
kwargs = {
318+
'type':'multinode',
319+
'nodeProperties':AWSBatchJobDefinitionNodeProperties(
320+
numNodes=node_count,
321+
nodeRangeProperties=[
322+
AWSBatchJobDefinitionNodePropertiesNodeRangeProperty(
323+
targetNodes=','.join([str(node_index) for node_index in range(node_count)]),
324+
container=container_properties
325+
)
326+
]
327+
)
328+
}
329+
330+
331+
return AWSBatchJobDefinition(
332+
jobDefinitionName=job_name,
320333
timeout={'attemptDurationSeconds':step_settings.timeout_seconds},
321-
# type: Literal['container','multinode'] = 'container' # we dont support multinode type in this version
334+
**kwargs
322335
# parameters: Dict[str,str] = {}
323336
# schedulingPriority: int = 0 # ignored in FIFO queues
324337
# nodeProperties: AWSBatchJobDefinitionNodeProperties = AWSBatchJobDefinitionNodeProperties(

0 commit comments

Comments
 (0)