Skip to content

Commit e87d336

Browse files
started creating required files and mapping out the zenml config -> aws batch job definition conversion
1 parent ce1de79 commit e87d336

File tree

7 files changed

+518
-5
lines changed

7 files changed

+518
-5
lines changed

src/zenml/integrations/aws/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
AWS_CONTAINER_REGISTRY_FLAVOR = "aws"
2929
AWS_SAGEMAKER_STEP_OPERATOR_FLAVOR = "sagemaker"
3030
AWS_SAGEMAKER_ORCHESTRATOR_FLAVOR = "sagemaker"
31+
AWS_BATCH_STEP_OPERATOR_FLAVOR = "aws_batch"
3132

3233
# Service connector constants
3334
AWS_CONNECTOR_TYPE = "aws"
@@ -42,6 +43,7 @@ class AWSIntegration(Integration):
4243
NAME = AWS
4344
REQUIREMENTS = [
4445
"sagemaker>=2.237.3",
46+
"boto3>=1.40.30",
4547
"kubernetes",
4648
"aws-profile-manager",
4749
]
@@ -64,11 +66,13 @@ def flavors(cls) -> List[Type[Flavor]]:
6466
AWSImageBuilderFlavor,
6567
SagemakerOrchestratorFlavor,
6668
SagemakerStepOperatorFlavor,
69+
AWSBatchStepOperatorFlavor
6770
)
6871

6972
return [
7073
AWSContainerRegistryFlavor,
7174
AWSImageBuilderFlavor,
7275
SagemakerStepOperatorFlavor,
7376
SagemakerOrchestratorFlavor,
77+
AWSBatchStepOperatorFlavor
7478
]

src/zenml/integrations/aws/flavors/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@
2929
SagemakerStepOperatorConfig,
3030
SagemakerStepOperatorFlavor,
3131
)
32+
from zenml.integrations.aws.flavors.batch_step_operator_flavor import (
33+
AWSBatchStepOperatorConfig,
34+
AWSBatchStepOperatorFlavor
35+
)
3236

