Skip to content

Commit 42f952a

Browse files
ZenML for the experimentation phase (#132)
* Initial implementation * Additional changes * Working as grid search * Multiprocessing works now * Finished the story * Uncommented the feature_engineering pipeline * Update native-experiment-tracking/README.md Co-authored-by: Hamza Tahir <[email protected]> * Update native-experiment-tracking/README.md Co-authored-by: Hamza Tahir <[email protected]> * Update native-experiment-tracking/README.md Co-authored-by: Hamza Tahir <[email protected]> * Update native-experiment-tracking/README.md Co-authored-by: Hamza Tahir <[email protected]> * Final cleanup --------- Co-authored-by: Hamza Tahir <[email protected]>
1 parent d3759d4 commit 42f952a

23 files changed

+1229
-0
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
.venv*
2+
.requirements*
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# :: Track experiments in ZenML natively
2+
3+
Although ZenML plugs into many [experiment trackers](https://www.zenml.io/vs/zenml-vs-experiment-trackers), a lot of
4+
the functionality of experiment trackers is already covered by ZenML's native metadata and artifact tracking.
5+
This project aims to show these capabilities.
6+
7+
## 🎯 Project Overview
8+
We're tackling a simple classification task using the breast cancer dataset. Our goal is to showcase how ZenML can effortlessly track experiments, hyperparameters, and results throughout the machine learning workflow.
9+
10+
### 🔍 What We're Doing
11+
12+
In this project, we begin by preparing the breast cancer dataset for our model through data preprocessing. For our machine learning task, we've chosen to use an SGDClassifier. Rather than relying on sklearn's GridSearchCV, we implement our own hyperparameter tuning process to showcase ZenML's robust tracking capabilities. Finally, we conduct a thorough analysis of the results, visualizing how various hyperparameters influence the model's accuracy. This approach allows us to demonstrate the power of ZenML in tracking and managing the machine learning workflow.
13+
14+
We are by no means claiming that our solution outperforms GridSearchCV, spoiler alert, this demo won't, rather, this project demonstrates how you would do hyperparameter tuning and experiment tracking with ZenML on large deep learning problems.
15+
16+
### 🛠 The Pipeline
17+
18+
Our ZenML pipeline consists of the following steps:
19+
20+
The feature_engineering pipeline:
21+
* Data Loading: Load the breast cancer dataset.
22+
* Data Splitting: Split the data into training and testing sets.
23+
* Data Pre Processing: Pre process our dataset
24+
25+
The model training pipeline:
26+
* Model Training: Train multiple SGDClassifiers with different hyperparameters.
27+
* Model Evaluation: Evaluate each model's performance.
28+
29+
By running this pipeline iteratively
30+
31+
## :running: Run locally
32+
33+
```bash
34+
# Pip install all requirements
35+
pip install -r requirements.txt
36+
37+
# Install required zenml integrations
38+
zenml integration install sklearn pandas -y
39+
40+
# Initialize ZenML
41+
zenml init
42+
43+
# Connect to your ZenML server
44+
zenml connect --url ...
45+
46+
python run.py --parallel
47+
```
48+
49+
This will run a grid search across the following parameter space:
50+
51+
```python
52+
alpha_values = [0.0001, 0.001, 0.01]
53+
penalties = ["l2", "l1", "elasticnet"]
54+
losses = ["hinge", "squared_hinge", "modified_huber"]
55+
```
56+
57+
If you choose to include the `--parallel` flag, this should all run in parallel.
58+
As ZenML smartly caches across pipelines, and because the feature pipeline has run
59+
ahead of the parallel training runs, all training pipelines should start on the
60+
`model_trainer` step.
61+
![Pipeline DAG with cached steps](./assets/pipeline_dag_caching.png)
62+
63+
After running, you now should have 27 runs of the model training with 27
64+
produced model_versions. In case you are running with [ZenML Pro](https://docs.zenml.io/getting-started/zenml-pro)
65+
you'll now be able to inspect these models in the dashboard:
66+
![Model Versions Page](./assets/model_versions.png)
67+
68+
Additionally, in case you ran with a remote [Data backend](https://docs.zenml.io/stack-components/artifact-stores),
69+
you'll be able to inspect the confusion matrix for any specific training directly in the
70+
frontend.
71+
![Confusion Matrix Visualization](./assets/cm_visualization.png)
72+
73+
In case you want to create your own visualization, check out the implementation
74+
at `native-experiment-tracking/steps/model_trainer.py:generate_cm`. Basically, just create a
75+
matplotlib plot, convert it into a `PIL.Image` and return it from your
76+
step. Don't forget to annotate your [step output accordingly](https://docs.zenml.io/how-to/build-pipelines/step-output-typing-and-annotation.
77+
78+
```python
79+
from typing import Tuple
80+
from typing_extensions import Annotated
81+
from PIL import Image
82+
from zenml import ArtifactConfig, step
83+
84+
@step
85+
def func(...) -> Tuple[
86+
Annotated[
87+
...
88+
],
89+
Annotated[
90+
Image.Image, "confusion_matrix"
91+
]
92+
]:
93+
```
94+
95+
## 📈 Explore your experiments
96+
97+
Once all pipelines ran, it is time to analyze our experiment.
98+
For this we have written an analyze.py script.
99+
```commandline
100+
python analyze.py
101+
```
102+
This will generate 2 plots for you:
103+
104+
**3D Plot**
105+
![3D Plot](./assets/3d_plot.png)
106+
107+
**2D Plot**
108+
![2D Plot](./assets/2d_plot.png)
109+
110+
Feel free to use this file as a starting point to write your very own
111+
analysis.
112+
113+
## The moral of the story
114+
115+
So what's the point? We at ZenML believe that any good experiment should be set up in a
116+
repeatable, scalable way while storing all the relevant metadata in order to analyze the experiment
117+
after the fact. This project shows how you could do this with ZenML.
118+
119+
Once you have accomplished this on a toy dataset with a tiny SGDClassifier, you can start
120+
scaling up in all dimensions: data, parameters, model, etc... And all of this while staying infrastructure
121+
agnostic. So when your experiment outgrows your local machine, you can simply move
122+
to the stack of your choice ...
123+
124+
## 🤝 Contributing
125+
126+
Contributions to improve the pipeline are welcome! Please feel free to submit a Pull Request.
127+
128+
## 📄 License
129+
130+
This project is licensed under the Apache License 2.0. See the LICENSE file for details.
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
# Apache Software License 2.0
2+
#
3+
# Copyright (c) ZenML GmbH 2024. All rights reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
import matplotlib.pyplot as plt
18+
import numpy as np
19+
import pandas as pd
20+
import seaborn as sns
21+
from zenml.client import Client
22+
23+
24+
def main():
25+
client = Client()
26+
27+
model_versions = client.list_model_versions(
28+
model_name_or_id="breast_cancer_classifier", size=27, hydrate=True
29+
)
30+
31+
alpha_values = []
32+
losses = []
33+
penalties = []
34+
test_accuracies = []
35+
train_accuracies = []
36+
37+
for model_version in model_versions:
38+
mv_metadata = model_version.run_metadata
39+
40+
alpha_values.append(mv_metadata.get("alpha_value", None).value)
41+
losses.append(mv_metadata.get("loss", None).value)
42+
penalties.append(mv_metadata.get("penalty", None).value)
43+
test_accuracies.append(mv_metadata.get("test_accuracy", None).value)
44+
train_accuracies.append(mv_metadata.get("train_accuracy", None).value)
45+
46+
generate_3d_plot(alpha_values, losses, penalties, test_accuracies)
47+
generate_2d_plots(alpha_values, losses, penalties, test_accuracies)
48+
49+
50+
def generate_2d_plots(alpha_values, losses, penalties, test_accuracies):
51+
# Convert the data into a DataFrame
52+
df = pd.DataFrame(
53+
{
54+
"Alpha": alpha_values,
55+
"Loss": losses,
56+
"Penalty": penalties,
57+
"Accuracy": test_accuracies,
58+
}
59+
)
60+
61+
# Get unique values
62+
unique_penalties = df["Penalty"].unique()
63+
64+
# Create a figure with subplots for each penalty
65+
fig, axes = plt.subplots(
66+
1, len(unique_penalties), figsize=(20, 6), sharey=True
67+
)
68+
fig.suptitle("Accuracy Heatmap for Different Penalties", fontsize=16)
69+
70+
for i, penalty in enumerate(unique_penalties):
71+
# Filter data for the current penalty
72+
df_penalty = df[df["Penalty"] == penalty]
73+
74+
# Create a pivot table
75+
pivot = df_penalty.pivot(
76+
index="Loss", columns="Alpha", values="Accuracy"
77+
)
78+
79+
# Create heatmap
80+
sns.heatmap(
81+
pivot,
82+
ax=axes[i],
83+
cmap="viridis",
84+
annot=True,
85+
fmt=".3f",
86+
cbar=False,
87+
)
88+
89+
axes[i].set_title(f"Penalty: {penalty}")
90+
axes[i].set_xlabel("Alpha")
91+
92+
if i == 0:
93+
axes[i].set_ylabel("Loss")
94+
95+
# Add a colorbar to the right of the subplots
96+
cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
97+
fig.colorbar(axes[0].collections[0], cax=cbar_ax, label="Accuracy")
98+
99+
plt.tight_layout(rect=[0, 0, 0.9, 1])
100+
plt.show()
101+
102+
103+
def generate_3d_plot(alpha_values, losses, penalties, test_accuracies):
104+
# Convert losses and penalties to numerical indices
105+
unique_losses = list(set(losses))
106+
unique_penalties = list(set(penalties))
107+
108+
loss_indices = [unique_losses.index(loss) for loss in losses]
109+
penalty_indices = [
110+
unique_penalties.index(penalty) for penalty in penalties
111+
]
112+
113+
# Create a figure and a 3D axis
114+
fig = plt.figure(figsize=(12, 8))
115+
ax = fig.add_subplot(111, projection="3d")
116+
117+
# Create a scatter plot
118+
scatter = ax.scatter(
119+
alpha_values,
120+
loss_indices,
121+
penalty_indices,
122+
c=test_accuracies,
123+
cmap="viridis",
124+
)
125+
# Find the point with the highest accuracy
126+
max_accuracy_index = np.argmax(test_accuracies)
127+
max_accuracy = test_accuracies[max_accuracy_index]
128+
max_alpha = alpha_values[max_accuracy_index]
129+
max_loss = losses[max_accuracy_index]
130+
max_penalty = penalties[max_accuracy_index]
131+
132+
# Highlight the point with the highest accuracy
133+
ax.scatter(
134+
[max_alpha],
135+
[loss_indices[max_accuracy_index]],
136+
[penalty_indices[max_accuracy_index]],
137+
c="red",
138+
s=100,
139+
edgecolors="black",
140+
linewidths=2,
141+
zorder=10,
142+
)
143+
144+
# Set labels for each axis
145+
ax.set_xlabel("Alpha")
146+
ax.set_ylabel("Loss")
147+
ax.set_zlabel("Penalty")
148+
149+
# Set custom ticks for loss and penalty axes
150+
ax.set_yticks(range(len(unique_losses)))
151+
ax.set_yticklabels(unique_losses)
152+
ax.set_zticks(range(len(unique_penalties)))
153+
ax.set_zticklabels(unique_penalties)
154+
155+
# Add a color bar
156+
cbar = plt.colorbar(scatter)
157+
cbar.set_label("Accuracy")
158+
159+
# Set a title
160+
plt.title("Accuracy vs. Alpha, Loss, and Penalty")
161+
162+
# Adjust the viewing angle
163+
ax.view_init(elev=20, azim=45)
164+
165+
# Add legend with highest accuracy point description
166+
legend_text = f"Highest Accuracy:\nAccuracy: {max_accuracy:.4f}\nAlpha: {max_alpha}\nLoss: {max_loss}\nPenalty: {max_penalty}"
167+
ax.text2D(
168+
0.05,
169+
0.95,
170+
legend_text,
171+
transform=ax.transAxes,
172+
fontsize=10,
173+
verticalalignment="top",
174+
bbox=dict(boxstyle="round", facecolor="white", alpha=0.8),
175+
)
176+
177+
# Show the plot
178+
plt.tight_layout()
179+
plt.show()
180+
return
181+
182+
183+
if __name__ == "__main__":
184+
main()
54.1 KB
Loading
106 KB
Loading
90.8 KB
Loading
126 KB
Loading
33.9 KB
Loading
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# environment configuration
2+
settings:
3+
docker:
4+
required_integrations:
5+
- sklearn
6+
- pandas
7+
requirements:
8+
- pyarrow
9+
10+
# pipeline configuration
11+
test_size: 0.35
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# environment configuration
2+
settings:
3+
docker:
4+
required_integrations:
5+
- sklearn
6+
- pandas
7+
requirements:
8+
- pyarrow
9+
- matplotlib
10+
- pillow
11+
- numpy

0 commit comments

Comments
 (0)