Skip to content

Commit ca40be8

Browse files
committed
Working as grid search
1 parent 8236fec commit ca40be8

File tree

7 files changed

+178
-239
lines changed

7 files changed

+178
-239
lines changed
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import numpy as np
2+
from matplotlib import pyplot as plt
3+
from zenml.client import Client
4+
5+
6+
def main():
7+
client = Client()
8+
9+
model_versions = client.list_model_versions(model_name_or_id="breast_cancer_classifier", size=30, hydrate=True)
10+
11+
alpha_values = []
12+
losses = []
13+
penalties = []
14+
test_accuracies = []
15+
train_accuracies = []
16+
17+
for model_version in model_versions:
18+
mv_metadata = model_version.run_metadata
19+
20+
alpha_values.append(mv_metadata.get("alpha_value", None).value)
21+
losses.append(mv_metadata.get("loss", None).value)
22+
penalties.append(mv_metadata.get("penalty", None).value)
23+
test_accuracies.append(mv_metadata.get("test_accuracy", None).value)
24+
train_accuracies.append(mv_metadata.get("train_accuracy", None).value)
25+
26+
generate_plot(alpha_values, losses, penalties, test_accuracies)
27+
28+
29+
def generate_plot(alpha_values, losses, penalties, test_accuracies):
30+
# Convert losses and penalties to numerical indices
31+
unique_losses = list(set(losses))
32+
unique_penalties = list(set(penalties))
33+
34+
loss_indices = [unique_losses.index(loss) for loss in losses]
35+
penalty_indices = [unique_penalties.index(penalty) for penalty in penalties]
36+
37+
# Create a figure and a 3D axis
38+
fig = plt.figure(figsize=(12, 8))
39+
ax = fig.add_subplot(111, projection='3d')
40+
41+
# Create a scatter plot
42+
scatter = ax.scatter(alpha_values, loss_indices, penalty_indices, c=test_accuracies, cmap='viridis')
43+
44+
# Set labels for each axis
45+
ax.set_xlabel('Alpha')
46+
ax.set_ylabel('Loss')
47+
ax.set_zlabel('Penalty')
48+
49+
# Set custom ticks for loss and penalty axes
50+
ax.set_yticks(range(len(unique_losses)))
51+
ax.set_yticklabels(unique_losses)
52+
ax.set_zticks(range(len(unique_penalties)))
53+
ax.set_zticklabels(unique_penalties)
54+
55+
# Add a color bar
56+
cbar = plt.colorbar(scatter)
57+
cbar.set_label('Accuracy')
58+
59+
# Set a title
60+
plt.title('Accuracy vs. Alpha, Loss, and Penalty')
61+
62+
# Adjust the viewing angle
63+
ax.view_init(elev=20, azim=45)
64+
65+
# Show the plot
66+
plt.tight_layout()
67+
plt.show()
68+
return
69+
70+
71+
72+
if __name__ == "__main__":
73+
main()

native-experiment-tracking/configs/training.yaml

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,3 @@ settings:
99
- matplotlib
1010
- pillow
1111
- numpy
12-
13-
# configuration of the Model Control Plane
14-
model:
15-
name: breast_cancer_classifier
16-
license: Apache 2.0
17-
description: A breast cancer classifier
18-
tags: ["breast_cancer", "classifier"]

native-experiment-tracking/pipelines/training.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from typing import Optional
1919
from uuid import UUID
2020

21-
from steps import model_evaluator, model_promoter, model_trainer, model_grid_searcher
21+
from steps import model_evaluator, model_promoter, model_trainer
2222

