Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 185 additions & 0 deletions benchmarks/imagenet1k/nnclr-resnet50.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import lightning as pl
import torch
import torchmetrics
from lightning.pytorch.loggers import WandbLogger
from torch import nn

import stable_pretraining as spt
from stable_pretraining.data import transforms
from stable_pretraining.forward import nnclr_forward
from stable_pretraining.callbacks.queue import OnlineQueue
import sys
from pathlib import Path

sys.path.append(str(Path(__file__).parent.parent))
from utils import get_data_dir

nnclr_transform = transforms.MultiViewTransform(
[
transforms.Compose(
transforms.RGB(),
transforms.RandomResizedCrop((224, 224), scale=(0.08, 1.0)),
transforms.ColorJitter(
brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.8
),
transforms.RandomGrayscale(p=0.2),
transforms.PILGaussianBlur(p=1.0),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ToImage(**spt.data.static.ImageNet),
),
transforms.Compose(
transforms.RGB(),
transforms.RandomResizedCrop((224, 224), scale=(0.08, 1.0)),
transforms.ColorJitter(
brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.8
),
transforms.RandomGrayscale(p=0.2),
transforms.PILGaussianBlur(p=0.1),
transforms.RandomSolarize(threshold=0.5, p=0.2),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ToImage(**spt.data.static.ImageNet),
),
]
)

val_transform = transforms.Compose(
transforms.RGB(),
transforms.Resize((256, 256)),
transforms.CenterCrop((224, 224)),
transforms.ToImage(**spt.data.static.ImageNet),
)

train_dataset = spt.data.HFDataset(
"imagenet-1k",
split="train",
cache_dir=str(get_data_dir()),
transform=nnclr_transform,
)
val_dataset = spt.data.HFDataset(
"imagenet-1k",
split="validation",
cache_dir=str(get_data_dir()),
transform=val_transform,
)

total_batch_size, world_size, num_epochs = 4096, 4, 400
local_batch_size = total_batch_size // world_size
train_dataloader = torch.utils.data.DataLoader(
dataset=train_dataset,
sampler=spt.data.sampler.RepeatedRandomSampler(train_dataset, n_views=2),
batch_size=local_batch_size,
num_workers=64,
drop_last=True,
persistent_workers=True,
)
val_dataloader = torch.utils.data.DataLoader(
dataset=val_dataset,
batch_size=local_batch_size,
num_workers=32,
persistent_workers=True,
)

data = spt.data.DataModule(train=train_dataloader, val=val_dataloader)

backbone = spt.backbone.from_torchvision(
"resnet50",
low_resolution=False,
)
backbone.fc = torch.nn.Identity()

projector = nn.Sequential(
nn.Linear(2048, 2048),
nn.BatchNorm1d(2048),
nn.ReLU(inplace=True),
nn.Linear(2048, 2048),
nn.BatchNorm1d(2048),
nn.ReLU(inplace=True),
nn.Linear(2048, 256),
)

predictor = nn.Sequential(
nn.Linear(256, 4096),
nn.BatchNorm1d(4096),
nn.ReLU(inplace=True),
nn.Linear(4096, 256),
)


module = spt.Module(
backbone=backbone,
projector=projector,
predictor=predictor,
forward=nnclr_forward,
nnclr_loss=spt.losses.NTXEntLoss(temperature=0.1),
optim={
"optimizer": {
"type": "LARS",
"lr": 0.3 * (total_batch_size / 256),
"weight_decay": 1e-6,
"clip_lr": True,
"eta": 0.001,
"exclude_bias_n_norm": True,
},
"scheduler": {
"type": "LinearWarmupCosineAnnealing",
"peak_step": 10 / num_epochs,
"total_steps": num_epochs * (len(train_dataloader) // world_size),
},
"interval": "epoch",
},
hparams={
"support_set_size": 98304, # was 16384
"projection_dim": 256, # keep
},
)

linear_probe = spt.callbacks.OnlineProbe(
name="linear_probe",
input="embedding",
target="label",
probe=torch.nn.Linear(2048, 1000),
loss_fn=torch.nn.CrossEntropyLoss(),
metrics={
"top1": torchmetrics.classification.MulticlassAccuracy(1000),
"top5": torchmetrics.classification.MulticlassAccuracy(1000, top_k=5),
},
)

knn_probe = spt.callbacks.OnlineKNN(
name="knn_probe",
input="embedding",
target="label",
queue_length=20000,
metrics={"accuracy": torchmetrics.classification.MulticlassAccuracy(1000)},
input_dim=2048,
k=20,
)

support_queue = OnlineQueue(
key="nnclr_support_set",
queue_length=module.hparams.support_set_size,
dim=module.hparams.projection_dim,
)

wandb_logger = WandbLogger(
entity="samibg",
project="imagenet-1k-nnclr",
name=f"nnclr-resnet50-{world_size}gpus",
log_model=False,
)

# --- Trainer ---
trainer = pl.Trainer(
max_epochs=num_epochs,
num_sanity_val_steps=0,
callbacks=[linear_probe, knn_probe, support_queue],
precision="16-mixed",
logger=wandb_logger,
enable_checkpointing=True,
accelerator="gpu",
devices=world_size,
sync_batchnorm=world_size > 1,
)

manager = spt.Manager(trainer=trainer, module=module, data=data)
manager()
169 changes: 169 additions & 0 deletions benchmarks/imagenet1k/simclr-resnet50.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import lightning as pl
import torch
import torchmetrics
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger
from torch import nn

import stable_pretraining as spt
from stable_pretraining.data import transforms
import sys
from pathlib import Path

sys.path.append(str(Path(__file__).parent.parent))
from utils import get_data_dir

# SimCLR augmentations, as described in the paper
simclr_transform = transforms.MultiViewTransform(
[
transforms.Compose(
transforms.RGB(),
transforms.RandomResizedCrop((224, 224), scale=(0.08, 1.0), ratio=(3/4, 4/3)),
transforms.ColorJitter(
brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2, p=0.8
),
transforms.RandomGrayscale(p=0.2),
transforms.PILGaussianBlur(sigma=(0.1, 2.0), p=0.5),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ToImage(**spt.data.static.ImageNet),
),
transforms.Compose(
transforms.RGB(),
transforms.RandomResizedCrop((224, 224), scale=(0.08, 1.0), ratio=(3/4, 4/3)),
transforms.ColorJitter(
brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2, p=0.8
),
transforms.RandomGrayscale(p=0.2),
transforms.PILGaussianBlur(sigma=(0.1, 2.0), p=0.5),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ToImage(**spt.data.static.ImageNet),
),
]
)

val_transform = transforms.Compose(
transforms.RGB(),
transforms.Resize((256, 256)),
transforms.CenterCrop((224, 224)),
transforms.ToImage(**spt.data.static.ImageNet),
)

# Using a standard ImageNet-1k dataset
train_dataset = spt.data.HFDataset(
"imagenet-1k",
split="train",
cache_dir=str(get_data_dir()),
transform=simclr_transform,
)
val_dataset = spt.data.HFDataset(
"imagenet-1k",
split="validation",
cache_dir=str(get_data_dir()),
transform=val_transform,
)

# Batch size from the paper (adjust if necessary)
total_batch_size, world_size, num_epochs = 4096, 8, 800
local_batch_size = total_batch_size // world_size

train_dataloader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=local_batch_size,
num_workers=64,
drop_last=True,
persistent_workers=True,
)
val_dataloader = torch.utils.data.DataLoader(
dataset=val_dataset,
batch_size=local_batch_size,
num_workers=32,
persistent_workers=True,
)

