Skip to content

Commit f046037

Browse files
committed
fix older syntax
1 parent 516d6e3 commit f046037

File tree

3 files changed

+9
-8
lines changed

3 files changed

+9
-8
lines changed

databricks-production-qa-demo/steps/hp_tuning/hp_tuning_select_best_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,10 @@ def hp_tuning_select_best_model(
4747
best_metric = -1
4848
# consume artifacts attached to current model version in Model Control Plane
4949
for step_name in step_names:
50-
hp_output = model.get_data_artifact("hp_result")
50+
hp_output = model.get_model_artifact("hp_result")
5151
model_: ClassifierMixin = hp_output.load()
5252
# fetch metadata we attached earlier
53-
metric = float(hp_output.run_metadata["metric"].value)
53+
metric = float(hp_output.run_metadata["metric"])
5454
if best_model is None or best_metric < metric:
5555
best_model = model_
5656
### YOUR CODE ENDS HERE ###

databricks-production-qa-demo/steps/promotion/promote_with_metric_compare.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,14 +89,14 @@ def promote_with_metric_compare(
8989
# Promote in Model Registry
9090
latest_version_model_registry_number = latest_version.run_metadata[
9191
"model_registry_version"
92-
].value
92+
]
9393
if current_version_number is None:
9494
current_version_model_registry_number = (
9595
latest_version_model_registry_number
9696
)
9797
else:
9898
current_version_model_registry_number = (
99-
current_version.run_metadata["model_registry_version"].value
99+
current_version.run_metadata["model_registry_version"]
100100
)
101101
promote_in_model_registry(
102102
latest_version=latest_version_model_registry_number,
@@ -108,7 +108,7 @@ def promote_with_metric_compare(
108108
else:
109109
promoted_version = current_version.run_metadata[
110110
"model_registry_version"
111-
].value
111+
]
112112

113113
logger.info(
114114
f"Current model version in `{target_env}` is `{promoted_version}` registered in Model Registry"

databricks-production-qa-demo/steps/training/model_evaluator.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import mlflow
33
import pandas as pd
44
from sklearn.base import ClassifierMixin
5-
from zenml import step, get_step_context, log_model_metadata
5+
from zenml import step, get_step_context, log_metadata
66
from zenml.client import Client
77
from zenml.logger import get_logger
88
from PIL import Image, ImageDraw, ImageFont, ImageFilter
@@ -34,13 +34,14 @@ def model_evaluator(
3434

3535
step_context = get_step_context()
3636

37-
log_model_metadata(
37+
log_metadata(
3838
metadata={
3939
"evaluation_metrics": {
4040
"train_accuracy": trn_acc,
4141
"test_accuracy": tst_acc
4242
}
4343
},
44+
infer_model=True
4445
)
4546

4647
# Fetch previous versions (same as before)
@@ -49,7 +50,7 @@ def model_evaluator(
4950
for version in client.get_model(step_context.model.name).versions:
5051
version_obj = client.get_model_version(step_context.model.name, version.version)
5152
if "evaluation_metrics" in version_obj.run_metadata:
52-
test_accuracy = version_obj.run_metadata["evaluation_metrics"].value.get("test_accuracy")
53+
test_accuracy = version_obj.run_metadata["evaluation_metrics"].get("test_accuracy")
5354
if test_accuracy is not None:
5455
previous_versions.append((f"v{version.version}", float(test_accuracy)))
5556

0 commit comments

Comments
 (0)