Skip to content

Commit 0906233

Browse files
committed
Optional entrypoint args
1 parent 6b51cc5 commit 0906233

File tree

8 files changed

+52
-45
lines changed

8 files changed

+52
-45
lines changed

src/zenml/deployers/server/entrypoint_configuration.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
"""ZenML Pipeline Deployment Entrypoint Configuration."""
1515

1616
import os
17-
from typing import Any, List, Set
17+
from typing import Any, Dict, List
1818
from uuid import UUID
1919

2020
from zenml.client import Client
@@ -46,20 +46,20 @@ class DeploymentEntrypointConfiguration(BaseEntrypointConfiguration):
4646
"""
4747

4848
@classmethod
49-
def get_entrypoint_options(cls) -> Set[str]:
49+
def get_entrypoint_options(cls) -> Dict[str, bool]:
5050
"""Gets all options required for the deployment entrypoint.
5151
5252
Returns:
5353
Set of required option names
5454
"""
5555
return {
56-
DEPLOYMENT_ID_OPTION,
57-
HOST_OPTION,
58-
PORT_OPTION,
59-
WORKERS_OPTION,
60-
LOG_LEVEL_OPTION,
61-
CREATE_RUNS_OPTION,
62-
AUTH_KEY_OPTION,
56+
DEPLOYMENT_ID_OPTION: True,
57+
HOST_OPTION: True,
58+
PORT_OPTION: True,
59+
WORKERS_OPTION: True,
60+
LOG_LEVEL_OPTION: True,
61+
CREATE_RUNS_OPTION: True,
62+
AUTH_KEY_OPTION: True,
6363
}
6464

6565
@classmethod

src/zenml/entrypoints/base_entrypoint_configuration.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import os
1818
import sys
1919
from abc import ABC, abstractmethod
20-
from typing import TYPE_CHECKING, Any, Dict, List, NoReturn, Optional, Set
20+
from typing import TYPE_CHECKING, Any, Dict, List, NoReturn, Optional
2121
from uuid import UUID
2222

2323
from zenml.client import Client
@@ -83,18 +83,18 @@ def get_entrypoint_command(cls) -> List[str]:
8383
return DEFAULT_ENTRYPOINT_COMMAND
8484

8585
@classmethod
86-
def get_entrypoint_options(cls) -> Set[str]:
86+
def get_entrypoint_options(cls) -> Dict[str, bool]:
8787
"""Gets all options required for running with this configuration.
8888
8989
Returns:
90-
A set of strings with all required options.
90+
A dictionary of options and whether they are required.
9191
"""
9292
return {
9393
# Importable source pointing to the entrypoint configuration class
9494
# that should be used inside the entrypoint.
95-
ENTRYPOINT_CONFIG_SOURCE_OPTION,
95+
ENTRYPOINT_CONFIG_SOURCE_OPTION: True,
9696
# ID of the pipeline snapshot to use in this entrypoint
97-
SNAPSHOT_ID_OPTION,
97+
SNAPSHOT_ID_OPTION: True,
9898
}
9999

100100
@classmethod
@@ -178,13 +178,13 @@ def error(self, message: str) -> NoReturn:
178178

179179
parser = _CustomParser()
180180

181-
for option_name in cls.get_entrypoint_options():
181+
for option_name, required in cls.get_entrypoint_options().items():
182182
if option_name == ENTRYPOINT_CONFIG_SOURCE_OPTION:
183183
# This option is already used by
184184
# `zenml.entrypoints.entrypoint` to read which config
185185
# class to use
186186
continue
187-
parser.add_argument(f"--{option_name}", required=True)
187+
parser.add_argument(f"--{option_name}", required=required)
188188

189189
result, _ = parser.parse_known_args(arguments)
190190
return vars(result)

src/zenml/entrypoints/step_entrypoint_configuration.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import os
1717
import sys
18-
from typing import TYPE_CHECKING, Any, List, Set
18+
from typing import TYPE_CHECKING, Any, Dict, List
1919
from uuid import UUID
2020

2121
from zenml.client import Client
@@ -115,14 +115,14 @@ def post_run(
115115
"""
116116

