2525 Literal ,
2626 cast ,
2727)
28- from pydantic import BaseModel
28+ from pydantic import BaseModel , PositiveInt
2929import boto3
3030
3131from zenml .client import Client
@@ -71,30 +71,22 @@ class AWSBatchJobDefinitionContainerProperties(BaseModel):
7171class AWSBatchJobDefinitionNodePropertiesNodeRangeProperty (BaseModel ):
7272 """An AWS Batch job subconfiguration model for a node in a multinode job
7373 specifications.
74-
75- Note: We include this class for completeness sake to make it easier to add
76- multinode support later for now.
7774 """
7875 targetNodes : str
7976 container : AWSBatchJobDefinitionContainerProperties
8077
8178class AWSBatchJobDefinitionNodeProperties (BaseModel ):
8279 """An AWS Batch job subconfiguration model for multinode job specifications.
83-
84- Note: We include this class for completeness sake to make it easier to add
85- multinode support later for now, we'll set defaults to intuitively
86- represent the only supported exeuction type ('container'); in reality AWS
87- Batch will ignore this config.
8880 """
89- numNodes : int = 1
81+ numNodes : PositiveInt = 1
9082 mainNode : int = 0
9183 nodeRangeProperties : List [
9284 AWSBatchJobDefinitionNodePropertiesNodeRangeProperty
9385 ] = []
9486
9587class AWSBatchJobDefinitionRetryStrategy (BaseModel ):
9688 """An AWS Batch job subconfiguration model for retry specifications."""
97- attempts : int = 2
89+ attempts : PositiveInt = 2
9890 evaluateOnExit : List [Dict [str ,str ]] = [
9991 {
10092 "onExitCode" : "137" , # out-of-memory killed
@@ -122,9 +114,8 @@ class AWSBatchJobDefinition(BaseModel):
122114 type : Literal ['container' ,'multinode' ] = 'container' # we dont support multinode type in this version
123115 parameters : Dict [str ,str ] = {}
124116 schedulingPriority : int = 0 # ignored in FIFO queues
125- containerProperties : AWSBatchJobDefinitionContainerProperties
126- nodeProperties : AWSBatchJobDefinitionNodeProperties = AWSBatchJobDefinitionNodeProperties (
127- numNodes = 1 ,mainNode = 0 ,nodeRangeProperties = []) # we'll focus on container mode for now - let's add multinode support later, as that will most likely require network configuration support as well
117+ containerProperties : Optional [AWSBatchJobDefinitionContainerProperties ] = None
118+ nodeProperties : Optional [AWSBatchJobDefinitionNodeProperties ] = None
128119 retryStrategy : AWSBatchJobDefinitionRetryStrategy = AWSBatchJobDefinitionRetryStrategy ()
129120 propagateTags : bool = False
130121 timeout : Dict [str ,int ] = {'attemptDurationSeconds' :60 } # key 'attemptDurationSeconds'
@@ -305,10 +296,7 @@ def generate_job_definition(self, info: "StepRunInfo", entrypoint_command: List[
305296 step_settings = cast (AWSBatchStepOperatorSettings , self .get_settings (info ))
306297
307298 job_name = self .generate_unique_batch_job_name (info )
308-
309- return AWSBatchJobDefinition (
310- jobDefinitionName = job_name ,
311- containerProperties = AWSBatchJobDefinitionContainerProperties (
299+ container_properties = AWSBatchJobDefinitionContainerProperties (
312300 executionRoleArn = self .config .execution_role ,
313301 jobRoleArn = self .config .job_role ,
314302 image = image_name ,
@@ -317,8 +305,33 @@ def generate_job_definition(self, info: "StepRunInfo", entrypoint_command: List[
317305 instanceType = step_settings .instance_type ,
318306 resourceRequirements = self .map_resource_settings (resource_settings ),
319307 ),
308+
309+ node_count = step_settings .node_count
310+
311+ if node_count == 1 :
312+ kwargs = {
313+ 'type' :'container' ,
314+ 'containerProperties' :container_properties
315+ }
316+ else :
317+ kwargs = {
318+ 'type' :'multinode' ,
319+ 'nodeProperties' :AWSBatchJobDefinitionNodeProperties (
320+ numNodes = node_count ,
321+ nodeRangeProperties = [
322+ AWSBatchJobDefinitionNodePropertiesNodeRangeProperty (
323+ targetNodes = ',' .join ([str (node_index ) for node_index in range (node_count )]),
324+ container = container_properties
325+ )
326+ ]
327+ )
328+ }
329+
330+
331+ return AWSBatchJobDefinition (
332+ jobDefinitionName = job_name ,
320333 timeout = {'attemptDurationSeconds' :step_settings .timeout_seconds },
321- # type: Literal['container','multinode'] = 'container' # we dont support multinode type in this version
334+ ** kwargs
322335 # parameters: Dict[str,str] = {}
323336 # schedulingPriority: int = 0 # ignored in FIFO queues
324337 # nodeProperties: AWSBatchJobDefinitionNodeProperties = AWSBatchJobDefinitionNodeProperties(
0 commit comments