3337
__all__ = [
3438
"AWSContainerRegistryFlavor",
@@ -39,4 +43,7 @@
3943
"SagemakerStepOperatorConfig",
4044
"SagemakerOrchestratorFlavor",
4145
"SagemakerOrchestratorConfig",
46+
"AWSBatchStepOperatorFlavor",
47+
"AWSBatchStepOperatorConfig",
48+
4249
]
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
# Copyright (c) ZenML GmbH 2022. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at:
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
12+
# or implied. See the License for the specific language governing
13+
# permissions and limitations under the License.
14+
"""Amazon SageMaker step operator flavor."""
15+
16+
from typing import TYPE_CHECKING, Any, Dict, Optional, Type, Union
17+
18+
from pydantic import Field
19+
20+
from zenml.config.base_settings import BaseSettings
21+
from zenml.integrations.aws import (
22+
AWS_RESOURCE_TYPE,
23+
AWS_BATCH_STEP_OPERATOR_FLAVOR,
24+
)
25+
from zenml.models import ServiceConnectorRequirements
26+
from zenml.step_operators.base_step_operator import (
27+
BaseStepOperatorConfig,
28+
BaseStepOperatorFlavor,
29+
)
30+
from zenml.utils import deprecation_utils
31+
32+
if TYPE_CHECKING:
33+
from zenml.integrations.aws.step_operators import AWSBatchStepOperator
34+
35+
36+
class AWSBatchStepOperatorSettings(BaseSettings):
37+
"""Settings for the Sagemaker step operator."""
38+
39+
instance_type: Optional[str] = Field(
40+
None,
41+
description="DEPRECATED: The instance type to use for the step execution. "
42+
"Use estimator_args instead. Example: 'ml.m5.xlarge'",
43+
)
44+
environment: Dict[str, str] = Field(
45+
default_factory=dict,
46+
description="Environment variables to pass to the container during execution. "
47+
"Example: {'LOG_LEVEL': 'INFO', 'DEBUG_MODE': 'False'}",
48+
)
49+
50+
_deprecation_validator = deprecation_utils.deprecate_pydantic_attributes(
51+
"instance_type"
52+
)
53+
54+
55+
class AWSBatchStepOperatorConfig(
56+
BaseStepOperatorConfig, AWSBatchStepOperatorSettings
57+
):
58+
"""Config for the AWS Batch step operator."""
59+
60+
execution_role: str = Field(
61+
"",
62+
description="The ECS execution role required to execute the AWS Batch" \
63+
" jobs as ECS tasks."
64+
)
65+
job_role: str = Field(
66+
"",
67+
description="The ECS job role required by the container runtime insdide" \
68+
"the ECS task implementing the zenml step."
69+
)
70+
71+
@property
72+
def is_remote(self) -> bool:
73+
"""Checks if this stack component is running remotely.
74+
75+
This designation is used to determine if the stack component can be
76+
used with a local ZenML database or if it requires a remote ZenML
77+
server.
78+
79+
Returns:
80+
True if this config is for a remote component, False otherwise.
81+
"""
82+
return True
83+
84+
85+
class AWSBatchStepOperatorFlavor(BaseStepOperatorFlavor):
86+
"""Flavor for the AWS Batch step operator."""
87+
88+
@property
89+
def name(self) -> str:
90+
"""Name of the flavor.
91+
92+
Returns:
93+
The name of the flavor.
94+
"""
95+
return AWS_BATCH_STEP_OPERATOR_FLAVOR
96+
97+
@property
98+
def service_connector_requirements(
99+
self,
100+
) -> Optional[ServiceConnectorRequirements]:
101+
"""Service connector resource requirements for service connectors.
102+
103+
Specifies resource requirements that are used to filter the available
104+
service connector types that are compatible with this flavor.
105+
106+
Returns:
107+
Requirements for compatible service connectors, if a service
108+
connector is required for this flavor.
109+
"""
110+
return ServiceConnectorRequirements(resource_type=AWS_RESOURCE_TYPE)
111+
112+
@property
113+
def docs_url(self) -> Optional[str]:
114+
"""A url to point at docs explaining this flavor.
115+
116+
Returns:
117+
A flavor docs url.
118+
"""
119+
return self.generate_default_docs_url()
120+
121+
@property
122+
def sdk_docs_url(self) -> Optional[str]:
123+
"""A url to point at SDK docs explaining this flavor.
124+
125+
Returns:
126+
A flavor SDK docs url.
127+
"""
128+
return self.generate_default_sdk_docs_url()
129+
130+
@property
131+
def logo_url(self) -> str:
132+
"""A url to represent the flavor in the dashboard.
133+
134+
Returns:
135+
The flavor logo.
136+
"""
137+
return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/step_operator/aws_batch.png"
138+
139+
@property
140+
def config_class(self) -> Type[AWSBatchStepOperatorConfig]:
141+
"""Returns BatchStepOperatorConfig config class.
142+
143+
Returns:
144+
The config class.
145+
"""
146+
return AWSBatchStepOperatorConfig
147+
148+
@property
149+
def implementation_class(self) -> Type["AWSBatchStepOperator"]:
150+
"""Implementation class.
151+
152+
Returns:
153+
The implementation class.
154+
"""
155+
from zenml.integrations.aws.step_operators import AWSBatchStepOperator
156+
157+
return AWSBatchStepOperator

src/zenml/integrations/aws/step_operators/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
# permissions and limitations under the License.
1414
"""Initialization of the Sagemaker Step Operator."""
1515

16-
from zenml.integrations.aws.step_operators.sagemaker_step_operator import ( # noqa
16+
from zenml.integrations.aws.step_operators.sagemaker_step_operator import ( # noqa: F401
1717
SagemakerStepOperator,
1818
)
19-
20-
__all__ = ["SagemakerStepOperator"]
19+
from zenml.integrations.aws.step_operators.batch_step_operator import ( # noqa: F401
20+
AWSBatchStepOperator,
21+
)
22+
__all__ = ["SagemakerStepOperator","AWSBatchStepOperator"]

0 commit comments

Comments
 (0)