117117
@classmethod
118-
def get_entrypoint_options(cls) -> Set[str]:
118+
def get_entrypoint_options(cls) -> Dict[str, bool]:
119119
"""Gets all options required for running with this configuration.
120120
121121
Returns:
122122
The superclass options as well as an option for the name of the
123123
step to run.
124124
"""
125-
return super().get_entrypoint_options() | {STEP_NAME_OPTION}
125+
return super().get_entrypoint_options() | {STEP_NAME_OPTION: True}
126126

127127
@classmethod
128128
def get_entrypoint_arguments(

src/zenml/integrations/azure/orchestrators/azureml_orchestrator_entrypoint_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import json
1717
import os
18-
from typing import Any, List, Set
18+
from typing import Any, Dict, List
1919

2020
from zenml.entrypoints.step_entrypoint_configuration import (
2121
StepEntrypointConfiguration,
@@ -30,14 +30,14 @@ class AzureMLEntrypointConfiguration(StepEntrypointConfiguration):
3030
"""Entrypoint configuration for ZenML AzureML pipeline steps."""
3131

3232
@classmethod
33-
def get_entrypoint_options(cls) -> Set[str]:
33+
def get_entrypoint_options(cls) -> Dict[str, bool]:
3434
"""Gets all options required for running with this configuration.
3535
3636
Returns:
3737
The superclass options as well as an option for the
3838
environmental variables.
3939
"""
40-
return super().get_entrypoint_options() | {ZENML_ENV_VARIABLES}
40+
return super().get_entrypoint_options() | {ZENML_ENV_VARIABLES: True}
4141

4242
@classmethod
4343
def get_entrypoint_arguments(cls, **kwargs: Any) -> List[str]:

src/zenml/integrations/databricks/orchestrators/databricks_orchestrator_entrypoint_config.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import os
1717
import sys
1818
from importlib.metadata import distribution
19-
from typing import Any, List, Set
19+
from typing import Any, Dict, List
2020

2121
from zenml.entrypoints.step_entrypoint_configuration import (
2222
StepEntrypointConfiguration,
@@ -38,17 +38,16 @@ class DatabricksEntrypointConfiguration(StepEntrypointConfiguration):
3838
"""
3939

4040
@classmethod
41-
def get_entrypoint_options(cls) -> Set[str]:
41+
def get_entrypoint_options(cls) -> Dict[str, bool]:
4242
"""Gets all options required for running with this configuration.
4343
4444
Returns:
4545
The superclass options as well as an option for the wheel package.
4646
"""
47-
return (
48-
super().get_entrypoint_options()
49-
| {WHEEL_PACKAGE_OPTION}
50-
| {DATABRICKS_JOB_ID_OPTION}
51-
)
47+
return super().get_entrypoint_options() | {
48+
WHEEL_PACKAGE_OPTION: True,
49+
DATABRICKS_JOB_ID_OPTION: True,
50+
}
5251

5352
@classmethod
5453
def get_entrypoint_arguments(

src/zenml/pipelines/dynamic/entrypoint_configuration.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
# permissions and limitations under the License.
2828
"""Abstract base class for entrypoint configurations that run a pipeline."""
2929

30-
from typing import Any, List, Set
30+
from typing import Any, Dict, List
3131
from uuid import UUID
3232

3333
from zenml.client import Client
@@ -44,14 +44,14 @@ class DynamicPipelineEntrypointConfiguration(BaseEntrypointConfiguration):
4444
"""Base class for entrypoint configurations that run an entire pipeline."""
4545

4646
@classmethod
47-
def get_entrypoint_options(cls) -> Set[str]:
47+
def get_entrypoint_options(cls) -> Dict[str, bool]:
4848
"""Gets all options required for running with this configuration.
4949
5050
Returns:
5151
The superclass options as well as an option for the name of the
5252
step to run.
5353
"""
54-
return super().get_entrypoint_options() | {RUN_ID_OPTION}
54+
return super().get_entrypoint_options() | {RUN_ID_OPTION: False}
5555

5656
@classmethod
5757
def get_entrypoint_arguments(
@@ -73,10 +73,15 @@ def get_entrypoint_arguments(
7373
The superclass arguments as well as arguments for the name of the
7474
step to run.
7575
"""
76-
return super().get_entrypoint_arguments(**kwargs) + [
77-
f"--{RUN_ID_OPTION}",
78-
str(kwargs[RUN_ID_OPTION]),
79-
]
76+
args = super().get_entrypoint_arguments(**kwargs)
77+
if RUN_ID_OPTION in kwargs:
78+
args.extend(
79+
[
80+
f"--{RUN_ID_OPTION}",
81+
str(kwargs[RUN_ID_OPTION]),
82+
]
83+
)
84+
return args
8085

8186
def run(self) -> None:
8287
"""Prepares the environment and runs the configured pipeline."""
@@ -88,9 +93,9 @@ def run(self) -> None:
8893

8994
self.download_code_if_necessary(snapshot=snapshot)
9095

91-
# TODO: make this optional
92-
run_id = UUID(self.entrypoint_args[RUN_ID_OPTION])
93-
run = Client().get_pipeline_run(run_id)
96+
run = None
97+
if run_id := self.entrypoint_args.get(RUN_ID_OPTION, None):
98+
run = Client().get_pipeline_run(UUID(run_id))
9499

95100
runner = DynamicPipelineRunner(snapshot=snapshot, run=run)
96101
runner.run_pipeline()

src/zenml/step_operators/step_operator_entrypoint_configuration.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
# permissions and limitations under the License.
1414
"""Abstract base class for entrypoint configurations that run a single step."""
1515

16-
from typing import TYPE_CHECKING, Any, List, Set
16+
from ast import Dict
17+
from typing import TYPE_CHECKING, Any, Dict, List
1718
from uuid import UUID
1819

1920
from zenml.client import Client
@@ -36,14 +37,14 @@ class StepOperatorEntrypointConfiguration(StepEntrypointConfiguration):
3637
"""Base class for step operator entrypoint configurations."""
3738

3839
@classmethod
39-
def get_entrypoint_options(cls) -> Set[str]:
40+
def get_entrypoint_options(cls) -> Dict[str, bool]:
4041
"""Gets all options required for running with this configuration.
4142
4243
Returns:
4344
The superclass options as well as an option for the step run id.
4445
"""
4546
return super().get_entrypoint_options() | {
46-
STEP_RUN_ID_OPTION,
47+
STEP_RUN_ID_OPTION: True,
4748
}
4849

4950
@classmethod

src/zenml/zen_server/pipeline_execution/runner_entrypoint_configuration.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# permissions and limitations under the License.
1414
"""Runner entrypoint configuration."""
1515

16-
from typing import Any, List, Set
16+
from typing import Any, Dict, List
1717
from uuid import UUID
1818

1919
from zenml.client import Client
@@ -29,14 +29,16 @@ class RunnerEntrypointConfiguration(BaseEntrypointConfiguration):
2929
"""Runner entrypoint configuration."""
3030

3131
@classmethod
32-
def get_entrypoint_options(cls) -> Set[str]:
32+
def get_entrypoint_options(cls) -> Dict[str, bool]:
3333
"""Gets all options required for running with this configuration.
3434
3535
Returns:
3636
The superclass options as well as an option for the name of the
3737
step to run.
3838
"""
39-
return super().get_entrypoint_options() | {PLACEHOLDER_RUN_ID_OPTION}
39+
return super().get_entrypoint_options() | {
40+
PLACEHOLDER_RUN_ID_OPTION: True
41+
}
4042

4143
@classmethod
4244
def get_entrypoint_arguments(

0 commit comments

Comments
 (0)