Skip to content

Commit c6f2a87

Browse files
renaming modules and adding unit tests
1 parent 6b076d1 commit c6f2a87

File tree

9 files changed

+130
-41
lines changed

9 files changed

+130
-41
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ dev = [
250250
"darglint",
251251
"pytest-randomly",
252252
"pytest-mock",
253-
"pytest-clarity",
253+
#"pytest-clarity",
254254
"pytest-instafail",
255255
"pytest-rerunfailures",
256256
"pytest-split",

src/zenml/integrations/aws/flavors/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
SagemakerStepOperatorConfig,
3030
SagemakerStepOperatorFlavor,
3131
)
32-
from zenml.integrations.aws.flavors.batch_step_operator_flavor import (
32+
from zenml.integrations.aws.flavors.aws_batch_step_operator_flavor import (
3333
AWSBatchStepOperatorConfig,
3434
AWSBatchStepOperatorFlavor
3535
)

src/zenml/integrations/aws/flavors/batch_step_operator_flavor.py renamed to src/zenml/integrations/aws/flavors/aws_batch_step_operator_flavor.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# permissions and limitations under the License.
1414
"""Amazon SageMaker step operator flavor."""
1515

16-
from typing import TYPE_CHECKING, Any, Dict, Optional, Type, Union
16+
from typing import TYPE_CHECKING, Dict, Optional, Type
1717

1818
from pydantic import Field
1919

@@ -27,7 +27,6 @@
2727
BaseStepOperatorConfig,
2828
BaseStepOperatorFlavor,
2929
)
30-
from zenml.utils import deprecation_utils
3130

3231
if TYPE_CHECKING:
3332
from zenml.integrations.aws.step_operators import AWSBatchStepOperator
@@ -36,13 +35,13 @@
3635
class AWSBatchStepOperatorSettings(BaseSettings):
3736
"""Settings for the Sagemaker step operator."""
3837

