Skip to content

Commit 41518e0

Browse files
committed
add demo (must be deleted)
1 parent bd104d7 commit 41518e0

File tree

9 files changed

+582
-0
lines changed

9 files changed

+582
-0
lines changed

demo/README.md

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# ZenML Implementation Guide
2+
3+
## Overview
4+
This guide outlines the step-by-step process for setting up and running the demonstrated ZenML pipeline with Neptune experiment tracking integration. The implementation follows a systematic approach to ensure reproducible machine learning workflows.
5+
6+
## Prerequisites
7+
- Python 3.9 or higher
8+
- Access to Neptune.ai account
9+
- ZenML cloud account
10+
11+
## Installation and Setup Process
12+
13+
### 1. Environment Setup
14+
First, create and activate a dedicated virtual environment:
15+
16+
```bash
17+
# Create virtual environment
18+
python -m venv .venv
19+
20+
# Activate virtual environment
21+
# For Unix/MacOS
22+
source .venv/bin/activate
23+
```
24+
25+
### 2. Dependencies Installation
26+
Install required packages from the requirements file:
27+
28+
```bash
29+
pip install -r requirements.txt
30+
```
31+
32+
### 3. ZenML Configuration
33+
Initialize and configure ZenML with the following steps:
34+
35+
```bash
36+
# Initialize ZenML in your project directory
37+
zenml init
38+
zenml integration install pytorch_lightning neptune
39+
40+
# Connect to ZenML cloud tenant (you can find this command in the overview page of your ZenML cloud tenant)
41+
zenml login 8a462fb6-b...
42+
43+
# Register Neptune experiment tracker
44+
zenml experiment-tracker register neptune_experiment_tracker \
45+
--flavor=neptune \
46+
--project="" \
47+
--api_token=""
48+
49+
# Register and configure stack
50+
zenml stack register neptune_stack \
51+
-o default \
52+
-a default \
53+
-e neptune_experiment_tracker
54+
55+
# Set as active stack
56+
zenml stack set neptune_stack
57+
```
58+
59+
### 4. Execute Pipeline
60+
Run the implementation:
61+
62+
```bash
63+
python run.py
64+
```
65+
66+
## Troubleshooting
67+
- Ensure all environment variables are properly set
68+
- Verify Neptune.ai credentials are correctly configured
69+
- Check ZenML stack status using `zenml stack list`
70+

demo/configs/config.yaml

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
model:
2+
name: cifar10_resnet18
3+
description: "Fine-tune with ResNet18 on CIFAR10 using PyTorch Lightning and Neptune in GCP"
4+
tags:
5+
- pytorch_lightning
6+
- demo
7+
- neptune
8+
- cifar10
9+
- gcp
10+
11+
settings:
12+
docker:
13+
#parent_image: pytorch/pytorch:2.4.1-cuda12.1-cudnn9-runtime
14+
python_package_installer: uv
15+
required_integrations:
16+
- pytorch
17+
- neptune
18+
- gcp
19+
- pytorch_lightning
20+
requirements:
21+
- torchvision
22+
- lightning
23+
#- zenml==0.73.0
24+
25+
parameters:
26+
# Data parameters
27+
batch_size: 256
28+
val_split: 0.2
29+
dataset_fraction: 0.05 # Use only 10% of the data for faster demo
30+
31+
# Training parameters
32+
epochs: 2
33+
learning_rate: 0.04

