Skip to content

Commit 093ee9a

Browse files
committed
Misc fixes
1 parent 60e9358 commit 093ee9a

File tree

9 files changed

+125
-75
lines changed

9 files changed

+125
-75
lines changed

src/zenml/deployers/server/entrypoint_configuration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def run(self) -> None:
113113
raise RuntimeError(f"Deployment {deployment.id} has no snapshot")
114114

115115
# Download code if necessary (for remote execution environments)
116-
self.download_code_if_necessary(snapshot=deployment.snapshot)
116+
self.download_code_if_necessary()
117117

118118
app_runner = BaseDeploymentAppRunner.load_app_runner(deployment)
119119
app_runner.run()

src/zenml/entrypoints/base_entrypoint_configuration.py

Lines changed: 48 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
if TYPE_CHECKING:
3636
from zenml.artifact_stores import BaseArtifactStore
37+
from zenml.config import DockerSettings
3738
from zenml.models import CodeReferenceResponse, PipelineSnapshotResponse
3839

3940
logger = get_logger(__name__)
@@ -64,6 +65,7 @@ def __init__(self, arguments: List[str]):
6465
arguments: Command line arguments to configure this object.
6566
"""
6667
self.entrypoint_args = self._parse_arguments(arguments)
68+
self._snapshot: Optional["PipelineSnapshotResponse"] = None
6769

6870
@classmethod
6971
def get_entrypoint_command(cls) -> List[str]:
@@ -189,7 +191,48 @@ def error(self, message: str) -> NoReturn:
189191
result, _ = parser.parse_known_args(arguments)
190192
return vars(result)
191193

192-
def load_snapshot(self) -> "PipelineSnapshotResponse":
194+
@property
195+
def snapshot(self) -> "PipelineSnapshotResponse":
196+
"""The snapshot configured for this entrypoint configuration.
197+
198+
Returns:
199+
The snapshot.
200+
"""
201+
if self._snapshot is None:
202+
self._snapshot = self._load_snapshot()
203+
return self._snapshot
204+
205+
@property
206+
def docker_settings(self) -> "DockerSettings":
207+
"""The Docker settings configured for this entrypoint configuration.
208+
209+
Returns:
210+
The Docker settings.
211+
"""
212+
return self.snapshot.pipeline_configuration.docker_settings
213+
214+
@property
215+
def should_download_code(self) -> bool:
216+
"""Whether code should be downloaded.
217+
218+
Returns:
219+
Whether code should be downloaded.
220+
"""
221+
if (
222+
self.snapshot.code_reference
223+
and self.docker_settings.allow_download_from_code_repository
224+
):
225+
return True
226+
227+
if (
228+
self.snapshot.code_path
229+
and self.docker_settings.allow_download_from_artifact_store
230+
):
231+
return True
232+
233+
return False
234+
235+
def _load_snapshot(self) -> "PipelineSnapshotResponse":
193236
"""Loads the snapshot.
194237
195238
Returns:
@@ -198,34 +241,19 @@ def load_snapshot(self) -> "PipelineSnapshotResponse":
198241
snapshot_id = UUID(self.entrypoint_args[SNAPSHOT_ID_OPTION])
199242
return Client().zen_store.get_snapshot(snapshot_id=snapshot_id)
200243

201-
def download_code_if_necessary(
202-
self,
203-
snapshot: "PipelineSnapshotResponse",
204-
step_name: Optional[str] = None,
205-
) -> None:
244+
def download_code_if_necessary(self) -> None:
206245
"""Downloads user code if necessary.
207246
208-
Args:
209-
snapshot: The snapshot for which to download the code.
210-
step_name: Name of the step to be run. This will be used to
211-
determine whether code download is necessary. If not given,
212-
the DockerSettings of the pipeline will be used to make that
213-
decision instead.
214-
215247
Raises:
216248
CustomFlavorImportError: If the artifact store flavor can't be
217249
imported.
218250
RuntimeError: If the current environment requires code download
219251
but the snapshot does not have a reference to any code.
220252
"""
221-
should_download_code = self._should_download_code(
222-
snapshot=snapshot, step_name=step_name
223-
)
224-
225-
if not should_download_code:
253+
if not self.should_download_code:
226254
return
227255