2323
from pipelines import (
2424
feature_engineering,
@@ -33,8 +33,9 @@
3333

3434
@pipeline
3535
def training(
36-
train_dataset_id: Optional[UUID] = None,
37-
test_dataset_id: Optional[UUID] = None,
36+
alpha_value: float,
37+
penalty: str,
38+
loss: str,
3839
target: Optional[str] = "target",
3940
):
4041
"""
@@ -47,27 +48,19 @@ def training(
4748
model version.
4849
4950
Args:
50-
train_dataset_id: ID of the train dataset produced by feature engineering.
51-
test_dataset_id: ID of the test dataset produced by feature engineering.
5251
target: Name of target column in dataset.
52+
alpha_value: Alpha value to use for the train step,
53+
penalty: Penalty to use for sgd,
54+
loss: Loss function to be used for sgd,
5355
"""
5456
# Link all the steps together by calling them and passing the output
5557
# of one step as the input of the next step.
5658

5759
# Execute Feature Engineering Pipeline
58-
if train_dataset_id is None or test_dataset_id is None:
59-
dataset_trn, dataset_tst = feature_engineering()
60-
else:
61-
client = Client()
62-
dataset_trn = client.get_artifact_version(
63-
name_id_or_prefix=train_dataset_id
64-
)
65-
dataset_tst = client.get_artifact_version(
66-
name_id_or_prefix=test_dataset_id
67-
)
60+
dataset_trn, dataset_tst = feature_engineering()
6861

69-
model, _, _ = model_grid_searcher(
70-
dataset_trn=dataset_trn, target=target
62+
model, _ = model_trainer(
63+
dataset_trn=dataset_trn, target=target, alpha_value=alpha_value, penalty=penalty, loss=loss
7164
)
7265

7366
acc, _ = model_evaluator(
@@ -76,5 +69,3 @@ def training(
7669
dataset_tst=dataset_tst,
7770
target=target,
7871
)
79-
80-
model_promoter(accuracy=acc)

native-experiment-tracking/run.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,24 @@
1717
import os
1818

1919
import click
20-
20+
from sklearn.utils._param_validation import InvalidParameterError
21+
from zenml import Model
22+
from zenml.client import Client
2123
from zenml.logger import get_logger
2224

2325
from pipelines import training
2426

2527
logger = get_logger(__name__)
2628

29+
2730
@click.option(
2831
"--no-cache",
2932
is_flag=True,
3033
default=False,
3134
help="Disable caching for the pipeline run.",
3235
)
3336
def main(
34-
no_cache: bool = False,
37+
no_cache: bool = False,
3538
):
3639
"""Main entry point for the pipeline execution.
3740
@@ -45,20 +48,34 @@ def main(
4548
Args:
4649
no_cache: If `True` cache will be disabled.
4750
"""
48-
config_folder = os.path.join(
51+
client = Client()
52+
config_path = os.path.join(
4953
os.path.dirname(os.path.realpath(__file__)),
5054
"configs",
55+
"training.yaml"
5156
)
57+
enable_cache = not no_cache
58+
59+
alpha_values = [0.0001, 0.001, 0.01]
60+
penalties = ["l2", "l1", "elasticnet"]
61+
losses = ["hinge", "squared_hinge", "modified_huber"]
62+
for penalty in penalties:
63+
for loss in losses:
64+
for alpha_value in alpha_values:
65+
logger.info(f"Training with alpha: {alpha_value}, penalty: {penalty}, loss: {loss}")
66+
67+
model = Model(
68+
name="breast_cancer_classifier",
69+
tags=[f"alpha: {alpha_value}", f"penalty: {penalty}", f"loss: {loss}"]
70+
)
71+
try:
72+
training.with_options(config_path=config_path, enable_cache=enable_cache, model=model)(
73+
alpha_value=alpha_value, penalty=penalty, loss=loss)
74+
except RuntimeError:
75+
pass
76+
else:
77+
logger.info("Training pipeline finished successfully!\n\n")
5278

53-
pipeline_args = {}
54-
if no_cache:
55-
pipeline_args["enable_cache"] = False
56-
pipeline_args["config_path"] = os.path.join(
57-
config_folder, "training.yaml"
58-
)
59-
training.with_options(**pipeline_args)()
60-
training.with_options(**pipeline_args)()
61-
logger.info("Training pipeline finished successfully!\n\n")
6279

6380
if __name__ == "__main__":
6481
main()

native-experiment-tracking/steps/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,3 @@
3333
from .model_trainer import (
3434
model_trainer,
3535
)
36-
from .model_grid_search import (
37-
model_grid_searcher
38-
)

native-experiment-tracking/steps/model_grid_search.py

Lines changed: 0 additions & 153 deletions
This file was deleted.

0 commit comments

Comments
 (0)