demo/pipelines/cifar10_pipeline.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from typing import Dict
2+
3+
from steps.data_loader import load_cifar10_data
4+
from steps.evaluator import evaluate_model
5+
from steps.trainer import train_model
6+
7+
from zenml import pipeline
8+
from zenml.config import DockerSettings
9+
from zenml.config.resource_settings import ResourceSettings
10+
from zenml.integrations.constants import PYTORCH
11+
from zenml.integrations.gcp.flavors.vertex_orchestrator_flavor import (
12+
VertexOrchestratorSettings,
13+
)
14+
15+
vertex_settings = VertexOrchestratorSettings(
16+
pod_settings={
17+
"node_selectors": {
18+
"cloud.google.com/gke-accelerator": "NVIDIA_TESLA_V100",
19+
},
20+
}
21+
)
22+
#resource_settings = ResourceSettings(gpu_count=1)
23+
resource_settings = ResourceSettings(cpu_count=16, memory="32GB")
24+
@pipeline(
25+
settings={
26+
#"orchestrator": vertex_settings,
27+
"resources": resource_settings,
28+
},
29+
enable_cache=True
30+
)
31+
def cifar10_pipeline(
32+
batch_size: int = 256,
33+
val_split: float = 0.2,
34+
dataset_fraction: float = 0.05, # Control dataset size
35+
epochs: int = 5,
36+
learning_rate: float = 0.05
37+
) -> Dict[str, float]:
38+
"""Training pipeline for CIFAR10 image classification.
39+
40+
Args:
41+
batch_size: The batch size for training and evaluation.
42+
val_split: The fraction of the dataset to use for validation.
43+
dataset_fraction: The fraction of total dataset to use (for faster demo).
44+
epochs: The number of epochs to train the model.
45+
learning_rate: The learning rate for the optimizer.
46+
47+
Returns:
48+
A dictionary containing the test loss and accuracy.
49+
"""
50+
# Load and prepare data
51+
train_dataloader, val_dataloader, test_dataloader = load_cifar10_data(
52+
batch_size=batch_size,
53+
val_split=val_split,
54+
dataset_fraction=dataset_fraction
55+
)
56+
57+
# Train model
58+
model = train_model(
59+
train_dataloader=train_dataloader,
60+
val_dataloader=val_dataloader,
61+
epochs=epochs,
62+
lr=learning_rate,
63+
)
64+
65+
# Evaluate model
66+
metrics = evaluate_model(
67+
model=model,
68+
test_dataloader=test_dataloader
69+
)
70+
71+
return metrics

demo/requirements.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
torch
2+
torchvision
3+
torchmetrics
4+
zenml
5+
click
6+
pyyaml
7+
torchvision
8+
lightning

demo/run.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import os
2+
from typing import Optional
3+
4+
import click
5+
import yaml
6+
from pipelines.cifar10_pipeline import cifar10_pipeline
7+
8+
from zenml.client import Client
9+
from zenml.config.schedule import Schedule
10+
from zenml.integrations.neptune.experiment_trackers import (
11+
NeptuneExperimentTracker,
12+
)
13+
14+
15+
@click.command(
16+
help="""
17+
ZenML CIFAR10 Training Demo CLI.
18+
19+
Run the ZenML CIFAR10 image classification training pipeline.
20+
21+
Examples:
22+
23+
\b
24+
# Run the pipeline with default config
25+
python run.py
26+
27+
\b
28+
# Run the pipeline with custom config
29+
python run.py --config custom_config.yaml
30+
31+
\b
32+
# Run without caching
33+
python run.py --no-cache
34+
"""
35+
)
36+
@click.option(
37+
"--config-path",
38+
type=str,
39+
default="configs/config.yaml",
40+
help="Path to the YAML config file.",
41+
)
42+
@click.option(
43+
"--no-cache",
44+
is_flag=True,
45+
default=False,
46+
help="Disable caching for the pipeline run.",
47+
)
48+
def main(config_path: Optional[str] = None, no_cache: bool = False) -> None:
49+
"""Main entry point for the pipeline execution.
50+
51+
Args:
52+
config: Path to the YAML config file.
53+
no_cache: If True, disable caching.
54+
"""
55+
if not config_path:
56+
raise RuntimeError("Config file is required to run the pipeline.")
57+
58+
# Ensure config path is absolute
59+
if not os.path.isabs(config_path):
60+
config_path = os.path.join(
61+
os.path.dirname(os.path.realpath(__file__)),
62+
config_path
63+
)
64+
65+
# Load configuration
66+
with open(config_path, "r") as f:
67+
config_dict = yaml.safe_load(f)
68+
69+
# Ensure neptune experiment tracker is active
70+
stack = Client().active_stack
71+
if not isinstance(stack.experiment_tracker, NeptuneExperimentTracker):
72+
raise RuntimeError(
73+
"This pipeline requires an Neptune experiment tracker in the active stack. "
74+
"Please run: zenml experiment-tracker register neptune"
75+
)
76+
77+
# Run the pipeline
78+
pipeline_args = {"enable_cache": not no_cache}
79+
pipeline_args["config_path"] = config_path
80+
metrics = cifar10_pipeline.with_options(**pipeline_args,)(
81+
batch_size=config_dict["parameters"]["batch_size"],
82+
val_split=config_dict["parameters"]["val_split"],
83+
dataset_fraction=config_dict["parameters"]["dataset_fraction"],
84+
epochs=config_dict["parameters"]["epochs"],
85+
learning_rate=config_dict["parameters"]["learning_rate"],
86+
)
87+
88+
click.echo("Training completed!")
89+
click.echo(f"Test metrics: {metrics}")
90+
91+
92+
if __name__ == "__main__":
93+
main()

