diff --git a/benchmarks/imagenet1k/nnclr-resnet50.py b/benchmarks/imagenet1k/nnclr-resnet50.py new file mode 100644 index 00000000..1b1f949c --- /dev/null +++ b/benchmarks/imagenet1k/nnclr-resnet50.py @@ -0,0 +1,185 @@ +import lightning as pl +import torch +import torchmetrics +from lightning.pytorch.loggers import WandbLogger +from torch import nn + +import stable_pretraining as spt +from stable_pretraining.data import transforms +from stable_pretraining.forward import nnclr_forward +from stable_pretraining.callbacks.queue import OnlineQueue +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent)) +from utils import get_data_dir + +nnclr_transform = transforms.MultiViewTransform( + [ + transforms.Compose( + transforms.RGB(), + transforms.RandomResizedCrop((224, 224), scale=(0.08, 1.0)), + transforms.ColorJitter( + brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.8 + ), + transforms.RandomGrayscale(p=0.2), + transforms.PILGaussianBlur(p=1.0), + transforms.RandomHorizontalFlip(p=0.5), + transforms.ToImage(**spt.data.static.ImageNet), + ), + transforms.Compose( + transforms.RGB(), + transforms.RandomResizedCrop((224, 224), scale=(0.08, 1.0)), + transforms.ColorJitter( + brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.8 + ), + transforms.RandomGrayscale(p=0.2), + transforms.PILGaussianBlur(p=0.1), + transforms.RandomSolarize(threshold=0.5, p=0.2), + transforms.RandomHorizontalFlip(p=0.5), + transforms.ToImage(**spt.data.static.ImageNet), + ), + ] +) + +val_transform = transforms.Compose( + transforms.RGB(), + transforms.Resize((256, 256)), + transforms.CenterCrop((224, 224)), + transforms.ToImage(**spt.data.static.ImageNet), +) + +train_dataset = spt.data.HFDataset( + "imagenet-1k", + split="train", + cache_dir=str(get_data_dir()), + transform=nnclr_transform, +) +val_dataset = spt.data.HFDataset( + "imagenet-1k", + split="validation", + cache_dir=str(get_data_dir()), + transform=val_transform, +) + +total_batch_size, world_size, num_epochs = 4096, 4, 400 +local_batch_size = total_batch_size // world_size +train_dataloader = torch.utils.data.DataLoader( + dataset=train_dataset, + sampler=spt.data.sampler.RepeatedRandomSampler(train_dataset, n_views=2), + batch_size=local_batch_size, + num_workers=64, + drop_last=True, + persistent_workers=True, +) +val_dataloader = torch.utils.data.DataLoader( + dataset=val_dataset, + batch_size=local_batch_size, + num_workers=32, + persistent_workers=True, +) + +data = spt.data.DataModule(train=train_dataloader, val=val_dataloader) + +backbone = spt.backbone.from_torchvision( + "resnet50", + low_resolution=False, +) +backbone.fc = torch.nn.Identity() + +projector = nn.Sequential( + nn.Linear(2048, 2048), + nn.BatchNorm1d(2048), + nn.ReLU(inplace=True), + nn.Linear(2048, 2048), + nn.BatchNorm1d(2048), + nn.ReLU(inplace=True), + nn.Linear(2048, 256), +) + +predictor = nn.Sequential( + nn.Linear(256, 4096), + nn.BatchNorm1d(4096), + nn.ReLU(inplace=True), + nn.Linear(4096, 256), +) + + +module = spt.Module( + backbone=backbone, + projector=projector, + predictor=predictor, + forward=nnclr_forward, + nnclr_loss=spt.losses.NTXEntLoss(temperature=0.1), + optim={ + "optimizer": { + "type": "LARS", + "lr": 0.3 * (total_batch_size / 256), + "weight_decay": 1e-6, + "clip_lr": True, + "eta": 0.001, + "exclude_bias_n_norm": True, + }, + "scheduler": { + "type": "LinearWarmupCosineAnnealing", + "peak_step": 10 / num_epochs, + "total_steps": num_epochs * (len(train_dataloader) // world_size), + }, + "interval": "epoch", + }, + hparams={ + "support_set_size": 98304, # was 16384 + "projection_dim": 256, # keep + }, +) + +linear_probe = spt.callbacks.OnlineProbe( + name="linear_probe", + input="embedding", + target="label", + probe=torch.nn.Linear(2048, 1000), + loss_fn=torch.nn.CrossEntropyLoss(), + metrics={ + "top1": torchmetrics.classification.MulticlassAccuracy(1000), + "top5": torchmetrics.classification.MulticlassAccuracy(1000, top_k=5), + }, +) + +knn_probe = spt.callbacks.OnlineKNN( + name="knn_probe", + input="embedding", + target="label", + queue_length=20000, + metrics={"accuracy": torchmetrics.classification.MulticlassAccuracy(1000)}, + input_dim=2048, + k=20, +) + +support_queue = OnlineQueue( + key="nnclr_support_set", + queue_length=module.hparams.support_set_size, + dim=module.hparams.projection_dim, +) + +wandb_logger = WandbLogger( + entity="samibg", + project="imagenet-1k-nnclr", + name=f"nnclr-resnet50-{world_size}gpus", + log_model=False, +) + +# --- Trainer --- +trainer = pl.Trainer( + max_epochs=num_epochs, + num_sanity_val_steps=0, + callbacks=[linear_probe, knn_probe, support_queue], + precision="16-mixed", + logger=wandb_logger, + enable_checkpointing=True, + accelerator="gpu", + devices=world_size, + sync_batchnorm=world_size > 1, +) + +manager = spt.Manager(trainer=trainer, module=module, data=data) +manager() \ No newline at end of file diff --git a/benchmarks/imagenet1k/simclr-resnet50.py b/benchmarks/imagenet1k/simclr-resnet50.py new file mode 100644 index 00000000..9568fb5c --- /dev/null +++ b/benchmarks/imagenet1k/simclr-resnet50.py @@ -0,0 +1,169 @@ +import lightning as pl +import torch +import torchmetrics +from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.loggers import WandbLogger +from torch import nn + +import stable_pretraining as spt +from stable_pretraining.data import transforms +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent)) +from utils import get_data_dir + +# SimCLR augmentations, as described in the paper +simclr_transform = transforms.MultiViewTransform( + [ + transforms.Compose( + transforms.RGB(), + transforms.RandomResizedCrop((224, 224), scale=(0.08, 1.0), ratio=(3/4, 4/3)), + transforms.ColorJitter( + brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2, p=0.8 + ), + transforms.RandomGrayscale(p=0.2), + transforms.PILGaussianBlur(sigma=(0.1, 2.0), p=0.5), + transforms.RandomHorizontalFlip(p=0.5), + transforms.ToImage(**spt.data.static.ImageNet), + ), + transforms.Compose( + transforms.RGB(), + transforms.RandomResizedCrop((224, 224), scale=(0.08, 1.0), ratio=(3/4, 4/3)), + transforms.ColorJitter( + brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2, p=0.8 + ), + transforms.RandomGrayscale(p=0.2), + transforms.PILGaussianBlur(sigma=(0.1, 2.0), p=0.5), + transforms.RandomHorizontalFlip(p=0.5), + transforms.ToImage(**spt.data.static.ImageNet), + ), + ] +) + +val_transform = transforms.Compose( + transforms.RGB(), + transforms.Resize((256, 256)), + transforms.CenterCrop((224, 224)), + transforms.ToImage(**spt.data.static.ImageNet), +) + +# Using a standard ImageNet-1k dataset +train_dataset = spt.data.HFDataset( + "imagenet-1k", + split="train", + cache_dir=str(get_data_dir()), + transform=simclr_transform, +) +val_dataset = spt.data.HFDataset( + "imagenet-1k", + split="validation", + cache_dir=str(get_data_dir()), + transform=val_transform, +) + +# Batch size from the paper (adjust if necessary) +total_batch_size, world_size, num_epochs = 4096, 8, 800 +local_batch_size = total_batch_size // world_size + +train_dataloader = torch.utils.data.DataLoader( + dataset=train_dataset, + batch_size=local_batch_size, + num_workers=64, + drop_last=True, + persistent_workers=True, +) +val_dataloader = torch.utils.data.DataLoader( + dataset=val_dataset, + batch_size=local_batch_size, + num_workers=32, + persistent_workers=True, +) + +data = spt.data.DataModule(train=train_dataloader, val=val_dataloader) + +# Using ResNet-50 as in the paper +backbone = spt.backbone.from_torchvision( + "resnet50", + low_resolution=False, +) +backbone.fc = torch.nn.Identity() + +# Projector network as described in SimCLR paper +projector = nn.Sequential( + nn.Linear(2048, 2048), + nn.ReLU(inplace=True), + nn.Linear(2048, 128), +) + +module = spt.Module( + backbone=backbone, + projector=projector, + forward=spt.forward.simclr_forward, + # Temperature can be tuned, 0.1 is a common value + simclr_loss=spt.losses.NTXEntLoss(temperature=0.1), + optim={ + "optimizer": { + "type": "LARS", + "lr": 0.3 + * (total_batch_size / 256), # 256 is base batch size they use in SimCLR + "weight_decay": 1e-6, + "clip_lr": True, + "eta": 1e-3, + "exclude_bias_n_norm": True, + }, + "scheduler": { + "type": "LinearWarmupCosineAnnealing", + "peak_step": 10 / num_epochs, # 10 epochs warmup + "total_steps": num_epochs * (len(train_dataloader) // world_size), + }, + "interval": "epoch", + }, +) + +# Update probes for 1000 classes +linear_probe = spt.callbacks.OnlineProbe( + name="linear_probe", + input="embedding", + target="label", + probe=torch.nn.Linear(2048, 1000), + loss_fn=torch.nn.CrossEntropyLoss(), + metrics={ + "top1": torchmetrics.classification.MulticlassAccuracy(1000), + "top5": torchmetrics.classification.MulticlassAccuracy(1000, top_k=5), + }, +) + +knn_probe = spt.callbacks.OnlineKNN( + name="knn_probe", + input="embedding", + target="label", + queue_length=20000, + metrics={"accuracy": torchmetrics.classification.MulticlassAccuracy(1000)}, + input_dim=2048, + k=20, +) + +wandb_logger = WandbLogger( + entity="samibg", + project="imagenet1k-simclr", + name="simclr-resnet50-4gpu", + log_model=False, +) + +trainer = pl.Trainer( + max_epochs=num_epochs, + num_sanity_val_steps=0, + callbacks=[ + linear_probe, + knn_probe, + ], + precision="16-mixed", + logger=wandb_logger, + enable_checkpointing=True, + accelerator="gpu", + sync_batchnorm=world_size > 1, +) + +manager = spt.Manager(trainer=trainer, module=module, data=data) +manager() diff --git a/stable_pretraining/__init__.py b/stable_pretraining/__init__.py index ca30b716..7ed64f80 100644 --- a/stable_pretraining/__init__.py +++ b/stable_pretraining/__init__.py @@ -25,7 +25,7 @@ except ImportError: WANDB_AVAILABLE = False -from . import backbone, callbacks, data, losses, module, optim, static, utils +from . import backbone, callbacks, data, forward, losses, module, optim, static, utils from .__about__ import ( __author__, __license__, @@ -81,6 +81,7 @@ "module", "static", "optim", + "forward", "losses", "callbacks", "backbone", diff --git a/stable_pretraining/data/transforms.py b/stable_pretraining/data/transforms.py index d8635f6f..6f0658f7 100644 --- a/stable_pretraining/data/transforms.py +++ b/stable_pretraining/data/transforms.py @@ -210,7 +210,7 @@ class PILGaussianBlur(Transform): _NAMES = ["sigma_x", "sigma_y"] - def __init__(self, sigma=None, p=1, source: str = "image", target: str = "image"): + def __init__(self, sigma=None, p=1., source: str = "image", target: str = "image"): """Gaussian blur as a callable object. Args: diff --git a/stable_pretraining/manager.py b/stable_pretraining/manager.py index 0ef1ba3e..2fff46cd 100644 --- a/stable_pretraining/manager.py +++ b/stable_pretraining/manager.py @@ -629,4 +629,4 @@ def _configure_checkpointing(self) -> None: if is_slurm_job: logging.error( "\t\t CRITICAL SLURM WARNING: This job will lose all progress if it is preempted or requeued. It is highly recommended to configure checkpointing." - ) + ) \ No newline at end of file