39-
instance_type: Optional[str] = Field(
40-
'optimal',
38+
instance_type: str = Field(
39+
default='optimal',
4140
description="The instance type for AWS Batch to use for the step" \
4241
" execution. Example: 'm5.xlarge'",
4342
)
4443
environment: Dict[str, str] = Field(
45-
default_factory=dict,
44+
default={},
4645
description="Environment variables to pass to the container during " \
4746
"execution. Example: {'LOG_LEVEL': 'INFO', 'DEBUG_MODE': 'False'}",
4847
)

src/zenml/integrations/aws/step_operators/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from zenml.integrations.aws.step_operators.sagemaker_step_operator import ( # noqa: F401
1717
SagemakerStepOperator,
1818
)
19-
from zenml.integrations.aws.step_operators.batch_step_operator import ( # noqa: F401
19+
from zenml.integrations.aws.step_operators.aws_batch_step_operator import ( # noqa: F401
2020
AWSBatchStepOperator,
2121
)
2222
__all__ = ["SagemakerStepOperator","AWSBatchStepOperator"]

src/zenml/integrations/aws/step_operators/batch_step_operator.py renamed to src/zenml/integrations/aws/step_operators/aws_batch_step_operator.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,33 +14,28 @@
1414
"""Implementation of the Sagemaker Step Operator."""
1515

1616
import time
17+
import math
1718
from typing import (
1819
TYPE_CHECKING,
1920
Dict,
2021
List,
2122
Optional,
2223
Tuple,
2324
Type,
24-
Union,
2525
Literal,
2626
cast,
2727
)
2828
from pydantic import BaseModel
29-
3029
import boto3
31-
from sagemaker.estimator import Estimator
32-
from sagemaker.inputs import TrainingInput
33-
from sagemaker.session import Session
3430

3531
from zenml.client import Client
3632
from zenml.config.build_configuration import BuildConfiguration
3733
from zenml.enums import StackComponentType
38-
from zenml.integrations.aws.flavors.batch_step_operator_flavor import (
34+
from zenml.integrations.aws.flavors.aws_batch_step_operator_flavor import (
3935
AWSBatchStepOperatorConfig,
4036
AWSBatchStepOperatorSettings,
4137
)
42-
from zenml.integrations.aws.step_operators.batch_step_operator_entrypoint_config import (
43-
BATCH_STEP_ENV_VAR_SIZE_LIMIT,
38+
from zenml.integrations.aws.step_operators.aws_batch_step_operator_entrypoint_config import (
4439
AWSBatchEntrypointConfiguration,
4540
)
4641
from zenml.logger import get_logger
@@ -49,7 +44,6 @@
4944
from zenml.step_operators.step_operator_entrypoint_configuration import (
5045
StepOperatorEntrypointConfiguration,
5146
)
52-
from zenml.utils.env_utils import split_environment_variables
5347
from zenml.utils.string_utils import random_str
5448

5549
if TYPE_CHECKING:
@@ -60,7 +54,7 @@
6054

6155
logger = get_logger(__name__)
6256

63-
BATCH_DOCKER_IMAGE_KEY = "batch_step_operator"
57+
BATCH_DOCKER_IMAGE_KEY = "aws_batch_step_operator"
6458
_ENTRYPOINT_ENV_VARIABLE = "__ZENML_ENTRYPOINT"
6559

6660
class AWSBatchJobDefinitionContainerProperties(BaseModel):
@@ -232,11 +226,11 @@ def map_environment(environment: Dict[str,str]) -> List[Dict[str,str]]:
232226
"""
233227

234228
return [
235-
{"name":k,"value":v} for k,v in environment
229+
{"name":k,"value":v} for k,v in environment.items()
236230
]
237231

238232
@staticmethod
239-
def map_resource_settings(resource_settings: ResourceSettings) -> List[Dict[str,str]]:
233+
def map_resource_settings(resource_settings: "ResourceSettings") -> List[Dict[str,str]]:
240234
"""Utility to map the resource_settings to the resource convention used
241235
in the AWS Batch Job definition spec.
242236
@@ -253,25 +247,31 @@ def map_resource_settings(resource_settings: ResourceSettings) -> List[Dict[str,
253247
else:
254248

255249
if resource_settings.cpu_count is not None:
250+
251+
cpu_count_int = math.ceil(resource_settings.cpu_count)
252+
253+
if cpu_count_int != resource_settings.cpu_count:
254+
logger.info(f"AWS Batch only accepts int type cpu resource requirements. Converted {resource_settings.cpu_count} to {cpu_count_int}")
255+
256256
mapped_resource_settings.append(
257257
{
258-
"value": resource_settings.cpu_count,
258+
"value": str(cpu_count_int),
259259
"type": 'VCPU'
260260
}
261261
)
262262

263263
if resource_settings.gpu_count is not None:
264264
mapped_resource_settings.append(
265265
{
266-
"value": resource_settings.gpu_count,
266+
"value": str(resource_settings.gpu_count),
267267
"type": 'GPU'
268268
}
269269
)
270270

271271
if resource_settings.get_memory() is not None:
272272
mapped_resource_settings.append(
273273
{
274-
"value": resource_settings.get_memory(),
274+
"value": str(int(resource_settings.get_memory(unit="MiB"))),
275275
"type": 'MEMORY'
276276
}
277277
)

src/zenml/integrations/aws/step_operators/batch_step_operator_entrypoint_config.py renamed to src/zenml/integrations/aws/step_operators/aws_batch_step_operator_entrypoint_config.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,7 @@
1616
from zenml.step_operators.step_operator_entrypoint_configuration import (
1717
StepOperatorEntrypointConfiguration,
1818
)
19-
from zenml.utils.env_utils import reconstruct_environment_variables
20-
21-
BATCH_STEP_ENV_VAR_SIZE_LIMIT = 512
2219

2320

2421
class AWSBatchEntrypointConfiguration(StepOperatorEntrypointConfiguration):
25-
"""Entrypoint configuration for ZenML Batch step operator.
26-
27-
The only purpose of this entrypoint configuration is to reconstruct the
28-
environment variables that exceed the maximum length of 512 characters
29-
allowed for Batch steps from their individual components.
30-
"""
31-
32-
def run(self) -> None:
33-
"""Runs the step."""
34-
# Reconstruct the environment variables that exceed the maximum length
35-
# of 512 characters from their individual chunks
36-
reconstruct_environment_variables()
37-
38-
# Run the step
39-
super().run()
22+
"""Entrypoint configuration for ZenML Batch step operator."""
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) ZenML GmbH 2024. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at:
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
12+
# or implied. See the License for the specific language governing
13+
# permissions and limitations under the License.
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright (c) ZenML GmbH 2024. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at:
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
12+
# or implied. See the License for the specific language governing
13+
# permissions and limitations under the License.
14+
15+
16+
import pytest
17+
18+
from zenml.config.resource_settings import ResourceSettings
19+
from zenml.integrations.aws.step_operators.aws_batch_step_operator import AWSBatchStepOperator
20+
21+
def test_aws_batch_step_operator_map_environment():
22+
test_environment = {'key_1':'value_1','key_2':'value_2'}
23+
expected = [
24+
{
25+
"name": "key_1",
26+
"value": "value_1"
27+
},
28+
{
29+
"name": "key_2",
30+
"value": "value_2"
31+
}
32+
]
33+
34+
assert AWSBatchStepOperator.map_environment(test_environment) == expected
35+
36+
@pytest.mark.parametrize(
37+
"test_resource_settings,expected",
38+
[
39+
(
40+
ResourceSettings(),
41+
[]
42+
),
43+
(
44+
ResourceSettings(cpu_count=0.4,gpu_count=1,memory="10MiB"),
45+
[
46+
{
47+
"value":"1",
48+
"type":"VCPU"
49+
},
50+
{
51+
"value":"1",
52+
"type":"GPU"
53+
},
54+
{
55+
"value": "10",
56+
"type": "MEMORY"
57+
}
58+
]
59+
),
60+
(
61+
ResourceSettings(cpu_count=1,gpu_count=1),
62+
[
63+
{
64+
"value":"1",
65+
"type":"VCPU"
66+
},
67+
{
68+
"value":"1",
69+
"type":"GPU"
70+
},
71+
]
72+
),
73+
(
74+
ResourceSettings(memory="1GiB"),
75+
[
76+
{
77+
"value": "1024",
78+
"type": "MEMORY"
79+
}
80+
]
81+
),
82+
83+
]
84+
)
85+
def test_aws_batch_step_operator_map_resource_settings(test_resource_settings,expected):
86+
assert AWSBatchStepOperator.map_resource_settings(test_resource_settings) == expected
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from zenml.integrations.aws.flavors.aws_batch_step_operator_flavor import AWSBatchStepOperatorSettings
2+
3+
def test_aws_batch_step_operator_settings():
4+
AWSBatchStepOperatorSettings(
5+
instance_type="g4dn.xlarge",
6+
environment={"key_1":"value_1","key_2":"value_2"},
7+
timeout_seconds=60
8+
)

0 commit comments

Comments
 (0)