Skip to content

Commit 0c8cca2

Browse files
committed
format
1 parent 69f914a commit 0c8cca2

File tree

29 files changed

+111
-103
lines changed

29 files changed

+111
-103
lines changed

airflow-cloud-composer-etl-feature-train/steps/training/model_trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ def train_xgboost_model(
3232
dataset: Union[BigQueryDataset, CSVDataset],
3333
) -> Tuple[
3434
Annotated[
35-
xgb.Booster, ArtifactConfig(name="xgb_model", artifact_type=ArtifactType.MODEL)
35+
xgb.Booster,
36+
ArtifactConfig(name="xgb_model", artifact_type=ArtifactType.MODEL),
3637
],
3738
Annotated[Dict[str, float], "metrics"],
3839
]:

databricks-production-qa-demo/pipelines/batch_inference.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,12 @@
2323
notify_on_failure,
2424
notify_on_success,
2525
)
26+
from steps.explainability import explain_model
2627
from zenml import get_pipeline_context, pipeline
2728
from zenml.integrations.evidently.metrics import EvidentlyMetricConfig
2829
from zenml.integrations.evidently.steps import evidently_report_step
2930
from zenml.logger import get_logger
3031

31-
from steps.explainability import explain_model
32-
3332
logger = get_logger(__name__)
3433

3534

databricks-production-qa-demo/steps/deployment/deployment_deploy.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,10 @@
3131

3232

