Skip to content

Commit c787379

Browse files
added support for multinode aws batch job type
1 parent c9b5829 commit c787379

File tree

4 files changed

+61
-3
lines changed

4 files changed

+61
-3
lines changed

src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class AWSBatchStepOperatorSettings(BaseSettings):
4646
"execution. Example: {'LOG_LEVEL': 'INFO', 'DEBUG_MODE': 'False'}",
4747
)
4848
node_count: PositiveInt = Field(
49+
default=1,
4950
description="The number of AWS Batch nodes to run the step on. If > 1," \
5051
"an AWS Batch multinode job will be run, with the network connectivity" \
5152
"between the nodes provided by AWS Batch. See https://docs.aws.amazon.com/batch/latest/userguide/multi-node-parallel-jobs.html" \

src/zenml/integrations/aws/step_operators/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,6 @@
1818
)
1919
from zenml.integrations.aws.step_operators.aws_batch_step_operator import ( # noqa: F401
2020
AWSBatchStepOperator,
21+
get_aws_batch_context
2122
)
22-
__all__ = ["SagemakerStepOperator","AWSBatchStepOperator"]
23+
__all__ = ["SagemakerStepOperator","AWSBatchStepOperator","get_aws_batch_context"]

src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
Literal,
2626
cast,
2727
)
28-
from pydantic import BaseModel, PositiveInt
28+
from pydantic import BaseModel, PositiveInt, Field
29+
from pydantic_settings import BaseSettings as PydanticBaseSettings
2930
import boto3
3031

3132
from zenml.client import Client
@@ -57,6 +58,43 @@
5758
BATCH_DOCKER_IMAGE_KEY = "aws_batch_step_operator"
5859
_ENTRYPOINT_ENV_VARIABLE = "__ZENML_ENTRYPOINT"
5960

61+
class AWSBatchContext(PydanticBaseSettings):
62+
"""A utility to access the AWS Batch job context environment variables."""
63+
main_node_index: int = Field(
64+
description="This variable is set to the index number of the job's "
65+
"main node. Your application code can compare the "
66+
"AWS_BATCH_JOB_MAIN_NODE_INDEX to the AWS_BATCH_JOB_NODE_INDEX on "
67+
"an individual node to determine if it's the main node.",
68+
alias="AWS_BATCH_JOB_MAIN_NODE_INDEX"
69+
)
70+
71+
main_node_address: str = Field(
72+
description="This variable is only set in multi-node parallel job "
73+
"child nodes. This variable isn't present on the main node. This "
74+
"variable is set to the private IPv4 address of the job's main "
75+
"node. Your child node's application code can use this address to "
76+
"communicate with the main node.",
77+
alias="AWS_BATCH_JOB_MAIN_NODE_PRIVATE_IPV4_ADDRESS"
78+
)
79+
80+
node_index: int = Field(
81+
description="This variable is set to the node index number of the "
82+
"node. The node index begins at 0, and each node receives a unique"
83+
" index number. For example, a multi-node parallel job with 10 "
84+
"children has index values of 0-9.",
85+
alias="AWS_BATCH_JOB_NODE_INDEX"
86+
)
87+
88+
num_nodes: int = Field(
89+
description="This variable is set to the number of nodes that you have"
90+
"requested for your multi-node parallel job.",
91+
alias="AWS_BATCH_JOB_NUM_NODES"
92+
)
93+
94+
def get_aws_batch_context() -> AWSBatchContext:
95+
"""Utility to retrieve the AWS Batch runtime context."""
96+
return AWSBatchContext()
97+
6098
class AWSBatchJobDefinitionContainerProperties(BaseModel):
6199
"""An AWS Batch job subconfiguration model for a container specification."""
62100
image: str

tests/integration/integrations/aws/step_operators/test_aws_batch_step_operator.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,25 @@
1616
import pytest
1717

1818
from zenml.config.resource_settings import ResourceSettings
19-
from zenml.integrations.aws.step_operators.aws_batch_step_operator import AWSBatchStepOperator
19+
from zenml.integrations.aws.step_operators.aws_batch_step_operator import AWSBatchStepOperator, get_aws_batch_context
20+
21+
def test_aws_batch_context(monkeypatch):
22+
"""Tests the AWSBatchContext class."""
23+
24+
monkeypatch.setenv('AWS_BATCH_JOB_MAIN_NODE_INDEX',0)
25+
monkeypatch.setenv('AWS_BATCH_JOB_MAIN_NODE_PRIVATE_IPV4_ADDRESS','test-address')
26+
monkeypatch.setenv('AWS_BATCH_JOB_NODE_INDEX',1)
27+
monkeypatch.setenv('AWS_BATCH_JOB_NUM_NODES',2)
28+
29+
test_aws_batch_context = get_aws_batch_context()
30+
assert test_aws_batch_context.main_node_index == 0
31+
assert test_aws_batch_context.main_node_address == 'test-address'
32+
assert test_aws_batch_context.node_index == 1
33+
assert test_aws_batch_context.num_nodes == 2
2034

2135
def test_aws_batch_step_operator_map_environment():
36+
"""Tests the AWSBatchStepOperator's map_environment class method."""
37+
2238
test_environment = {'key_1':'value_1','key_2':'value_2'}
2339
expected = [
2440
{
@@ -83,4 +99,6 @@ def test_aws_batch_step_operator_map_environment():
8399
]
84100
)
85101
def test_aws_batch_step_operator_map_resource_settings(test_resource_settings,expected):
102+
"""Tests the AWSBatchStepOperator's map_resource_settings class method."""
103+
86104
assert AWSBatchStepOperator.map_resource_settings(test_resource_settings) == expected

0 commit comments

Comments
 (0)