Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from materializers import BigQueryDataset, CSVDataset
from typing_extensions import Annotated
from zenml import ArtifactConfig, step
from zenml.enums import ArtifactType
from zenml.logger import get_logger

logger = get_logger(__name__)
Expand All @@ -31,7 +32,7 @@ def train_xgboost_model(
dataset: Union[BigQueryDataset, CSVDataset],
) -> Tuple[
Annotated[
xgb.Booster, ArtifactConfig(name="xgb_model", is_model_artifact=True)
xgb.Booster, ArtifactConfig(name="xgb_model", artifact_type=ArtifactType.MODEL)
],
Annotated[Dict[str, float], "metrics"],
]:
Expand Down
3 changes: 2 additions & 1 deletion customer-satisfaction/steps/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from model.model_dev import ModelTrainer
from sklearn.base import RegressorMixin
from zenml import ArtifactConfig, step
from zenml.enums import ArtifactType
from zenml.client import Client

experiment_tracker = Client().active_stack.experiment_tracker
Expand All @@ -21,7 +22,7 @@ def train_model(
do_fine_tuning: bool = True,
) -> Annotated[
RegressorMixin,
ArtifactConfig(name="sklearn_regressor", is_model_artifact=True),
ArtifactConfig(name="sklearn_regressor", artifact_type=ArtifactType.MODEL),
]:
"""
Args:
Expand Down
156 changes: 0 additions & 156 deletions databricks-demo/README.md

This file was deleted.

File renamed without changes.
File renamed without changes.
154 changes: 154 additions & 0 deletions databricks-production-qa-demo/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Databricks + ZenML: End-to-End Explainable ML Project

Welcome to this end-to-end demo project that showcases how to train, deploy, and run batch inference on a machine learning model using ZenML in a Databricks environment. This setup demonstrates how ZenML can simplify the end-to-end process of building reproducible, production-grade ML pipelines with minimal fuss.

## Overview

This project uses an example classification dataset (Breast Cancer) and provides three major pipelines:

1. Training Pipeline
2. Deployment Pipeline
3. Batch Inference Pipeline (with SHAP-based model explainability)

The pipelines are orchestrated via ZenML. Additionally, this setup uses:
- Databricks as the orchestrator
- MLflow for experiment tracking and model registry
- Evidently for data drift detection
- SHAP for model explainability during inference
- Slack notifications (configurable through ZenML's alerter stack components)

## Why ZenML?

ZenML is a lightweight MLOps framework for reproducible pipelines. With ZenML, you get:

- A consistent, standardized way to develop, version, and share pipelines.
- Easy integration with various cloud providers, experiment trackers, model registries, and more.
- Reproducibility and better collaboration: your pipelines and associated artifacts are automatically tracked and versioned.
- Simple command-line interface for spinning pipelines up and down with different stack components (like local or Databricks orchestrators).
- Built-in best practices for production ML, including quality gates for data drift and model performance thresholds.

## Project Structure

Here's an outline of the repository:

```
.
├── configs # Pipeline configuration files
│ ├── deployer_config.yaml # Deployment pipeline config
│ ├── inference_config.yaml # Batch inference pipeline config
│ └── train_config.yaml # Training pipeline config
├── pipelines # ZenML pipeline definitions
│ ├── batch_inference.py # Orchestrates batch inference
│ ├── deployment.py # Deploys a model service
│ └── training.py # Trains and promotes model
├── steps # ZenML steps logic
│ ├── alerts # Alert/notification logic
│ ├── data_quality # Data drift and quality checks
│ ├── deployment # Deployment step
│ ├── etl # ETL steps (data loading, preprocessing, splitting)
│ ├── explainability # SHAP-based model explanations
│ ├── hp_tuning # Hyperparameter tuning pipeline steps
│ ├── inference # Batch inference step
│ ├── promotion # Model promotion logic
│ └── training # Model training and evaluation steps
├── utils # Helper modules
├── Makefile # Quick integration setup commands
├── requirements.txt # Python dependencies
├── run.py # CLI to run pipelines
└── README.md # This file
```

## Getting Started

1. (Optional) Create and activate a Python virtual environment:
```bash
python3 -m venv .venv
source .venv/bin/activate
```
2. Install dependencies:
```bash
make setup
```
This installs the required ZenML integrations (MLflow, Slack, Evidently, Kubeflow, Kubernetes, AWS, etc.) and any library dependencies.

3. (Optional) Set up a local Stack (if you want to try this outside Databricks):
```bash
make install-stack-local
```

4. If you have Databricks properly configured in your ZenML stack (with the Databricks token secret set up, cluster name, etc.), you can orchestrate the pipelines on Databricks by default.

## Running the Project

All pipeline runs happen via the CLI in run.py. Here are the main options:

• View available options:
```bash
python run.py --help
```

• Run everything (train, deploy, inference) with default settings:
```bash
python run.py --training --deployment --inference
```
This will:
1. Train a model and evaluate its performance
2. Deploy the model if it meets quality criteria
3. Run batch inference with SHAP explanations and data drift checks

• Run just the training pipeline (to build or update a model):
```bash
python run.py --training
```

• Run just the deployment pipeline (to deploy the latest staged model):
```bash
python run.py --deployment
```

• Run just the batch inference pipeline (to generate predictions and explanations while checking for data drift):
```bash
python run.py --inference
```

### Additional Command-Line Flags

• Disable caching:
```bash
python run.py --no-cache --training
```

• Skip dropping NA values or skipping normalization:
```bash
python run.py --no-drop-na --no-normalize --training
```

• Drop specific columns:
```bash
python run.py --training --drop-columns columnA,columnB
```

• Set minimal accuracy thresholds for training and test sets:
```bash
python run.py --min-train-accuracy 0.9 --min-test-accuracy 0.8 --fail-on-accuracy-quality-gates --training
```

When you run any of these commands, ZenML will orchestrate each pipeline on the active stack (Databricks if configured) and log the results in your model registry (MLflow). If you have Slack or other alerter components configured, you'll see success/failure notifications.

## Observing Your Pipelines

ZenML offers a local dashboard that you can launch with:
```bash
zenml up
```
Check the terminal logs for the local web address (usually http://127.0.0.1:8237). You'll see pipeline runs, steps, and artifacts.

If you deployed on Databricks, you can also see the runs orchestrated in the Databricks jobs UI. The project is flexible enough to run the same pipelines locally or in the cloud without changing the code.

## Contributing & License

Contributions and suggestions are welcome. This project is licensed under the Apache License 2.0.

For questions, feedback, or support, please reach out to the ZenML community or open an issue in this repository.

---
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ settings:
- mlflow
- sklearn
- databricks
python_package_installer: "uv"
orchestrator.databricks:
cluster_name: adas_
node_type_id: Standard_D8ads_v5
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from zenml.integrations.evidently.steps import evidently_report_step
from zenml.logger import get_logger

from steps.explainability import explain_model

logger = get_logger(__name__)


Expand All @@ -53,6 +55,10 @@ def production_line_qa_batch_inference():
preprocess_pipeline=model.get_artifact("preprocess_pipeline"),
target=target,
)

########## Model Explainability stage ##########
explain_model(df_inference)

########## DataQuality stage ##########
report, _ = evidently_report_step(
reference_dataset=model.get_artifact("dataset_trn"),
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .shap_explainer import explain_model
Loading
Loading