3333
@step(enable_cache=False)
34-
def deployment_deploy() -> (
35-
Annotated[
36-
Optional[DatabricksDeploymentService],
37-
ArtifactConfig(
38-
name="databricks_deployment", is_deployment_artifact=True
39-
),
40-
]
41-
):
34+
def deployment_deploy() -> Annotated[
35+
Optional[DatabricksDeploymentService],
36+
ArtifactConfig(name="databricks_deployment", is_deployment_artifact=True),
37+
]:
4238
"""Predictions step.
4339
4440
This is an example of a predictions step that takes the data in and returns
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .shap_explainer import explain_model
1+
from .shap_explainer import explain_model
Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,25 @@
1+
from typing import Annotated
2+
3+
import pandas as pd
4+
import shap
15
from sklearn.base import ClassifierMixin
26
from zenml import get_step_context, log_artifact_metadata
3-
import shap
4-
import pandas as pd
5-
from typing import Annotated
67
from zenml.steps import step
8+
79
from .shap_visualization import SHAPVisualization
810

11+
912
@step
1013
def explain_model(
11-
X_train: pd.DataFrame
14+
X_train: pd.DataFrame,
1215
) -> Annotated[SHAPVisualization, "shap_visualization"]:
1316
"""Generate SHAP values for model explainability and create a visualization."""
1417
model = get_step_context().model
1518
model_artifact: ClassifierMixin = model.load_artifact("model")
16-
17-
explainer = shap.KernelExplainer(model_artifact.predict_proba, shap.sample(X_train, 100))
19+
20+
explainer = shap.KernelExplainer(
21+
model_artifact.predict_proba, shap.sample(X_train, 100)
22+
)
1823
shap_values = explainer.shap_values(X_train.iloc[:100])
1924

2025
log_artifact_metadata(
@@ -25,7 +30,7 @@ def explain_model(
2530
"n_classes": len(shap_values),
2631
"n_features": shap_values[0].shape[1],
2732
}
28-
}
33+
},
2934
)
3035

31-
return SHAPVisualization(shap_values, X_train.columns)
36+
return SHAPVisualization(shap_values, X_train.columns)

databricks-production-qa-demo/steps/explainability/shap_visualization.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1-
import os
21
import io
2+
import os
33
from typing import Dict
44

5-
import shap
65
import matplotlib.pyplot as plt
7-
6+
import shap
87
from zenml.enums import ArtifactType, VisualizationType
98
from zenml.io import fileio
109
from zenml.materializers.base_materializer import BaseMaterializer
1110

11+
1212
# Custom class to hold SHAP visualization data
1313
class SHAPVisualization:
1414
def __init__(self, shap_values, feature_names):
@@ -22,18 +22,23 @@ class SHAPVisualizationMaterializer(BaseMaterializer):
2222
ASSOCIATED_ARTIFACT_TYPE = ArtifactType.DATA_ANALYSIS
2323

2424
def save_visualizations(
25-
self, data: SHAPVisualization
25+
self, data: SHAPVisualization
2626
) -> Dict[str, VisualizationType]:
2727
plt.figure(figsize=(10, 6))
28-
shap.summary_plot(data.shap_values, feature_names=data.feature_names, plot_type="bar", show=False)
28+
shap.summary_plot(
29+
data.shap_values,
30+
feature_names=data.feature_names,
31+
plot_type="bar",
32+
show=False,
33+
)
2934
plt.title("SHAP Feature Importance")
3035

3136
buf = io.BytesIO()
32-
plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
37+
plt.savefig(buf, format="png", dpi=150, bbox_inches="tight")
3338
buf.seek(0)
3439

3540
visualization_path = os.path.join(self.uri, "shap_summary_plot.png")
36-
with fileio.open(visualization_path, 'wb') as f:
41+
with fileio.open(visualization_path, "wb") as f:
3742
f.write(buf.getvalue())
3843

3944
plt.close()

databricks-production-qa-demo/steps/training/model_evaluator.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from PIL import Image, ImageDraw, ImageFont
44
from sklearn.base import ClassifierMixin
55
from typing_extensions import Annotated
6-
from zenml import step, get_step_context, log_metadata
6+
from zenml import get_step_context, log_metadata, step
77
from zenml.client import Client
88
from zenml.logger import get_logger
99

@@ -31,8 +31,8 @@ def model_evaluator(
3131
dataset_tst.drop(columns=[target]), dataset_tst[target]
3232
)
3333

34-
logger.info(f"Train accuracy={trn_acc*100:.2f}%")
35-
logger.info(f"Test accuracy={tst_acc*100:.2f}%")
34+
logger.info(f"Train accuracy={trn_acc * 100:.2f}%")
35+
logger.info(f"Test accuracy={tst_acc * 100:.2f}%")
3636
mlflow.log_metric("testing_accuracy_score", tst_acc)
3737

3838
step_context = get_step_context()
@@ -44,7 +44,7 @@ def model_evaluator(
4444
"test_accuracy": tst_acc,
4545
}
4646
},
47-
infer_model=True
47+
infer_model=True,
4848
)
4949

5050
# Fetch previous versions (same as before)
@@ -55,7 +55,9 @@ def model_evaluator(
5555
step_context.model.name, version.version
5656
)
5757
if "evaluation_metrics" in version_obj.run_metadata:
58-
test_accuracy = version_obj.run_metadata["evaluation_metrics"].get("test_accuracy")
58+
test_accuracy = version_obj.run_metadata["evaluation_metrics"].get(
59+
"test_accuracy"
60+
)
5961
if test_accuracy is not None:
6062
previous_versions.append(
6163
(f"v{version.version}", float(test_accuracy))

databricks-production-qa-demo/steps/training/model_trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ def model_trainer(
5050
target: str,
5151
name: str,
5252
) -> Annotated[
53-
ClassifierMixin, ArtifactConfig(name="model", artifact_type=ArtifactType.MODEL)
53+
ClassifierMixin,
54+
ArtifactConfig(name="model", artifact_type=ArtifactType.MODEL),
5455
]:
5556
"""Configure and train a model on the training dataset.
5657

end-to-end-computer-vision/steps/export_label_studio.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,12 @@
1616
#
1717
from typing import Annotated, List, Tuple
1818

19-
from zenml import log_metadata, step
20-
from zenml.client import Client
21-
from zenml.logger import get_logger
22-
2319
from materializers.label_studio_export_materializer import (
2420
LabelStudioAnnotationExport,
2521
LabelStudioAnnotationMaterializer,
2622
)
2723
from utils.constants import LABELED_DATASET_NAME
28-
from zenml import log_artifact_metadata, step
24+
from zenml import log_metadata, step
2925
from zenml.client import Client
3026
from zenml.logger import get_logger
3127

end-to-end-computer-vision/steps/fiftyone_inference.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,12 @@
1818
from typing import Annotated
1919

2020
import fiftyone as fo
21-
from zenml import log_metadata, step
22-
from zenml.client import Client
23-
from zenml.io import fileio
24-
from zenml.logger import get_logger
25-
2621
from utils.constants import (
2722
DATASET_NAME,
2823
PREDICTIONS_DATASET_ARTIFACT_NAME,
2924
TRAINED_MODEL_NAME,
3025
)
31-
from zenml import log_artifact_metadata, step
26+
from zenml import log_metadata, step
3227
from zenml.client import Client
3328
from zenml.io import fileio
3429
from zenml.logger import get_logger

0 commit comments

Comments
 (0)