5454
5555if TYPE_CHECKING :
5656 from zenml .config .base_settings import BaseSettings
57+ from zenml .config import ResourceSettings
5758 from zenml .config .step_run_info import StepRunInfo
5859 from zenml .models import PipelineDeploymentBase
5960
6364_ENTRYPOINT_ENV_VARIABLE = "__ZENML_ENTRYPOINT"
6465
6566class AWSBatchJobDefinitionContainerProperties (BaseModel ):
67+ """An AWS Batch job subconfiguration model for a container specification."""
6668 image : str
6769 command : List [str ]
6870 jobRoleArn : str
@@ -73,20 +75,31 @@ class AWSBatchJobDefinitionContainerProperties(BaseModel):
7375 secrets : List [Dict [str ,str ]] = [] # keys: 'name','value'
7476
7577class AWSBatchJobDefinitionNodePropertiesNodeRangeProperty (BaseModel ):
78+ """An AWS Batch job subconfiguration model for a node in a multinode job
79+ specifications.
80+
81+ Note: We include this class for completeness sake to make it easier to add
82+ multinode support later for now.
83+ """
7684 targetNodes : str
7785 container : AWSBatchJobDefinitionContainerProperties
7886
7987class AWSBatchJobDefinitionNodeProperties (BaseModel ):
80- # we include this class for completeness sake to make it easier
81- # to add multinode support later
82- # for now, we'll set defaults to intuitively represent the only supported
83- # exeuction type ('container'); in reality AWS Batch will ignore this
84- # config
88+ """An AWS Batch job subconfiguration model for multinode job specifications.
89+
90+ Note: We include this class for completeness sake to make it easier to add
91+ multinode support later for now, we'll set defaults to intuitively
92+ represent the only supported exeuction type ('container'); in reality AWS
93+ Batch will ignore this config.
94+ """
8595 numNodes : int = 1
8696 mainNode : int = 0
87- nodeRangeProperties : List [AWSBatchJobDefinitionNodePropertiesNodeRangeProperty ] = []
97+ nodeRangeProperties : List [
98+ AWSBatchJobDefinitionNodePropertiesNodeRangeProperty
99+ ] = []
88100
89101class AWSBatchJobDefinitionRetryStrategy (BaseModel ):
102+ """An AWS Batch job subconfiguration model for retry specifications."""
90103 attempts : int = 2
91104 evaluateOnExit : List [Dict [str ,str ]] = [
92105 {
@@ -102,15 +115,15 @@ class AWSBatchJobDefinitionRetryStrategy(BaseModel):
102115 "action" : "EXIT"
103116 }
104117 ]
105- # Example:
106- # {
107- # 'onStatusReason': 'string',
108- # 'onReason': 'string',
109- # 'onExitCode': 'string',
110- # 'action': 'RETRY'|'EXIT'
111- # },
112118
113119class AWSBatchJobDefinition (BaseModel ):
120+ """A utility to validate AWS Batch job descriptions.
121+
122+ Defaults fall into two categories:
123+ - reasonable default values
124+ - aligning the job description to be a valid 'container' type configuration,
125+ as multinode jobs are not supported yet."""
126+
114127 jobDefinitionName : str
115128 type : Literal ['container' ,'multinode' ] = 'container' # we dont support multinode type in this version
116129 parameters : Dict [str ,str ] = {}
@@ -126,15 +139,15 @@ class AWSBatchJobDefinition(BaseModel):
126139
127140
128141class AWSBatchStepOperator (BaseStepOperator ):
129- """Step operator to run a step on Sagemaker .
142+ """Step operator to run a step on AWS Batch .
130143
131144 This class defines code that builds an image with the ZenML entrypoint
132- to run using Sagemaker's Estimator .
145+ to run using AWS Batch .
133146 """
134147
135148 @property
136149 def config (self ) -> AWSBatchStepOperatorConfig :
137- """Returns the `SagemakerStepOperatorConfig ` config.
150+ """Returns the `AWSBatchStepOperatorConfig ` config.
138151
139152 Returns:
140153 The configuration.
@@ -143,7 +156,7 @@ def config(self) -> AWSBatchStepOperatorConfig:
143156
144157 @property
145158 def settings_class (self ) -> Optional [Type ["BaseSettings" ]]:
146- """Settings class for the SageMaker step operator.
159+ """Settings class for the AWS Batch step operator.
147160
148161 Returns:
149162 The settings class.
@@ -204,11 +217,117 @@ def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]:
204217 custom_validation_function = _validate_remote_components ,
205218 )
206219
220+ @staticmethod
221+ def map_environment (environment : Dict [str ,str ]) -> List [Dict [str ,str ]]:
222+ """Utility to map the {name:value} environment to the
223+ [{"name":name,"value":value},] convention used in the AWS Batch job
224+ definition spec.
225+
226+ Args:
227+ environment (Dict[str,str]): The step's environment variable
228+ specification
229+
230+ Returns:
231+ List[Dict[str,str]]: The mapped environment variable specification
232+ """
233+
234+ return [
235+ {"name" :k ,"value" :v } for k ,v in environment
236+ ]
237+
238+ @staticmethod
239+ def map_resource_settings (resource_settings : ResourceSettings ) -> List [Dict [str ,str ]]:
240+ """Utility to map the resource_settings to the resource convention used
241+ in the AWS Batch Job definition spec.
242+
243+ Args:
244+ resource_settings (ResourceSettings): The step's resource settings.
245+
246+ Returns:
247+ List[Dict[str,str]]: The mapped resource settings.
248+ """
249+ mapped_resource_settings = []
250+
251+ if resource_settings .empty :
252+ return mapped_resource_settings
253+ else :
254+
255+ if resource_settings .cpu_count is not None :
256+ mapped_resource_settings .append (
257+ {
258+ "value" : resource_settings .cpu_count ,
259+ "type" : 'VCPU'
260+ }
261+ )
262+
263+ if resource_settings .gpu_count is not None :
264+ mapped_resource_settings .append (
265+ {
266+ "value" : resource_settings .gpu_count ,
267+ "type" : 'GPU'
268+ }
269+ )
270+
271+ if resource_settings .get_memory () is not None :
272+ mapped_resource_settings .append (
273+ {
274+ "value" : resource_settings .get_memory (),
275+ "type" : 'MEMORY'
276+ }
277+ )
278+
279+ return mapped_resource_settings
280+
281+ @staticmethod
282+ def generate_unique_batch_job_name (info : "StepRunInfo" ) -> str :
283+ """Utility to generate a unique AWS Batch job name.
284+
285+ Args:
286+ info (StepRunInfo): The step run information.
287+
288+ Returns:
289+ str: A unique name for the step's AWS Batch job definition
290+ """
291+
292+ # Batch allows 63 characters at maximum for job name - ZenML uses 60 for safety margin.
293+ step_name = Client ().get_run_step (info .step_run_id ).name
294+ job_name = f"{ info .pipeline .name } -{ step_name } " [:55 ]
295+ suffix = random_str (4 )
296+ return f"{ job_name } -{ suffix } "
207297
208298 def generate_job_definition (self , info : "StepRunInfo" , entrypoint_command : List [str ], environment : Dict [str ,str ]) -> AWSBatchJobDefinition :
209299 """Utility to map zenml internal configurations to a valid AWS Batch
210300 job definition."""
211- pass
301+
302+ image_name = info .get_image (key = BATCH_DOCKER_IMAGE_KEY )
303+
304+ resource_settings = info .config .resource_settings
305+ step_settings = cast (AWSBatchStepOperatorSettings , self .get_settings (info ))
306+
307+ job_name = self .generate_unique_batch_job_name (info )
308+
309+ return AWSBatchJobDefinition (
310+ jobDefinitionName = job_name ,
311+ containerProperties = AWSBatchJobDefinitionContainerProperties (
312+ executionRoleArn = self .config .execution_role ,
313+ jobRoleArn = self .config .job_role ,
314+ image = image_name ,
315+ command = entrypoint_command ,
316+ environment = self .map_environment (environment ),
317+ instanceType = step_settings .instance_type ,
318+ resourceRequirements = self .map_resource_settings (resource_settings ),
319+ ),
320+ timeout = {'attemptDurationSeconds' :step_settings .timeout_seconds },
321+ # type: Literal['container','multinode'] = 'container' # we dont support multinode type in this version
322+ # parameters: Dict[str,str] = {}
323+ # schedulingPriority: int = 0 # ignored in FIFO queues
324+ # nodeProperties: AWSBatchJobDefinitionNodeProperties = AWSBatchJobDefinitionNodeProperties(
325+ # numNodes=1,mainNode=0,nodeRangeProperties=[]) # we'll focus on container mode for now - let's add multinode support later, as that will most likely require network configuration support as well
326+ # retryStrategy: AWSBatchJobDefinitionRetryStrategy = AWSBatchJobDefinitionRetryStrategy()
327+ # propagateTags: bool = False
328+ # tags: Dict[str,str] = {}
329+ # platformCapabilities: Literal['EC2','FARGATE'] = "EC2"
330+ )
212331
213332
214333 def get_docker_builds (
@@ -264,33 +383,20 @@ def launch(
264383 self .name ,
265384 )
266385
267- image_name = info .get_image (key = BATCH_DOCKER_IMAGE_KEY )
268-
269- settings = cast (AWSBatchStepOperatorSettings , self .get_settings (info ))
386+ job_definition = self .generate_job_definition (info , entrypoint_command , environment )
270387
271388 batch = boto3 .client ('batch' )
272-
273- # Batch allows 63 characters at maximum for job name - ZenML uses 60 for safety margin.
274- step_name = Client ().get_run_step (info .step_run_id ).name
275- training_job_name = f"{ info .pipeline .name } -{ step_name } " [:55 ]
276- suffix = random_str (4 )
277- unique_training_job_name = f"{ training_job_name } -{ suffix } "
278-
389+
279390 response = batch .register_job_definition (
280- jobDefinitionName = unique_training_job_name ,
281- type = 'container' ,
282- containerProperties = {
283- 'image' : image_name ,
284- 'command' : entrypoint_command ,
285- }
391+ ** job_definition .model_dump ()
286392 )
287393
288- job_definition = response ['jobDefinitionName' ]
394+ job_definition_name = response ['jobDefinitionName' ]
289395
290396 response = batch .submit_job (
291- jobName = unique_training_job_name ,
397+ jobName = job_definition . jobDefinitionName ,
292398 jobQueue = self .config .job_queue_name ,
293- jobDefinition = job_definition ,
399+ jobDefinition = job_definition_name ,
294400 )
295401
296402 job_id = response ['jobId' ]
0 commit comments