228-
if code_path := snapshot.code_path:
256+
if code_path := self.snapshot.code_path:
229257
# Load the artifact store not from the active stack but separately.
230258
# This is required in case the stack has custom flavor components
231259
# (other than the artifact store) for which the flavor
@@ -247,7 +275,7 @@ def download_code_if_necessary(
247275
code_utils.download_code_from_artifact_store(
248276
code_path=code_path, artifact_store=artifact_store
249277
)
250-
elif code_reference := snapshot.code_reference:
278+
elif code_reference := self.snapshot.code_reference:
251279
# TODO: This might fail if the code repository had unpushed changes
252280
# at the time the pipeline run was started.
253281
self.download_code_from_code_repository(
@@ -294,43 +322,6 @@ def download_code_from_code_repository(
294322
sys.path.insert(0, download_dir)
295323
os.chdir(download_dir)
296324

297-
def _should_download_code(
298-
self,
299-
snapshot: "PipelineSnapshotResponse",
300-
step_name: Optional[str] = None,
301-
) -> bool:
302-
"""Checks whether code should be downloaded.
303-
304-
Args:
305-
snapshot: The snapshot to check.
306-
step_name: Name of the step to be run. This will be used to
307-
determine whether code download is necessary. If not given,
308-
the DockerSettings of the pipeline will be used to make that
309-
decision instead.
310-
311-
Returns:
312-
Whether code should be downloaded.
313-
"""
314-
docker_settings = (
315-
snapshot.step_configurations[step_name].config.docker_settings
316-
if step_name
317-
else snapshot.pipeline_configuration.docker_settings
318-
)
319-
320-
if (
321-
snapshot.code_reference
322-
and docker_settings.allow_download_from_code_repository
323-
):
324-
return True
325-
326-
if (
327-
snapshot.code_path
328-
and docker_settings.allow_download_from_artifact_store
329-
):
330-
return True
331-
332-
return False
333-
334325
def _load_active_artifact_store(self) -> "BaseArtifactStore":
335326
"""Load the active artifact store.
336327

src/zenml/entrypoints/pipeline_entrypoint_configuration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,13 @@ class PipelineEntrypointConfiguration(BaseEntrypointConfiguration):
2525

2626
def run(self) -> None:
2727
"""Prepares the environment and runs the configured pipeline."""
28-
snapshot = self.load_snapshot()
28+
snapshot = self.snapshot
2929

3030
# Activate all the integrations. This makes sure that all materializers
3131
# and stack component flavors are registered.
3232
integration_registry.activate_integrations()
3333

34-
self.download_code_if_necessary(snapshot=snapshot)
34+
self.download_code_if_necessary()
3535

3636
orchestrator = Client().active_stack.orchestrator
3737
orchestrator._prepare_run(snapshot=snapshot)

src/zenml/entrypoints/step_entrypoint_configuration.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from zenml.logger import get_logger
2828

2929
if TYPE_CHECKING:
30+
from zenml.config import DockerSettings
3031
from zenml.config.step_configurations import Step
3132
from zenml.models import PipelineSnapshotResponse
3233

@@ -149,7 +150,26 @@ def get_entrypoint_arguments(
149150
kwargs[STEP_NAME_OPTION],
150151
]
151152

152-
def load_snapshot(self) -> "PipelineSnapshotResponse":
153+
@property
154+
def docker_settings(self) -> "DockerSettings":
155+
"""The Docker settings configured for this entrypoint configuration.
156+
157+
Returns:
158+
The Docker settings.
159+
"""
160+
return self.step.config.docker_settings
161+
162+
@property
163+
def step(self) -> "Step":
164+
"""The step configured for this entrypoint configuration.
165+
166+
Returns:
167+
The step.
168+
"""
169+
step_name = self.entrypoint_args[STEP_NAME_OPTION]
170+
return self.snapshot.step_configurations[step_name]
171+
172+
def _load_snapshot(self) -> "PipelineSnapshotResponse":
153173
"""Loads the snapshot.
154174
155175
Returns:
@@ -163,7 +183,7 @@ def load_snapshot(self) -> "PipelineSnapshotResponse":
163183

164184
def run(self) -> None:
165185
"""Prepares the environment and runs the configured step."""
166-
snapshot = self.load_snapshot()
186+
snapshot = self.snapshot
167187

168188
# Activate all the integrations. This makes sure that all materializers
169189
# and stack component flavors are registered.
@@ -178,7 +198,7 @@ def run(self) -> None:
178198
os.makedirs("/app", exist_ok=True)
179199
os.chdir("/app")
180200

181-
self.download_code_if_necessary(snapshot=snapshot, step_name=step_name)
201+
self.download_code_if_necessary()
182202

183203
# If the working directory is not in the sys.path, we include it to make
184204
# sure user code gets correctly imported.
@@ -188,8 +208,7 @@ def run(self) -> None:
188208

189209
pipeline_name = snapshot.pipeline_configuration.name
190210

191-
step = snapshot.step_configurations[step_name]
192-
self._run_step(step, snapshot=snapshot)
211+
self._run_step(step=self.step, snapshot=snapshot)
193212

194213
self.post_run(
195214
pipeline_name=pipeline_name,

src/zenml/pipelines/dynamic/entrypoint_configuration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,13 @@ def get_entrypoint_arguments(
5858

5959
def run(self) -> None:
6060
"""Prepares the environment and runs the configured dynamic pipeline."""
61-
snapshot = self.load_snapshot()
61+
snapshot = self.snapshot
6262

6363
# Activate all the integrations. This makes sure that all materializers
6464
# and stack component flavors are registered.
6565
integration_registry.activate_integrations()
6666

67-
self.download_code_if_necessary(snapshot=snapshot)
67+
self.download_code_if_necessary()
6868

6969
run = None
7070
if run_id := self.entrypoint_args.get(RUN_ID_OPTION, None):

src/zenml/step_operators/step_operator_entrypoint_configuration.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@
1313
# permissions and limitations under the License.
1414
"""Abstract base class for entrypoint configurations that run a single step."""
1515

16-
from typing import TYPE_CHECKING, Any, Dict, List
16+
from typing import TYPE_CHECKING, Any, Dict, List, Optional
1717
from uuid import UUID
1818

1919
from zenml.client import Client
20+
from zenml.config.step_configurations import Step
2021
from zenml.config.step_run_info import StepRunInfo
2122
from zenml.entrypoints.step_entrypoint_configuration import (
2223
STEP_NAME_OPTION,
@@ -26,15 +27,18 @@
2627
from zenml.orchestrators.step_runner import StepRunner
2728

2829
if TYPE_CHECKING:
29-
from zenml.config.step_configurations import Step
30-
from zenml.models import PipelineSnapshotResponse
30+
from zenml.models import PipelineSnapshotResponse, StepRunResponse
3131

3232
STEP_RUN_ID_OPTION = "step_run_id"
3333

3434

3535
class StepOperatorEntrypointConfiguration(StepEntrypointConfiguration):
3636
"""Base class for step operator entrypoint configurations."""
3737

38+
def __init__(self, *args: Any, **kwargs: Any) -> None:
39+
super().__init__(*args, **kwargs)
40+
self._step_run: Optional["StepRunResponse"] = None
41+
3842
@classmethod
3943
def get_entrypoint_options(cls) -> Dict[str, bool]:
4044
"""Gets all options required for running with this configuration.
@@ -64,6 +68,27 @@ def get_entrypoint_arguments(
6468
kwargs[STEP_RUN_ID_OPTION],
6569
]
6670

71+
@property
72+
def step_run(self) -> "StepRunResponse":
73+
"""The step run configured for this entrypoint configuration.
74+
75+
Returns:
76+
The step run.
77+
"""
78+
if self._step_run is None:
79+
step_run_id = UUID(self.entrypoint_args[STEP_RUN_ID_OPTION])
80+
self._step_run = Client().zen_store.get_run_step(step_run_id)
81+
return self._step_run
82+
83+
@property
84+
def step(self) -> "Step":
85+
"""The step configured for this entrypoint configuration.
86+
87+
Returns:
88+
The step.
89+
"""
90+
return Step(spec=self.step_run.spec, config=self.step_run.config)
91+
6792
def _run_step(
6893
self,
6994
step: "Step",
@@ -75,8 +100,7 @@ def _run_step(
75100
step: The step to run.
76101
snapshot: The snapshot configuration.
77102
"""
78-
step_run_id = UUID(self.entrypoint_args[STEP_RUN_ID_OPTION])
79-
step_run = Client().zen_store.get_run_step(step_run_id)
103+
step_run = self.step_run
80104
pipeline_run = Client().get_pipeline_run(step_run.pipeline_run_id)
81105

82106
step_run_info = StepRunInfo(
@@ -87,7 +111,7 @@ def _run_step(
87111
run_name=pipeline_run.name,
88112
pipeline_step_name=self.entrypoint_args[STEP_NAME_OPTION],
89113
run_id=pipeline_run.id,
90-
step_run_id=step_run_id,
114+
step_run_id=step_run.id,
91115
force_write_logs=lambda: None,
92116
)
93117

0 commit comments

Comments
 (0)