data = spt.data.DataModule(train=train_dataloader, val=val_dataloader)

# Using ResNet-50 as in the paper
backbone = spt.backbone.from_torchvision(
"resnet50",
low_resolution=False,
)
backbone.fc = torch.nn.Identity()

# Projector network as described in SimCLR paper
projector = nn.Sequential(
nn.Linear(2048, 2048),
nn.ReLU(inplace=True),
nn.Linear(2048, 128),
)

module = spt.Module(
backbone=backbone,
projector=projector,
forward=spt.forward.simclr_forward,
# Temperature can be tuned, 0.1 is a common value
simclr_loss=spt.losses.NTXEntLoss(temperature=0.1),
optim={
"optimizer": {
"type": "LARS",
"lr": 0.3
* (total_batch_size / 256), # 256 is base batch size they use in SimCLR
"weight_decay": 1e-6,
"clip_lr": True,
"eta": 1e-3,
"exclude_bias_n_norm": True,
},
"scheduler": {
"type": "LinearWarmupCosineAnnealing",
"peak_step": 10 / num_epochs, # 10 epochs warmup
"total_steps": num_epochs * (len(train_dataloader) // world_size),
},
"interval": "epoch",
},
)

# Update probes for 1000 classes
linear_probe = spt.callbacks.OnlineProbe(
name="linear_probe",
input="embedding",
target="label",
probe=torch.nn.Linear(2048, 1000),
loss_fn=torch.nn.CrossEntropyLoss(),
metrics={
"top1": torchmetrics.classification.MulticlassAccuracy(1000),
"top5": torchmetrics.classification.MulticlassAccuracy(1000, top_k=5),
},
)

knn_probe = spt.callbacks.OnlineKNN(
name="knn_probe",
input="embedding",
target="label",
queue_length=20000,
metrics={"accuracy": torchmetrics.classification.MulticlassAccuracy(1000)},
input_dim=2048,
k=20,
)

wandb_logger = WandbLogger(
entity="samibg",
project="imagenet1k-simclr",
name="simclr-resnet50-4gpu",
log_model=False,
)

trainer = pl.Trainer(
max_epochs=num_epochs,
num_sanity_val_steps=0,
callbacks=[
linear_probe,
knn_probe,
],
precision="16-mixed",
logger=wandb_logger,
enable_checkpointing=True,
accelerator="gpu",
sync_batchnorm=world_size > 1,
)

manager = spt.Manager(trainer=trainer, module=module, data=data)
manager()
3 changes: 2 additions & 1 deletion stable_pretraining/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
except ImportError:
WANDB_AVAILABLE = False

from . import backbone, callbacks, data, losses, module, optim, static, utils
from . import backbone, callbacks, data, forward, losses, module, optim, static, utils
from .__about__ import (
__author__,
__license__,
Expand Down Expand Up @@ -81,6 +81,7 @@
"module",
"static",
"optim",
"forward",
"losses",
"callbacks",
"backbone",
Expand Down
2 changes: 1 addition & 1 deletion stable_pretraining/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ class PILGaussianBlur(Transform):

_NAMES = ["sigma_x", "sigma_y"]

def __init__(self, sigma=None, p=1, source: str = "image", target: str = "image"):
def __init__(self, sigma=None, p=1., source: str = "image", target: str = "image"):
"""Gaussian blur as a callable object.

Args:
Expand Down
2 changes: 1 addition & 1 deletion stable_pretraining/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,4 +629,4 @@ def _configure_checkpointing(self) -> None:
if is_slurm_job:
logging.error(
"\t\t CRITICAL SLURM WARNING: This job will lose all progress if it is preempted or requeued. It is highly recommended to configure checkpointing."
)
)
Loading