Skip to content

Commit e422f2d

Browse files
committed
Refactor model deployment logic and update README with new configurations
1 parent 64bde44 commit e422f2d

File tree

8 files changed

+142
-91
lines changed

8 files changed

+142
-91
lines changed

vertex-registry-and-deployer/README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,16 @@ Once you have set up your Google Cloud credentials, we can create a stack and ru
2626

2727
```bash
2828
# Register the artifact store
29-
zenml artifact-store register gs_store -f gcp --path=gs://bucket-name
29+
zenml artifact-store register gs_store -f gcp --path=gs://zenml-vertex-test
3030
zenml artifact-store connect gs_store --connector gcp
3131

3232
# Register the model registry
33-
zenml model-registry register vertex_registry --flavor=vertex --location=us-central1
33+
zenml model-registry register vertex_registry --flavor=vertex --location=europe-west1
3434
zenml model-registry connect vertex_registry --connector gcp
3535

3636
# Register Model Deployer
37-
zenml model-deployer register vertex_deployer --flavor=vertex --location=us-central1
38-
zenml model-deployer connect vertex_deployer --connector vertex_deployer_connector
37+
zenml model-deployer register vertex_deployer --flavor=vertex --location=europe-west1
38+
zenml model-deployer connect vertex_deployer --connector gcp
3939

4040
# Register the stack
4141
zenml stack register vertex_stack --orchestrator default --artifact-store gs_store --model-registry vertex_registry --model-deployer vertex_deployer

vertex-registry-and-deployer/pipelines/inference.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,15 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
#
17-
from zenml import get_pipeline_context, pipeline
17+
from zenml import pipeline
1818
from zenml.logger import get_logger
1919

2020
logger = get_logger(__name__)
2121

2222

2323
@pipeline
2424
def inference(random_state: int, target: str):
25-
"""
26-
Model inference pipeline.
25+
"""Model inference pipeline.
2726
2827
This is a pipeline that loads the inference data, processes it with
2928
the same preprocessing pipeline used in training, and runs inference
@@ -34,12 +33,12 @@ def inference(random_state: int, target: str):
3433
target: Name of target column in dataset.
3534
"""
3635
# Get the production model artifact
37-
model = get_pipeline_context().model.get_artifact("sklearn_classifier")
36+
# model_artifact = get_pipeline_context().model.get_artifact("sklearn_classifier")
3837

3938
# Get the preprocess pipeline artifact associated with this version
40-
preprocess_pipeline = get_pipeline_context().model.get_artifact(
41-
"preprocess_pipeline"
42-
)
39+
# preprocess_pipeline = get_pipeline_context().model.get_artifact(
40+
# "preprocess_pipeline"
41+
# )
4342

4443
# Link all the steps together by calling them and passing the output
4544
# of one step as the input of the next step.

vertex-registry-and-deployer/pipelines/training.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,10 @@
1616
#
1717

1818
from typing import Optional
19-
from uuid import UUID
2019

2120
from steps import model_deployer, model_promoter, model_register, model_trainer
21+
2222
from zenml import pipeline
23-
from zenml.client import Client
2423
from zenml.logger import get_logger
2524

2625
logger = get_logger(__name__)
@@ -47,9 +46,9 @@ def training(
4746
# Link all the steps together by calling them and passing the output
4847
# of one step as the input of the next step.
4948

50-
model, accuracy = model_trainer(target=target)
49+
_, accuracy = model_trainer(target=target)
5150
is_promoted = model_promoter(accuracy=accuracy)
52-
if is_promoted:
53-
model_registry_uri = model_register()
54-
model_deployer(model_registry_uri=model_registry_uri)
55-
51+
model_registry_uri = model_register(is_promoted)
52+
model_deployer(
53+
model_registry_uri=model_registry_uri, is_promoted=is_promoted
54+
)

vertex-registry-and-deployer/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@
1616
#
1717

1818
import os
19-
from typing import Optional
2019

2120
import click
2221
import yaml
2322
from pipelines import (
2423
inference,
2524
training,
2625
)
26+
2727
from zenml.client import Client
2828
from zenml.logger import get_logger
2929

Lines changed: 36 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,9 @@
1-
# Apache Software License 2.0
2-
#
3-
# Copyright (c) ZenML GmbH 2023. All rights reserved.
4-
#
5-
# Licensed under the Apache License, Version 2.0 (the "License");
6-
# you may not use this file except in compliance with the License.
7-
# You may obtain a copy of the License at
8-
#
9-
# http://www.apache.org/licenses/LICENSE-2.0
10-
#
11-
# Unless required by applicable law or agreed to in writing, software
12-
# distributed under the License is distributed on an "AS IS" BASIS,
13-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14-
# See the License for the specific language governing permissions and
15-
# limitations under the License.
16-
#
17-
18-
from typing import Any
19-
201
from typing_extensions import Annotated
2+
213
from zenml import ArtifactConfig, get_step_context, step
224
from zenml.client import Client
235
from zenml.integrations.gcp.services.vertex_deployment import (
24-
VertexAIDeploymentConfig,
6+
VertexDeploymentConfig,
257
VertexDeploymentService,
268
)
279
from zenml.logger import get_logger
@@ -32,33 +14,46 @@
3214
@step(enable_cache=False)
3315
def model_deployer(
3416
model_registry_uri: str,
17+
is_promoted: bool = False,
3518
) -> Annotated[
36-
VertexDeploymentService, ArtifactConfig(name="vertex_deployment", is_deployment_artifact=True)
19+
VertexDeploymentService,
20+
ArtifactConfig(name="vertex_deployment", is_deployment_artifact=True),
3721
]:
3822
"""Model deployer step.
39-
23+
4024
Args:
4125
model_registry_uri: URI of the model in the model registry.
42-
26+
is_promoted: Whether the model was promoted to production.
27+
4328
Returns:
4429
The deployed model service.
4530
"""
46-
zenml_client = Client()
47-
current_model = get_step_context().model
48-
model_deployer = zenml_client.active_stack.model_deployer
49-
vertex_deployment_config = VertexAIDeploymentConfig(
50-
location="europe-west1",
51-
name="zenml-vertex-quickstart",
52-
model_name=current_model.name,
53-
description="An example of deploying a model using the MLflow Model Deployer",
54-
model_id=model_registry_uri,
55-
)
56-
service = model_deployer.deploy_model(
57-
config=vertex_deployment_config,
58-
service_type=VertexDeploymentService.SERVICE_TYPE,
59-
)
31+
if not is_promoted:
32+
logger.info("Model not promoted, skipping deployment")
33+
return None
34+
else:
35+
zenml_client = Client()
36+
current_model = get_step_context().model
37+
model_deployer = zenml_client.active_stack.model_deployer
38+
39+
# Create deployment configuration with proper model name and version
40+
vertex_deployment_config = VertexDeploymentConfig(
41+
location="europe-west1",
42+
name=current_model.name,
43+
display_name="zenml-vertex-quickstart",
44+
model_name=model_registry_uri, # This is the full resource name from registration
45+
model_version=current_model.version, # Specify the version explicitly
46+
description="An example of deploying a model using the Vertex AI Model Deployer",
47+
sync=True, # Wait for deployment to complete
48+
traffic_percentage=100, # Route all traffic to this version
49+
)
50+
51+
service = model_deployer.deploy_model(
52+
config=vertex_deployment_config,
53+
service_type=VertexDeploymentService.SERVICE_TYPE,
54+
)
6055

61-
logger.info(
62-
f"The deployed service info: {model_deployer.get_model_server_info(service)}"
63-
)
64-
return service
56+
logger.info(
57+
f"The deployed service info: {model_deployer.get_model_server_info(service)}"
58+
)
59+
return service

vertex-registry-and-deployer/steps/model_promoter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@ def model_promoter(accuracy: float, stage: str = "production") -> bool:
5858
try:
5959
stage_model = client.get_model_version(current_model.name, stage)
6060
# We compare their metrics
61-
prod_accuracy = (
62-
stage_model.get_artifact("sklearn_classifier").run_metadata["test_accuracy"].value
63-
)
61+
prod_accuracy = stage_model.get_artifact(
62+
"sklearn_classifier"
63+
).run_metadata["test_accuracy"]
6464
if accuracy > float(prod_accuracy):
6565
# If current model has better metrics, we promote it
6666
is_promoted = True
Lines changed: 69 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,77 @@
1-
# model_register.py
2-
31
from typing_extensions import Annotated
2+
43
from zenml import ArtifactConfig, get_step_context, step
54
from zenml.client import Client
5+
from zenml.integrations.gcp.flavors.vertex_base_config import (
6+
VertexAIContainerSpec,
7+
VertexAIModelConfig,
8+
VertexAIResourceSpec,
9+
)
610
from zenml.logger import get_logger
11+
from zenml.model_registries.base_model_registry import (
12+
ModelRegistryModelMetadata,
13+
)
714

815
logger = get_logger(__name__)
916

17+
1018
@step(enable_cache=False)
11-
def model_register() -> Annotated[str, ArtifactConfig(name="model_registry_uri")]:
12-
"""Model registration step."""
13-
# Get the current model from the context
14-
current_model = get_step_context().model
15-
16-
client = Client()
17-
model_registry = client.active_stack.model_registry
18-
model_version = model_registry.register_model_version(
19-
name=current_model.name,
20-
version=str(current_model.version),
21-
model_source_uri=current_model.get_model_artifact("sklearn_classifier").uri,
22-
description="ZenML model registered after promotion",
23-
)
24-
logger.info(
25-
f"Model version {model_version.version} registered in Model Registry"
26-
)
27-
28-
return model_version.model_source_uri
19+
def model_register(
20+
is_promoted: bool = False,
21+
) -> Annotated[str, ArtifactConfig(name="model_registry_uri")]:
22+
"""Model registration step.
23+
24+
Registers a model version in the Vertex AI Model Registry with extended configuration
25+
and returns the model's source URI. This configuration embeds details such as container,
26+
resource, explanation settings etc. so that the deployment can reuse these pre-configured
27+
settings from the registry.
28+
"""
29+
if is_promoted:
30+
# Get the current model from the step context
31+
current_model = get_step_context().model
32+
33+
client = Client()
34+
model_registry = client.active_stack.model_registry
35+
# Create an extended model configuration using the base settings for Vertex AI
36+
model_config = VertexAIModelConfig(
37+
location="europe-west1",
38+
container=VertexAIContainerSpec(
39+
image_uri="europe-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.1-5:latest",
40+
predict_route="predict",
41+
health_route="health",
42+
ports=[8080],
43+
),
44+
resources=VertexAIResourceSpec(
45+
machine_type="n1-standard-4",
46+
min_replica_count=1,
47+
max_replica_count=1,
48+
),
49+
labels={"env": "production"},
50+
description="Extended model configuration for Vertex AI",
51+
)
52+
53+
# Register the model version with the extended configuration as metadata
54+
model_version = model_registry.register_model_version(
55+
name=current_model.name,
56+
version=str(current_model.version),
57+
model_source_uri=current_model.get_model_artifact(
58+
"sklearn_classifier"
59+
).uri,
60+
description="ZenML model version registered with extended configuration",
61+
metadata=ModelRegistryModelMetadata(
62+
zenml_pipeline_name=get_step_context().pipeline.name,
63+
zenml_pipeline_run_uuid=str(
64+
get_step_context().pipeline_run.id
65+
),
66+
zenml_step_name=get_step_context().step_run.name,
67+
),
68+
config=model_config,
69+
)
70+
logger.info(
71+
f"Model version {model_version.version} registered in Model Registry"
72+
)
73+
74+
# Return the full resource name of the registered model
75+
return model_version.registered_model.name
76+
else:
77+
return ""

vertex-registry-and-deployer/steps/model_trainer.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
# Apache Software License 2.0
32
#
43
# Copyright (c) ZenML GmbH 2024. All rights reserved.
@@ -16,7 +15,7 @@
1615
# limitations under the License.
1716
#
1817

19-
from typing import Optional, Tuple, Union
18+
from typing import Tuple, Union
2019

2120
import pandas as pd
2221
from sklearn.base import ClassifierMixin
@@ -26,12 +25,16 @@
2625
from sklearn.pipeline import Pipeline
2726
from sklearn.preprocessing import MinMaxScaler
2827
from typing_extensions import Annotated
29-
from zenml import ArtifactConfig, log_artifact_metadata, step
28+
29+
from zenml import ArtifactConfig, step
30+
from zenml.enums import ArtifactType
3031
from zenml.logger import get_logger
32+
from zenml.utils.metadata_utils import log_metadata
3133

3234
logger = get_logger(__name__)
3335

34-
@step
36+
37+
@step(enable_cache=False)
3538
def model_trainer(
3639
random_state: int = 42,
3740
test_size: float = 0.2,
@@ -41,7 +44,12 @@ def model_trainer(
4144
min_train_accuracy: float = 0.3,
4245
min_test_accuracy: float = 0.3,
4346
) -> Tuple[
44-
Annotated[ClassifierMixin, ArtifactConfig(name="sklearn_classifier", is_model_artifact=True)],
47+
Annotated[
48+
ClassifierMixin,
49+
ArtifactConfig(
50+
name="sklearn_classifier", artifact_type=ArtifactType.MODEL
51+
),
52+
],
4553
Annotated[float, ArtifactConfig(name="accuracy")],
4654
]:
4755
# Load the dataset
@@ -99,21 +107,22 @@ def model_trainer(
99107
for message in messages:
100108
logger.warning(message)
101109

102-
log_artifact_metadata(
110+
log_metadata(
103111
metadata={
104112
"train_accuracy": float(trn_acc),
105113
"test_accuracy": float(tst_acc),
106114
},
107115
artifact_name="sklearn_classifier",
116+
infer_artifact=True,
108117
)
109-
110118
return model, tst_acc
111119

120+
112121
class NADropper:
113122
"""Support class to drop NA values in sklearn Pipeline."""
114123

115124
def fit(self, *args, **kwargs): # noqa: D102
116125
return self
117126

118127
def transform(self, X: Union[pd.DataFrame, pd.Series]): # noqa: D102
119-
return X.dropna()
128+
return X.dropna()

0 commit comments

Comments
 (0)