Skip to content

Commit 3d8c39b

Browse files
finished first draft of aws batch step operator
1 parent e372f85 commit 3d8c39b

File tree

2 files changed

+164
-50
lines changed

2 files changed

+164
-50
lines changed

src/zenml/integrations/aws/flavors/batch_step_operator_flavor.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,35 +37,43 @@ class AWSBatchStepOperatorSettings(BaseSettings):
3737
"""Settings for the Sagemaker step operator."""
3838

3939
instance_type: Optional[str] = Field(
40-
None,
41-
description="DEPRECATED: The instance type to use for the step execution. "
42-
"Use estimator_args instead. Example: 'ml.m5.xlarge'",
40+
'optimal',
41+
description="The instance type for AWS Batch to use for the step" \
42+
" execution. Example: 'm5.xlarge'",
4343
)
4444
environment: Dict[str, str] = Field(
4545
default_factory=dict,
46-
description="Environment variables to pass to the container during execution. "
47-
"Example: {'LOG_LEVEL': 'INFO', 'DEBUG_MODE': 'False'}",
46+
description="Environment variables to pass to the container during " \
47+
"execution. Example: {'LOG_LEVEL': 'INFO', 'DEBUG_MODE': 'False'}",
4848
)
49-
50-
_deprecation_validator = deprecation_utils.deprecate_pydantic_attributes(
51-
"instance_type"
49+
timeout_seconds: int = Field(
50+
default=120,
51+
description="The number of seconds before AWS Batch times out the job."
5252
)
5353

54-
5554
class AWSBatchStepOperatorConfig(
5655
BaseStepOperatorConfig, AWSBatchStepOperatorSettings
5756
):
58-
"""Config for the AWS Batch step operator."""
57+
"""Config for the AWS Batch step operator.
58+
59+
Note: We use ECS as a backend (not EKS), and EC2 as a compute engine (not
60+
Fargate). This is because
61+
- users can avoid the complexity of setting up an EKS cluster, and
62+
- we can AWS Batch multinode type job support later, which requires EC2
63+
"""
5964

6065
execution_role: str = Field(
6166
"",
6267
description="The ECS execution role required to execute the AWS Batch" \
63-
" jobs as ECS tasks."
68+
" jobs as an ECS tasks."
6469
)
6570
job_role: str = Field(
6671
"",
67-
description="The ECS job role required by the container runtime insdide" \
68-
"the ECS task implementing the zenml step."
72+
description="The ECS job role required by the container runtime inside" \
73+
"the ECS task."
74+
)
75+
job_queue_name: str = Field(
76+
description="The AWS Batch job queue to submit AWS Batch jobs to."
6977
)
7078

7179
@property

src/zenml/integrations/aws/step_operators/batch_step_operator.py

Lines changed: 143 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454

5555
if TYPE_CHECKING:
5656
from zenml.config.base_settings import BaseSettings
57+
from zenml.config import ResourceSettings
5758
from zenml.config.step_run_info import StepRunInfo
5859
from zenml.models import PipelineDeploymentBase
5960

@@ -63,6 +64,7 @@
6364
_ENTRYPOINT_ENV_VARIABLE = "__ZENML_ENTRYPOINT"
6465

6566
class AWSBatchJobDefinitionContainerProperties(BaseModel):
67+
"""An AWS Batch job subconfiguration model for a container specification."""
6668
image: str
6769
command: List[str]
6870
jobRoleArn: str
@@ -73,20 +75,31 @@ class AWSBatchJobDefinitionContainerProperties(BaseModel):
7375
secrets: List[Dict[str,str]] = [] # keys: 'name','value'
7476

7577
class AWSBatchJobDefinitionNodePropertiesNodeRangeProperty(BaseModel):
78+
"""An AWS Batch job subconfiguration model for a node in a multinode job
79+
specifications.
80+
81+
Note: We include this class for completeness sake to make it easier to add
82+
multinode support later for now.
83+
"""
7684
targetNodes: str
7785
container: AWSBatchJobDefinitionContainerProperties
7886

7987
class AWSBatchJobDefinitionNodeProperties(BaseModel):
80-
# we include this class for completeness sake to make it easier
81-
# to add multinode support later
82-
# for now, we'll set defaults to intuitively represent the only supported
83-
# exeuction type ('container'); in reality AWS Batch will ignore this
84-
# config
88+
"""An AWS Batch job subconfiguration model for multinode job specifications.
89+
90+
Note: We include this class for completeness sake to make it easier to add
91+
multinode support later for now, we'll set defaults to intuitively
92+
represent the only supported exeuction type ('container'); in reality AWS
93+
Batch will ignore this config.
94+
"""
8595
numNodes: int = 1
8696
mainNode: int = 0
87-
nodeRangeProperties: List[AWSBatchJobDefinitionNodePropertiesNodeRangeProperty] = []
97+
nodeRangeProperties: List[
98+
AWSBatchJobDefinitionNodePropertiesNodeRangeProperty
99+
] = []
88100

89101
class AWSBatchJobDefinitionRetryStrategy(BaseModel):
102+
"""An AWS Batch job subconfiguration model for retry specifications."""
90103
attempts: int = 2
91104
evaluateOnExit: List[Dict[str,str]] = [
92105
{
@@ -102,15 +115,15 @@ class AWSBatchJobDefinitionRetryStrategy(BaseModel):
102115
"action": "EXIT"
103116
}
104117
]
105-
# Example:
106-
# {
107-
# 'onStatusReason': 'string',
108-
# 'onReason': 'string',
109-
# 'onExitCode': 'string',
110-
# 'action': 'RETRY'|'EXIT'
111-
# },
112118

113119
class AWSBatchJobDefinition(BaseModel):
120+
"""A utility to validate AWS Batch job descriptions.
121+
122+
Defaults fall into two categories:
123+
- reasonable default values
124+
- aligning the job description to be a valid 'container' type configuration,
125+
as multinode jobs are not supported yet."""
126+
114127
jobDefinitionName: str
115128
type: Literal['container','multinode'] = 'container' # we dont support multinode type in this version
116129
parameters: Dict[str,str] = {}
@@ -126,15 +139,15 @@ class AWSBatchJobDefinition(BaseModel):
126139

127140

128141
class AWSBatchStepOperator(BaseStepOperator):
129-
"""Step operator to run a step on Sagemaker.
142+
"""Step operator to run a step on AWS Batch.
130143
131144
This class defines code that builds an image with the ZenML entrypoint
132-
to run using Sagemaker's Estimator.
145+
to run using AWS Batch.
133146
"""
134147

135148
@property
136149
def config(self) -> AWSBatchStepOperatorConfig:
137-
"""Returns the `SagemakerStepOperatorConfig` config.
150+
"""Returns the `AWSBatchStepOperatorConfig` config.
138151
139152
Returns:
140153
The configuration.
@@ -143,7 +156,7 @@ def config(self) -> AWSBatchStepOperatorConfig:
143156

144157
@property
145158
def settings_class(self) -> Optional[Type["BaseSettings"]]:
146-
"""Settings class for the SageMaker step operator.
159+
"""Settings class for the AWS Batch step operator.
147160
148161
Returns:
149162
The settings class.
@@ -204,11 +217,117 @@ def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]:
204217
custom_validation_function=_validate_remote_components,
205218
)
206219

220+
@staticmethod
221+
def map_environment(environment: Dict[str,str]) -> List[Dict[str,str]]:
222+
"""Utility to map the {name:value} environment to the
223+
[{"name":name,"value":value},] convention used in the AWS Batch job
224+
definition spec.
225+
226+
Args:
227+
environment (Dict[str,str]): The step's environment variable
228+
specification
229+
230+
Returns:
231+
List[Dict[str,str]]: The mapped environment variable specification
232+
"""
233+
234+
return [
235+
{"name":k,"value":v} for k,v in environment
236+
]
237+
238+
@staticmethod
239+
def map_resource_settings(resource_settings: ResourceSettings) -> List[Dict[str,str]]:
240+
"""Utility to map the resource_settings to the resource convention used
241+
in the AWS Batch Job definition spec.
242+
243+
Args:
244+
resource_settings (ResourceSettings): The step's resource settings.
245+
246+
Returns:
247+
List[Dict[str,str]]: The mapped resource settings.
248+
"""
249+
mapped_resource_settings = []
250+
251+
if resource_settings.empty:
252+
return mapped_resource_settings
253+
else:
254+
255+
if resource_settings.cpu_count is not None:
256+
mapped_resource_settings.append(
257+
{
258+
"value": resource_settings.cpu_count,
259+
"type": 'VCPU'
260+
}
261+
)
262+
263+
if resource_settings.gpu_count is not None:
264+
mapped_resource_settings.append(
265+
{
266+
"value": resource_settings.gpu_count,
267+
"type": 'GPU'
268+
}
269+
)
270+
271+
if resource_settings.get_memory() is not None:
272+
mapped_resource_settings.append(
273+
{
274+
"value": resource_settings.get_memory(),
275+
"type": 'MEMORY'
276+
}
277+
)
278+
279+
return mapped_resource_settings
280+
281+
@staticmethod
282+
def generate_unique_batch_job_name(info: "StepRunInfo") -> str:
283+
"""Utility to generate a unique AWS Batch job name.
284+
285+
Args:
286+
info (StepRunInfo): The step run information.
287+
288+
Returns:
289+
str: A unique name for the step's AWS Batch job definition
290+
"""
291+
292+
# Batch allows 63 characters at maximum for job name - ZenML uses 60 for safety margin.
293+
step_name = Client().get_run_step(info.step_run_id).name
294+
job_name = f"{info.pipeline.name}-{step_name}"[:55]
295+
suffix = random_str(4)
296+
return f"{job_name}-{suffix}"
207297

208298
def generate_job_definition(self, info: "StepRunInfo", entrypoint_command: List[str], environment: Dict[str,str]) -> AWSBatchJobDefinition:
209299
"""Utility to map zenml internal configurations to a valid AWS Batch
210300
job definition."""
211-
pass
301+
302+
image_name = info.get_image(key=BATCH_DOCKER_IMAGE_KEY)
303+
304+
resource_settings = info.config.resource_settings
305+
step_settings = cast(AWSBatchStepOperatorSettings, self.get_settings(info))
306+
307+
job_name = self.generate_unique_batch_job_name(info)
308+
309+
return AWSBatchJobDefinition(
310+
jobDefinitionName=job_name,
311+
containerProperties=AWSBatchJobDefinitionContainerProperties(
312+
executionRoleArn=self.config.execution_role,
313+
jobRoleArn=self.config.job_role,
314+
image=image_name,
315+
command=entrypoint_command,
316+
environment=self.map_environment(environment),
317+
instanceType=step_settings.instance_type,
318+
resourceRequirements=self.map_resource_settings(resource_settings),
319+
),
320+
timeout={'attemptDurationSeconds':step_settings.timeout_seconds},
321+
# type: Literal['container','multinode'] = 'container' # we dont support multinode type in this version
322+
# parameters: Dict[str,str] = {}
323+
# schedulingPriority: int = 0 # ignored in FIFO queues
324+
# nodeProperties: AWSBatchJobDefinitionNodeProperties = AWSBatchJobDefinitionNodeProperties(
325+
# numNodes=1,mainNode=0,nodeRangeProperties=[]) # we'll focus on container mode for now - let's add multinode support later, as that will most likely require network configuration support as well
326+
# retryStrategy: AWSBatchJobDefinitionRetryStrategy = AWSBatchJobDefinitionRetryStrategy()
327+
# propagateTags: bool = False
328+
# tags: Dict[str,str] = {}
329+
# platformCapabilities: Literal['EC2','FARGATE'] = "EC2"
330+
)
212331

213332

214333
def get_docker_builds(
@@ -264,33 +383,20 @@ def launch(
264383
self.name,
265384
)
266385

267-
image_name = info.get_image(key=BATCH_DOCKER_IMAGE_KEY)
268-
269-
settings = cast(AWSBatchStepOperatorSettings, self.get_settings(info))
386+
job_definition = self.generate_job_definition(info, entrypoint_command, environment)
270387

271388
batch = boto3.client('batch')
272-
273-
# Batch allows 63 characters at maximum for job name - ZenML uses 60 for safety margin.
274-
step_name = Client().get_run_step(info.step_run_id).name
275-
training_job_name = f"{info.pipeline.name}-{step_name}"[:55]
276-
suffix = random_str(4)
277-
unique_training_job_name = f"{training_job_name}-{suffix}"
278-
389+
279390
response = batch.register_job_definition(
280-
jobDefinitionName=unique_training_job_name,
281-
type='container',
282-
containerProperties={
283-
'image': image_name ,
284-
'command': entrypoint_command,
285-
}
391+
**job_definition.model_dump()
286392
)
287393

288-
job_definition = response['jobDefinitionName']
394+
job_definition_name = response['jobDefinitionName']
289395

290396
response = batch.submit_job(
291-
jobName=unique_training_job_name,
397+
jobName=job_definition.jobDefinitionName,
292398
jobQueue=self.config.job_queue_name,
293-
jobDefinition=job_definition,
399+
jobDefinition=job_definition_name,
294400
)
295401

296402
job_id = response['jobId']

0 commit comments

Comments
 (0)