demo/steps/data_loader.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import os
2+
import numpy as np
3+
import torch
4+
import torchvision
5+
from typing import Tuple, Annotated, List
6+
from torch.utils.data import DataLoader, random_split, Subset
7+
from zenml import step
8+
9+
# Constants
10+
PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
11+
BATCH_SIZE = 256 if torch.cuda.is_available() else 64
12+
NUM_WORKERS = int(os.cpu_count() / 2) if os.cpu_count() else 2
13+
14+
# Data normalization
15+
cifar10_normalization = torchvision.transforms.Normalize(
16+
mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
17+
std=[x / 255.0 for x in [63.0, 62.1, 66.7]],
18+
)
19+
20+
train_transforms = torchvision.transforms.Compose([
21+
torchvision.transforms.RandomCrop(32, padding=4),
22+
torchvision.transforms.RandomHorizontalFlip(),
23+
torchvision.transforms.ToTensor(),
24+
cifar10_normalization,
25+
])
26+
27+
test_transforms = torchvision.transforms.Compose([
28+
torchvision.transforms.ToTensor(),
29+
cifar10_normalization,
30+
])
31+
32+
def get_subset_indices(total_size: int, fraction: float) -> List[int]:
33+
"""Get random indices for subset of data.
34+
35+
Args:
36+
total_size: Total size of the dataset
37+
fraction: Fraction of data to use
38+
39+
Returns:
40+
List of indices for the subset
41+
"""
42+
num_samples = int(total_size * fraction)
43+
indices = np.random.permutation(total_size)[:num_samples].tolist()
44+
return indices
45+
46+
@step
47+
def load_cifar10_data(
48+
batch_size: int = BATCH_SIZE,
49+
val_split: float = 0.2,
50+
dataset_fraction: float = 0.05 # Use only 20% of the data by default
51+
) -> Tuple[
52+
Annotated[DataLoader, "train_dataloader"],
53+
Annotated[DataLoader, "val_dataloader"],
54+
Annotated[DataLoader, "test_dataloader"]
55+
]:
56+
"""Load and prepare CIFAR10 datasets.
57+
58+
Args:
59+
batch_size: Batch size for the dataloaders
60+
val_split: Fraction of training data to use for validation
61+
dataset_fraction: Fraction of total dataset to use (for faster demo)
62+
"""
63+
# Set random seed for reproducibility
64+
np.random.seed(42)
65+
66+
# Load full datasets
67+
dataset_train_full = torchvision.datasets.CIFAR10(PATH_DATASETS, train=True, download=True, transform=train_transforms)
68+
dataset_test_full = torchvision.datasets.CIFAR10(PATH_DATASETS, train=False, download=True, transform=test_transforms)
69+
70+
# Get subset indices
71+
train_indices = get_subset_indices(len(dataset_train_full), dataset_fraction)
72+
test_indices = get_subset_indices(len(dataset_test_full), dataset_fraction)
73+
# Create subsets
74+
dataset_train = Subset(dataset_train_full, train_indices)
75+
dataset_test = Subset(dataset_test_full, test_indices)
76+
77+
# Split training into train and validation
78+
train_length = int(len(dataset_train) * (1 - val_split))
79+
val_length = len(dataset_train) - train_length
80+
dataset_train, dataset_val = random_split(
81+
dataset_train,
82+
[train_length, val_length],
83+
generator=torch.Generator().manual_seed(42)
84+
)
85+
86+
print(f"Dataset sizes:")
87+
print(f"Original training set: {len(dataset_train_full)} samples")
88+
print(f"Original test set: {len(dataset_test_full)} samples")
89+
print(f"After {dataset_fraction*100:.1f}% subset:")
90+
print(f" Training: {len(dataset_train)} samples")
91+
print(f" Validation: {len(dataset_val)} samples")
92+
print(f" Test: {len(dataset_test)} samples")
93+
94+
# Create dataloaders
95+
train_dataloader = DataLoader(
96+
dataset_train,
97+
batch_size=batch_size,
98+
shuffle=True,
99+
num_workers=NUM_WORKERS
100+
)
101+
val_dataloader = DataLoader(
102+
dataset_val,
103+
batch_size=batch_size,
104+
shuffle=False,
105+
num_workers=NUM_WORKERS
106+
)
107+
test_dataloader = DataLoader(
108+
dataset_test,
109+
batch_size=batch_size,
110+
shuffle=False,
111+
num_workers=NUM_WORKERS
112+
)
113+
114+
return train_dataloader, val_dataloader, test_dataloader

0 commit comments

Comments
 (0)