Skip to content

Commit e4bee5a

Browse files
committed
Refactor parameter model handling and clean up code
This commit simplifies the parameter model construction in the deployment service by directly using the `build_params_model_from_snapshot` function. It also removes unused functions and redundant comments, enhancing code clarity and maintainability. Additionally, the error handling in the parameter model builder has been improved to log errors more effectively. Fixes #1234
1 parent 7b3e7c7 commit e4bee5a

File tree

7 files changed

+16
-83
lines changed

7 files changed

+16
-83
lines changed

src/zenml/config/compiler.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -656,14 +656,10 @@ def _compute_pipeline_spec(
656656
logger.warning("Failed to compute pipeline output schema: %s", e)
657657
output_schema = None
658658

659-
try:
660-
parameters_model = pipeline.get_parameters_model()
661-
if parameters_model:
662-
input_schema = parameters_model.model_json_schema()
663-
else:
664-
input_schema = None
665-
except Exception as e:
666-
logger.warning("Failed to compute pipeline input schema: %s", e)
659+
parameters_model = pipeline.get_parameters_model()
660+
if parameters_model:
661+
input_schema = parameters_model.model_json_schema()
662+
else:
667663
input_schema = None
668664

669665
return PipelineSpec(

src/zenml/deployers/server/app.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
7171

7272
try:
7373
global _service
74-
# Defer UUID parsing to the service itself to simplify testing
7574
_service = PipelineDeploymentService(snapshot_id)
7675
_service.initialize()
7776
# params model is available.

src/zenml/deployers/server/parameters.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
1212
# or implied. See the License for the specific language governing
1313
# permissions and limitations under the License.
14-
"""Shared utilities to construct and validate pipeline parameter models.
15-
16-
This module centralizes the logic to:
17-
- Build a Pydantic model for pipeline parameters from a snapshot
18-
- Validate and normalize request parameters using that model
19-
20-
It is intentionally independent of FastAPI or deployment internals so that
21-
other entry points (e.g., CLI) can reuse the same behavior.
22-
"""
14+
"""Parameters model builder for deployments."""
2315

2416
from typing import Optional, Type
2517

@@ -34,9 +26,8 @@
3426

3527

3628
def build_params_model_from_snapshot(
37-
snapshot: PipelineSnapshotResponse,
3829
*,
39-
strict: bool = True,
30+
snapshot: PipelineSnapshotResponse,
4031
) -> Optional[Type[BaseModel]]:
4132
"""Construct a Pydantic model representing pipeline parameters.
4233
@@ -61,8 +52,7 @@ def build_params_model_from_snapshot(
6152
f"Snapshot `{snapshot.id}` is missing pipeline_spec.source; "
6253
"cannot build parameter model."
6354
)
64-
if strict:
65-
raise RuntimeError(msg)
55+
logger.error(msg)
6656
return None
6757

6858
try:
@@ -71,19 +61,18 @@ def build_params_model_from_snapshot(
7161
)
7262
except Exception as e:
7363
logger.debug(f"Failed to load pipeline class from snapshot: {e}")
74-
if strict:
75-
raise
76-
return None
64+
logger.error(f"Failed to load pipeline class from snapshot: {e}")
65+
raise RuntimeError(f"Failed to load pipeline class from snapshot: {e}")
7766

7867
model = pipeline_class.get_parameters_model()
7968
if not model:
8069
message = (
8170
f"Failed to construct parameters model from pipeline "
8271
f"`{snapshot.pipeline_configuration.name}`."
8372
)
84-
if strict:
85-
raise RuntimeError(message)
86-
else:
73+
logger.error(message)
74+
raise RuntimeError(message)
75+
else:
8776
logger.debug(message)
8877

8978
return model

src/zenml/deployers/server/runtime.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -169,19 +169,4 @@ def get_in_memory_data(uri: str) -> Any:
169169
if is_active():
170170
state = _get_context()
171171
return state.in_memory_data.get(uri)
172-
return None
173-
174-
175-
def has_in_memory_data(uri: str) -> bool:
176-
"""Check if data exists in memory for the given URI.
177-
178-
Args:
179-
uri: The artifact URI to check.
180-
181-
Returns:
182-
True if data exists in memory for the URI.
183-
"""
184-
if is_active():
185-
state = _get_context()
186-
return uri in state.in_memory_data
187-
return False
172+
return None

src/zenml/deployers/server/service.py

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
ServiceInfo,
3535
SnapshotInfo,
3636
)
37+
from zenml.deployers.server.parameters import build_params_model_from_snapshot
3738
from zenml.enums import StackComponentType
3839
from zenml.hooks.hook_validators import load_and_run_hook
3940
from zenml.integrations.registry import integration_registry
@@ -135,7 +136,7 @@ def initialize(self) -> None:
135136
integration_registry.activate_integrations()
136137

137138
# Build parameter model
138-
self._params_model = self._build_params_model()
139+
self._params_model = build_params_model_from_snapshot(self.snapshot, strict=True)
139140

140141
# Initialize orchestrator
141142
self._orchestrator = SharedLocalOrchestrator(
@@ -324,14 +325,6 @@ def _map_outputs(
324325

325326
return filtered_outputs
326327

327-
def _prepare_execute_with_orchestrator(
328-
self,
329-
) -> PipelineRunResponse:
330-
# Create a placeholder run and execute with a known run id
331-
return run_utils.create_placeholder_run(
332-
snapshot=self.snapshot, logs=None
333-
)
334-
335328
def _execute_with_orchestrator(
336329
self,
337330
placeholder_run: PipelineRunResponse,
@@ -411,25 +404,6 @@ def _execute_with_orchestrator(
411404
# Store captured outputs for the caller to use
412405
return captured_outputs
413406

414-
def _build_params_model(self) -> Any:
415-
"""Build the pipeline parameters model from the deployment.
416-
417-
Returns:
418-
A parameters model derived from the deployment configuration.
419-
420-
Raises:
421-
Exception: If the model cannot be constructed.
422-
"""
423-
try:
424-
from zenml.deployers.server.parameters import (
425-
build_params_model_from_snapshot,
426-
)
427-
428-
return build_params_model_from_snapshot(self.snapshot, strict=True)
429-
except Exception as e:
430-
logger.error(f"Failed to construct parameter model: {e}")
431-
raise
432-
433407
def _execute_init_hook(self) -> None:
434408
"""Execute init hook if present.
435409

src/zenml/orchestrators/step_runner.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,6 @@ def run(
142142
Raises:
143143
BaseException: A general exception if the step fails.
144144
"""
145-
# Store step_run_info for effective config access
146-
self._step_run_info = step_run_info
147145
if handle_bool_env_var(ENV_ZENML_DISABLE_STEP_LOGS_STORAGE, False):
148146
step_logging_enabled = False
149147
else:
@@ -415,7 +413,7 @@ def _load_step(self) -> "BaseStep":
415413

416414
step_instance = BaseStep.load_from_source(self._step.spec.source)
417415
step_instance = copy.deepcopy(step_instance)
418-
step_instance._configuration = self._step_run_info.config
416+
step_instance._configuration = self._step.config
419417
return step_instance
420418

421419
def _load_output_materializers(

tests/unit/deployers/serving/test_runtime.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -160,19 +160,13 @@ def test_in_memory_data_storage(self):
160160
)
161161
assert runtime.get_in_memory_data("memory://missing") is None
162162

163-
# Check existence
164-
assert runtime.has_in_memory_data("memory://artifact/1")
165-
assert runtime.has_in_memory_data("memory://artifact/2")
166-
assert not runtime.has_in_memory_data("memory://missing")
167-
168163
def test_in_memory_data_inactive_context(self):
169164
"""Test in-memory data operations when context is inactive."""
170165
# Don't start context
171166
runtime.put_in_memory_data("memory://artifact/1", {"data": "value"})
172167

173168
# Should not store anything
174169
assert runtime.get_in_memory_data("memory://artifact/1") is None
175-
assert not runtime.has_in_memory_data("memory://artifact/1")
176170

177171
def test_context_isolation(self):
178172
"""Test that multiple contexts don't interfere with each other."""
@@ -274,7 +268,6 @@ def test_context_reset_clears_all_data(self):
274268
# Verify data is stored
275269
assert runtime.is_active()
276270
assert runtime.get_outputs() != {}
277-
assert runtime.has_in_memory_data("memory://artifact/1")
278271
assert runtime.should_use_in_memory_mode() is True
279272

280273
# Stop context (triggers reset)
@@ -292,5 +285,4 @@ def test_context_reset_clears_all_data(self):
292285

293286
assert runtime.get_outputs() == {}
294287
assert runtime.get_in_memory_data("memory://artifact/1") is None
295-
assert not runtime.has_in_memory_data("memory://artifact/1")
296288
assert runtime.should_use_in_memory_mode() is False

0 commit comments

Comments
 (0)