diff --git a/assets/benchmarks/cifar10/byol.yaml b/assets/benchmarks/cifar10/byol.yaml index 7a3c227b..583c6fda 100644 --- a/assets/benchmarks/cifar10/byol.yaml +++ b/assets/benchmarks/cifar10/byol.yaml @@ -13,14 +13,14 @@ trainer: # ===== Module Parameters ===== module: backbone: - _target_: stable_ssl.modules.TeacherStudentModule + _target_: stable_ssl.backbone.utils.TeacherStudentWrapper student: _target_: stable_ssl.modules.load_backbone name: resnet50 low_resolution: True num_classes: null projector: - _target_: stable_ssl.modules.TeacherStudentModule + _target_: stable_ssl.backbone.utils.TeacherStudentWrapper student: _target_: stable_ssl.modules.MLP sizes: [2048, 2048, 128] diff --git a/assets/benchmarks/cifar10/dino.yaml b/assets/benchmarks/cifar10/dino.yaml index c2e3ec8e..d7b89869 100644 --- a/assets/benchmarks/cifar10/dino.yaml +++ b/assets/benchmarks/cifar10/dino.yaml @@ -9,14 +9,14 @@ trainer: # ===== Module Parameters ===== module: backbone: - _target_: stable_ssl.modules.TeacherStudentModule + _target_: stable_ssl.backbone.utils.TeacherStudentWrapper student: _target_: stable_ssl.modules.load_backbone name: resnet50 low_resolution: True num_classes: null projector: - _target_: stable_ssl.modules.TeacherStudentModule + _target_: stable_ssl.backbone.utils.TeacherStudentWrapper student: _target_: stable_ssl.modules.MLP sizes: [2048, 2048, 2048] diff --git a/assets/benchmarks/cifar100/byol.yaml b/assets/benchmarks/cifar100/byol.yaml index d727d3b1..c601f8d4 100644 --- a/assets/benchmarks/cifar100/byol.yaml +++ b/assets/benchmarks/cifar100/byol.yaml @@ -13,14 +13,14 @@ trainer: # ===== Module Parameters ===== module: backbone: - _target_: stable_ssl.modules.TeacherStudentModule + _target_: stable_ssl.backbone.utils.TeacherStudentWrapper student: _target_: stable_ssl.modules.load_backbone name: resnet50 low_resolution: True num_classes: null projector: - _target_: stable_ssl.modules.TeacherStudentModule + _target_: stable_ssl.backbone.utils.TeacherStudentWrapper student: _target_: stable_ssl.modules.MLP sizes: [2048, 2048, 128] diff --git a/assets/benchmarks/cifar100/dino.yaml b/assets/benchmarks/cifar100/dino.yaml index c2e3ec8e..d7b89869 100644 --- a/assets/benchmarks/cifar100/dino.yaml +++ b/assets/benchmarks/cifar100/dino.yaml @@ -9,14 +9,14 @@ trainer: # ===== Module Parameters ===== module: backbone: - _target_: stable_ssl.modules.TeacherStudentModule + _target_: stable_ssl.backbone.utils.TeacherStudentWrapper student: _target_: stable_ssl.modules.load_backbone name: resnet50 low_resolution: True num_classes: null projector: - _target_: stable_ssl.modules.TeacherStudentModule + _target_: stable_ssl.backbone.utils.TeacherStudentWrapper student: _target_: stable_ssl.modules.MLP sizes: [2048, 2048, 2048] diff --git a/assets/benchmarks/imagenette/byol.yaml b/assets/benchmarks/imagenette/byol.yaml index ce045e8b..fd3f8d5a 100644 --- a/assets/benchmarks/imagenette/byol.yaml +++ b/assets/benchmarks/imagenette/byol.yaml @@ -13,14 +13,14 @@ trainer: # ===== Module Parameters ===== module: backbone: - _target_: stable_ssl.modules.TeacherStudentModule + _target_: stable_ssl.backbone.utils.TeacherStudentWrapper student: _target_: stable_ssl.modules.load_backbone name: resnet50 low_resolution: False num_classes: null projector: - _target_: stable_ssl.modules.TeacherStudentModule + _target_: stable_ssl.backbone.utils.TeacherStudentWrapper student: _target_: stable_ssl.modules.MLP sizes: [2048, 2048, 128] diff --git a/assets/benchmarks/imagenette/dino.yaml b/assets/benchmarks/imagenette/dino.yaml index 7d0afbca..724240fa 100644 --- a/assets/benchmarks/imagenette/dino.yaml +++ b/assets/benchmarks/imagenette/dino.yaml @@ -9,14 +9,14 @@ trainer: # ===== Module Parameters ===== module: backbone: - _target_: stable_ssl.modules.TeacherStudentModule + _target_: stable_ssl.backbone.utils.TeacherStudentWrapper student: _target_: stable_ssl.modules.load_backbone name: resnet50 low_resolution: False num_classes: null projector: - _target_: stable_ssl.modules.TeacherStudentModule + _target_: stable_ssl.backbone.utils.TeacherStudentWrapper student: _target_: stable_ssl.modules.MLP sizes: [2048, 2048, 2048] diff --git a/benchmarks/cifar10/vicreg-resnet18.py b/benchmarks/cifar10/vicreg-resnet18.py index 85fe5e7b..ffc7d3dc 100644 --- a/benchmarks/cifar10/vicreg-resnet18.py +++ b/benchmarks/cifar10/vicreg-resnet18.py @@ -153,10 +153,11 @@ def forward(self, batch, stage): ) wandb_logger = WandbLogger( - entity="stable-ssl", - project="cifar10-vicreg", - name="vicreg-resnet18", - log_model=False, + project="ijepa-cifar10", + entity="samibg", # Your W&B entity + name="vicreg-cifar10-run", + log_model=False, # Set to True if you want to save model artifacts + offline=False, # Ensure offline mode ) trainer = pl.Trainer( @@ -165,6 +166,7 @@ def forward(self, batch, stage): callbacks=[knn_probe, linear_probe], precision="16-mixed", logger=wandb_logger, + devices=1, enable_checkpointing=False, ) diff --git a/benchmarks/imagenet100/ijepa-vith.py b/benchmarks/imagenet100/ijepa-vith.py new file mode 100644 index 00000000..e7127063 --- /dev/null +++ b/benchmarks/imagenet100/ijepa-vith.py @@ -0,0 +1,381 @@ +import math +import torch +from torch import nn +import torch.nn.functional as F +from torch.nn.init import trunc_normal_ +from timm.models.vision_transformer import PatchEmbed, Block +import lightning.pytorch.loggers as pl_loggers +import lightning.pytorch as pl +from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor + +import torchmetrics +from stable_ssl.backbone.utils import TeacherStudentWrapper +from stable_ssl.callbacks.teacher_student import TeacherStudentCallback +from lightning.pytorch.strategies import DDPStrategy +from stable_ssl.data import transforms + + +import stable_ssl as ssl +import lightning as pl +from stable_ssl.utils.pos_embed import get_2d_sincos_pos_embed + +NUM_DEVICES = 1 + +# Match paper's setup exactly +EFFECTIVE_BATCH_SIZE = 2048 +BATCH_SIZE = EFFECTIVE_BATCH_SIZE // NUM_DEVICES + +EFFECTIVE_VAL_BATCH_SIZE = 16384 +VAL_BATCH_SIZE = EFFECTIVE_VAL_BATCH_SIZE // NUM_DEVICES + +# Use paper's learning rates directly +effective_lr = 0.001 +effective_probe_lr = 0.005 + + +def apply_masks(x: torch.Tensor, masks: list[torch.Tensor]) -> torch.Tensor: + B, N, D = x.shape + M = len(masks) + idx = torch.stack( + [m.to(x.device, dtype=torch.long) for m in masks], dim=1 + ) # [B, M, K] + x_exp = x.unsqueeze(1).expand(-1, M, -1, -1) # [B, M, N, D] + out = x_exp.gather(2, idx.unsqueeze(-1).expand(-1, -1, -1, D)) # [B, M, K, D] + return out.reshape(B * M, idx.size(-1), D) # [B*M, K, D] + +def repeat_interleave_batch(x: torch.Tensor, B, repeat): + N = x.shape[0] // B + x = x[:N*B] + return x.reshape(N, B, *x.shape[1:]) \ + .repeat_interleave(repeat, dim=1) \ + .reshape(N * B * repeat, *x.shape[1:]) + +def init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + +def fix_init_weight(blocks): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + +class IJEPA_ViT_Encoder(nn.Module): + def __init__( + self, + img_size=224, patch_size=14, embed_dim=768, + depth=12, num_heads=12, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.norm_layer = nn.LayerNorm + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, + in_chans=3, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + # -- + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim), requires_grad=False) + pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], + int(self.patch_embed.num_patches**.5), + cls_token=False) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + # -- + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, + mlp_ratio=4.0, qkv_bias=True, attn_drop=0.0, drop_path=0.0, + norm_layer=self.norm_layer + ) + for _ in range(depth) + ]) + self.norm = self.norm_layer(embed_dim) + self.apply(init_weights) + fix_init_weight(self.blocks) + + def forward(self, x, masks: None | list = None): + x = self.patch_embed(x) + x = x + self.pos_embed + + if masks is not None: + x = apply_masks(x, masks) + + for blk in self.blocks: + x = blk(x) + + return self.norm(x) + +class IJEPA_ViT_Predictor(nn.Module): + def __init__( + self, + num_patches, embed_dim=768, predictor_embed_dim=384, + depth=6, num_heads=12, + ): + super().__init__() + self.embed_dim = embed_dim + self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True) + self.mask_token = nn.Parameter(torch.zeros(1, 1, predictor_embed_dim)) + # -- + self.predictor_pos_embed = nn.Parameter(torch.zeros(1, num_patches, predictor_embed_dim), + requires_grad=False) + predictor_pos_embed = get_2d_sincos_pos_embed(self.predictor_pos_embed.shape[-1], + int(num_patches**.5), + cls_token=False) + self.predictor_pos_embed.data.copy_(torch.from_numpy(predictor_pos_embed).float().unsqueeze(0)) + # -- + self.predictor_blocks = nn.ModuleList([ + Block( + dim=predictor_embed_dim, num_heads=num_heads, + mlp_ratio=4, qkv_bias=True, + attn_drop=0.0, drop_path=0.0, + norm_layer=nn.LayerNorm + ) + for _ in range(depth) + ]) + self.predictor_norm = nn.LayerNorm(predictor_embed_dim) + self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True) + # ------ + trunc_normal_(self.mask_token, std=0.02) + self.apply(init_weights) + fix_init_weight(self.predictor_blocks) + + def forward(self, context_patches, masks_context: list, masks_target: list): + assert len(masks_context) == 1 + B = len(context_patches) // len(masks_context) + x = self.predictor_embed(context_patches) + x_pos_embed = self.predictor_pos_embed.repeat(B, 1, 1) + x += apply_masks(x_pos_embed, masks_context) + + _, N_ctxt, D = x.shape + + # -- concat mask tokens to x + pos_embs = self.predictor_pos_embed.repeat(B, 1, 1) + pos_embs = apply_masks(pos_embs, masks_target) + pos_embs = repeat_interleave_batch(pos_embs, B, repeat=len(masks_context)) + # -- + pred_tokens = self.mask_token.repeat(pos_embs.size(0), pos_embs.size(1), 1) + # -- + pred_tokens += pos_embs + x = x.repeat(len(masks_target), 1, 1) + x = torch.cat([x, pred_tokens], dim=1) + + # -- fwd prop + for blk in self.predictor_blocks: + x = blk(x) + x = self.predictor_norm(x) + + # -- return preds for mask tokens + x = x[:, N_ctxt:] + x = self.predictor_proj(x) + + return x + +def forward_ijepa(self: ssl.Module, batch: dict, stage: str) -> dict: + out = {} + + with torch.no_grad(): + # -- keys for rankme and linear probe + out['context_embeddings'] = self.context_encoder(batch['image'], masks = None) + out['meanpool'] = out['context_embeddings'].mean(dim=1) + out['flat'] = out['context_embeddings'].reshape(out['context_embeddings'].shape[0], -1) + + if not self.training: + return out + + target_patches = self.target_encoder(batch['image']) + target_patches = F.layer_norm(target_patches, (target_patches.size(-1),)) + target_patches = apply_masks(target_patches, batch['masks_target']) + out['target_embeddings'] = target_patches + + out['context_embeddings'] = self.context_encoder(batch['image'], [batch['mask_context']]) + out['predictions'] = self.predictor( + out['context_embeddings'], + [batch['mask_context']], + batch['masks_target'] + ) + out['loss'] = self.ijepa_loss(out['predictions'], out['target_embeddings']) + return out + + +inet1k_train = ssl.data.HFDataset( + path="clane9/imagenet-100", + split="train", + transform=transforms.Compose( + transforms.RGB(), + transforms.RandomResizedCrop((224, 224), scale=(0.3, 1.0)), + transforms.ContextTargetsMultiBlockMask( + patch_size=14, + context_scale=(0.85, 1.0), + context_aspect_ratio=(1.0, 1.0), + target_scales=((0.15, 0.2),) * 4, + target_aspect_ratios=((0.75, 1.5),) * 4, + min_keep=10, + ), + transforms.ToImage( + mean=(0.485, 0.456, 0.406), + std= (0.229, 0.224, 0.225), + ), + ), +) + +inet1k_val = ssl.data.HFDataset( + path="clane9/imagenet-100", + split="validation", + transform=transforms.Compose( + transforms.RGB(), + transforms.Resize((256, 256)), + transforms.CenterCrop((224, 224)), + transforms.ToImage( + mean=(0.485, 0.456, 0.406), + std= (0.229, 0.224, 0.225), + ), + ), +) + +# collate function that makes each mask have the same number of indices, so they can be batched +def standardize_masks(batch: list[dict]): + context_indices = [item.pop("mask_context") for item in batch] + target_indices = [item.pop("masks_target") for item in batch] + batch = torch.utils.data.default_collate(batch) + + min_keep_enc = min(len(ctx) for ctx in context_indices) + min_keep_pred = min( + len(block) for multiblock in target_indices for block in multiblock + ) + + context_batch = [ctx[:min_keep_enc] for ctx in context_indices] + target_batch = [ + [tgt[:min_keep_pred] for tgt in multiblock] for multiblock in target_indices + ] + + collated_masks_context = torch.utils.data.default_collate(context_batch) + collated_masks_target = torch.utils.data.default_collate(target_batch) + + batch["mask_context"] = collated_masks_context + batch["masks_target"] = collated_masks_target + return batch + +train = torch.utils.data.DataLoader( + dataset=inet1k_train, + batch_size=BATCH_SIZE, + shuffle=True, + num_workers=16, + drop_last=True, + collate_fn=standardize_masks, + pin_memory=True, + persistent_workers=True, +) + +val = torch.utils.data.DataLoader( + dataset=inet1k_val, + batch_size=VAL_BATCH_SIZE, + num_workers=16, + shuffle=False, +) + +encoder = IJEPA_ViT_Encoder() +predictor = IJEPA_ViT_Predictor(num_patches=encoder.patch_embed.num_patches) +target_encoder = TeacherStudentWrapper(encoder, base_ema_coefficient=0.996, final_ema_coefficient=1.0) + +ema_callback = TeacherStudentCallback(update_frequency=1, update_after_backward=True) +rankme = ssl.callbacks.RankMe( + name="rankme", target="flat", + queue_length=max(512, BATCH_SIZE), + target_shape=(encoder.embed_dim * encoder.patch_embed.num_patches) +) + +linear_probe = ssl.callbacks.OnlineProbe( + name=f'linear_probe', input='meanpool', target='label', + probe=torch.nn.Sequential( + torch.nn.BatchNorm1d(encoder.embed_dim, affine=False), + torch.nn.Linear(encoder.embed_dim, 100) + ), + loss_fn=torch.nn.CrossEntropyLoss(), + # optimizer={ + # "type": "LARS", + # "lr": effective_probe_lr, + # "weight_decay": 1e-6, + # }, + # scheduler={ + # "type": "StepLR", + # "step_size": 15 * len(train), + # "gamma": 1.0, + # }, + metrics={ + "top1": torchmetrics.classification.MulticlassAccuracy(100), + "top5": torchmetrics.classification.MulticlassAccuracy(100, top_k=5), + } + ) + +knn_probe = ssl.callbacks.OnlineKNN( + name="knn_probe", + input="meanpool", + target="label", + queue_length=20000, + metrics={"accuracy": torchmetrics.classification.MulticlassAccuracy(100)}, + input_dim=768, + k=20, +) + + +module = ssl.Module( + context_encoder=encoder, + target_encoder=target_encoder, + predictor=predictor, + forward=forward_ijepa, + ijepa_loss=F.smooth_l1_loss, + optim={ + "optimizer": { + "type": "AdamW", + "lr": effective_lr, + "weight_decay": 0.04, # TODO Scheduler to 0.4 + }, + "scheduler": { + "type": "LinearWarmupCosineAnnealing", + "peak_step": 40 * int(len(train) / NUM_DEVICES), + "start_factor": 1 / 5, + "total_steps": 300 * int(len(train) / NUM_DEVICES), + "end_lr": 1.0e-6, + }, + "interval": "step", + } +) + +trainer = pl.Trainer( + max_epochs=300, num_sanity_val_steps=0, + callbacks=[ + linear_probe, knn_probe, rankme, ema_callback, + ModelCheckpoint(monitor='train/loss', mode='min', save_top_k=1, save_last=True, every_n_epochs=10), + LearningRateMonitor(logging_interval='step') + ], + precision='16-mixed', + logger=pl_loggers.WandbLogger( + project="ijepa-cifar10", entity="samibg", name=f"new-ijepa-inet100-num-devices{NUM_DEVICES}", + log_model=False, offline=False, + ), + # enable_checkpointing=False, + accelerator="gpu", devices=NUM_DEVICES, gradient_clip_val=None, + strategy=DDPStrategy( + find_unused_parameters=True, # this is because only teacher's params are used in the teacher-student module + static_graph=True, + gradient_as_bucket_view=True, + ) +) + +data = ssl.data.DataModule(train=train, val=val) + +manager = ssl.Manager(trainer=trainer, module=module, data=data) + +if __name__ == "__main__": + manager() \ No newline at end of file diff --git a/benchmarks/imagenet100/ijepa_vith.py b/benchmarks/imagenet100/ijepa_vith.py new file mode 100644 index 00000000..63db525e --- /dev/null +++ b/benchmarks/imagenet100/ijepa_vith.py @@ -0,0 +1,389 @@ +import math +import torch +from torch import nn +import torch.nn.functional as F +from torch.nn.init import trunc_normal_ +from timm.models.vision_transformer import PatchEmbed, Block +import lightning.pytorch.loggers as pl_loggers +import lightning.pytorch as pl +from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor + +import torchmetrics +from stable_ssl.backbone.utils import TeacherStudentWrapper +from stable_ssl.callbacks.teacher_student import TeacherStudentCallback +from lightning.pytorch.strategies import DDPStrategy +from stable_ssl.data import transforms + +from stable_ssl.tests.scripts.masking import MaskCollator +import stable_ssl as ssl +import lightning as pl +from stable_ssl.utils.pos_embed import get_2d_sincos_pos_embed + +NUM_DEVICES = 4 +ZERO_LOSS = True + +# Match paper's setup exactly +EFFECTIVE_BATCH_SIZE = 2048 +BATCH_SIZE = EFFECTIVE_BATCH_SIZE // NUM_DEVICES + +EFFECTIVE_VAL_BATCH_SIZE = 16384 +VAL_BATCH_SIZE = EFFECTIVE_VAL_BATCH_SIZE // NUM_DEVICES + +# Use paper's learning rates directly +effective_lr = 0.001 +effective_probe_lr = 0.01 + + +def apply_masks(x: torch.Tensor, masks: list[torch.Tensor]) -> torch.Tensor: + B, N, D = x.shape + M = len(masks) + idx = torch.stack( + [m.to(x.device, dtype=torch.long) for m in masks], dim=1 + ) # [B, M, K] + x_exp = x.unsqueeze(1).expand(-1, M, -1, -1) # [B, M, N, D] + out = x_exp.gather(2, idx.unsqueeze(-1).expand(-1, -1, -1, D)) # [B, M, K, D] + return out.reshape(B * M, idx.size(-1), D) # [B*M, K, D] + +def repeat_interleave_batch(x: torch.Tensor, B, repeat): + N = x.shape[0] // B + x = x[:N*B] + return x.reshape(N, B, *x.shape[1:]) \ + .repeat_interleave(repeat, dim=1) \ + .reshape(N * B * repeat, *x.shape[1:]) + +def init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + +def fix_init_weight(blocks): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + +class IJEPA_ViT_Encoder(nn.Module): + def __init__( + self, + img_size=224, patch_size=14, embed_dim=768, + depth=12, num_heads=12, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.norm_layer = nn.LayerNorm + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, + in_chans=3, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + # -- + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim), requires_grad=False) + pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], + int(self.patch_embed.num_patches**.5), + cls_token=False) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + # -- + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, + mlp_ratio=4.0, qkv_bias=True, attn_drop=0.0, drop_path=0.0, + norm_layer=self.norm_layer + ) + for _ in range(depth) + ]) + self.norm = self.norm_layer(embed_dim) + self.apply(init_weights) + fix_init_weight(self.blocks) + + def forward(self, x, masks: None | list = None): + x = self.patch_embed(x) + x = x + self.pos_embed + + if masks is not None: + x = apply_masks(x, masks) + + for blk in self.blocks: + x = blk(x) + + return self.norm(x) + +class IJEPA_ViT_Predictor(nn.Module): + def __init__( + self, + num_patches, embed_dim=768, predictor_embed_dim=384, + depth=6, num_heads=12, + ): + super().__init__() + self.embed_dim = embed_dim + self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True) + self.mask_token = nn.Parameter(torch.zeros(1, 1, predictor_embed_dim)) + # -- + self.predictor_pos_embed = nn.Parameter(torch.zeros(1, num_patches, predictor_embed_dim), + requires_grad=False) + predictor_pos_embed = get_2d_sincos_pos_embed(self.predictor_pos_embed.shape[-1], + int(num_patches**.5), + cls_token=False) + self.predictor_pos_embed.data.copy_(torch.from_numpy(predictor_pos_embed).float().unsqueeze(0)) + # -- + self.predictor_blocks = nn.ModuleList([ + Block( + dim=predictor_embed_dim, num_heads=num_heads, + mlp_ratio=4, qkv_bias=True, + attn_drop=0.0, drop_path=0.0, + norm_layer=nn.LayerNorm + ) + for _ in range(depth) + ]) + self.predictor_norm = nn.LayerNorm(predictor_embed_dim) + self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True) + # ------ + trunc_normal_(self.mask_token, std=0.02) + self.apply(init_weights) + fix_init_weight(self.predictor_blocks) + + def forward(self, context_patches, masks_context: list, masks_target: list): + assert len(masks_context) == 1 + B = len(context_patches) // len(masks_context) + x = self.predictor_embed(context_patches) + x_pos_embed = self.predictor_pos_embed.repeat(B, 1, 1) + x += apply_masks(x_pos_embed, masks_context) + + _, N_ctxt, D = x.shape + + # -- concat mask tokens to x + pos_embs = self.predictor_pos_embed.repeat(B, 1, 1) + pos_embs = apply_masks(pos_embs, masks_target) + pos_embs = repeat_interleave_batch(pos_embs, B, repeat=len(masks_context)) + # -- + pred_tokens = self.mask_token.repeat(pos_embs.size(0), pos_embs.size(1), 1) + # -- + pred_tokens += pos_embs + x = x.repeat(len(masks_target), 1, 1) + x = torch.cat([x, pred_tokens], dim=1) + + # -- fwd prop + for blk in self.predictor_blocks: + x = blk(x) + x = self.predictor_norm(x) + + # -- return preds for mask tokens + x = x[:, N_ctxt:] + x = self.predictor_proj(x) + + return x + +def forward_ijepa(self: ssl.Module, batch: dict, stage: str) -> dict: + out = {} + + with torch.no_grad(): + # -- keys for rankme and linear probe + out['context_embeddings'] = self.context_encoder(batch['image'], masks = None) + out['meanpool'] = out['context_embeddings'].mean(dim=1) + out['flat'] = out['context_embeddings'].reshape(out['context_embeddings'].shape[0], -1) + + if not self.training: + return out + + with torch.no_grad(): + target_patches = self.target_encoder(batch['image']) + target_patches = F.layer_norm(target_patches, (target_patches.size(-1),)) + target_patches = apply_masks(target_patches, batch['masks_target']) + out['target_embeddings'] = target_patches + + out['context_embeddings'] = self.context_encoder(batch['image'], [batch['mask_context']] if not isinstance(batch['mask_context'], list) else batch['mask_context']) + out['predictions'] = self.predictor( + out['context_embeddings'], + [batch['mask_context']] if not isinstance(batch['mask_context'], list) else batch['mask_context'], + batch['masks_target'] + ) + out['loss'] = self.ijepa_loss(out['predictions'], out['target_embeddings']) * int(not self.zero_loss) + return out + + +inet1k_train = ssl.data.HFDataset( + path="clane9/imagenet-100", + split="train", + transform=transforms.Compose( + transforms.RGB(), + transforms.RandomResizedCrop((224, 224), scale=(0.3, 1.0)), + transforms.ContextTargetsMultiBlockMask( + patch_size=14, + context_scale=(0.85, 1.0), + context_aspect_ratio=(1.0, 1.0), + target_scales=((0.15, 0.2),) * 4, + target_aspect_ratios=((0.75, 1.5),) * 4, + min_keep=10, + ), + transforms.ToImage( + mean=(0.485, 0.456, 0.406), + std= (0.229, 0.224, 0.225), + ), + ), +) + +inet1k_val = ssl.data.HFDataset( + path="clane9/imagenet-100", + split="validation", + transform=transforms.Compose( + transforms.RGB(), + transforms.Resize((256, 256)), + transforms.CenterCrop((224, 224)), + transforms.ToImage( + mean=(0.485, 0.456, 0.406), + std= (0.229, 0.224, 0.225), + ), + ), +) + +# collate function that makes each mask have the same number of indices, so they can be batched +def standardize_masks(batch: list[dict]): + context_indices = [item.pop("mask_context") for item in batch] + target_indices = [item.pop("masks_target") for item in batch] + batch = torch.utils.data.default_collate(batch) + + min_keep_enc = min(len(ctx) for ctx in context_indices) + min_keep_pred = min( + len(block) for multiblock in target_indices for block in multiblock + ) + + context_batch = [ctx[:min_keep_enc] for ctx in context_indices] + target_batch = [ + [tgt[:min_keep_pred] for tgt in multiblock] for multiblock in target_indices + ] + + collated_masks_context = torch.utils.data.default_collate(context_batch) + collated_masks_target = torch.utils.data.default_collate(target_batch) + + batch["mask_context"] = collated_masks_context + batch["masks_target"] = collated_masks_target + return batch + +mask_transform_kwargs = dict( + input_size=(224, 224), + patch_size=14, + enc_mask_scale=(0.2, 0.8), + pred_mask_scale=(0.15, 0.2), + aspect_ratio=(0.75, 1.5), + nenc=1, + npred=4, + min_keep=10, +) +train = torch.utils.data.DataLoader( + dataset=inet1k_train, + batch_size=BATCH_SIZE, + shuffle=True, + num_workers=64, + drop_last=True, + collate_fn=standardize_masks, + pin_memory=True, + persistent_workers=True, +) + +val = torch.utils.data.DataLoader( + dataset=inet1k_val, + batch_size=VAL_BATCH_SIZE, + num_workers=64, + shuffle=False, +) + +encoder = IJEPA_ViT_Encoder() +predictor = IJEPA_ViT_Predictor(num_patches=encoder.patch_embed.num_patches) +target_encoder = TeacherStudentWrapper(encoder, base_ema_coefficient=0.996, final_ema_coefficient=1.0) + +ema_callback = TeacherStudentCallback(update_frequency=1, update_after_backward=True) +rankme = ssl.callbacks.RankMe( + name="rankme", target="flat", + queue_length=max(512, BATCH_SIZE), + target_shape=(encoder.embed_dim * encoder.patch_embed.num_patches) +) + +linear_probe = ssl.callbacks.OnlineProbe( + name=f'linear_probe', input='meanpool', target='label', + probe=torch.nn.Sequential( + torch.nn.BatchNorm1d(encoder.embed_dim, affine=False), + torch.nn.Linear(encoder.embed_dim, 100) + ), + loss_fn=torch.nn.CrossEntropyLoss(), + optimizer={ + "type": "LARS", + "lr": effective_probe_lr, + "weight_decay": 1e-6, + }, + scheduler={ + "type": "StepLR", + "step_size": 15 * len(train), + "gamma": 1.0, + }, + metrics={ + "top1": torchmetrics.classification.MulticlassAccuracy(100), + "top5": torchmetrics.classification.MulticlassAccuracy(100, top_k=5), + } + ) + +module = ssl.Module( + context_encoder=encoder, + target_encoder=target_encoder, + predictor=predictor, + zero_loss=ZERO_LOSS, + forward=forward_ijepa, + ijepa_loss=F.smooth_l1_loss, + optim={ + "optimizer": { + "type": "AdamW", + "lr": effective_lr, + "weight_decay": 0.04, # TODO Scheduler to 0.4 + }, + "scheduler": { + "type": "LinearWarmupCosineAnnealing", + "peak_step": 40 * int(len(train) / NUM_DEVICES), + "start_factor": 1 / 5, + "total_steps": 300 * int(len(train) / NUM_DEVICES), + "end_lr": 1.0e-6, + }, + "interval": "step", + } +) + +trainer = pl.Trainer( + max_epochs=300, num_sanity_val_steps=0, + limit_train_batches = None if not ZERO_LOSS else 1, + callbacks=[ + linear_probe, + rankme, ema_callback, + ModelCheckpoint(monitor='train/loss', mode='min', save_top_k=1, save_last=True, every_n_epochs=10), + LearningRateMonitor(logging_interval='step') + ], + precision='16-mixed', + logger=pl_loggers.WandbLogger( + project="ijepa-cifar10", entity="samibg", name=f"new-ijepa-inet100-num-devices{NUM_DEVICES}-zero-loss{ZERO_LOSS}", + log_model=False, offline=False, + ), + # enable_checkpointing=False, + accelerator="gpu", devices=NUM_DEVICES, gradient_clip_val=None, + strategy=DDPStrategy( + find_unused_parameters=True, # this is because only teacher's params are used in the teacher-student module + static_graph=True, + gradient_as_bucket_view=True, + ) +) + +data = ssl.data.DataModule(train=train, val=val) + +manager = ssl.Manager(trainer=trainer, + module=module, + data=data, + ckpt_path='/home/sky/stable-SSL/ijepa-cifar10/yxvs6tll/checkpoints/epoch=119-step=7320.ckpt', +) + +if __name__ == "__main__": + manager() \ No newline at end of file diff --git a/benchmarks/imagenet100/vicreg-resnet50.py b/benchmarks/imagenet100/vicreg-resnet50.py index ca233e75..bee8f628 100644 --- a/benchmarks/imagenet100/vicreg-resnet50.py +++ b/benchmarks/imagenet100/vicreg-resnet50.py @@ -155,9 +155,9 @@ def forward(self, batch, stage): ) wandb_logger = WandbLogger( - entity="stable-ssl", - project="imagenet100-vicreg", - name="vicreg-resnet18", + entity="samibg", + project="ijepa-cifar10", + name="vicreg-resnet50", log_model=False, ) diff --git a/stable_ssl/data/masking.py b/stable_ssl/data/masking.py new file mode 100644 index 00000000..bc81e9fe --- /dev/null +++ b/stable_ssl/data/masking.py @@ -0,0 +1,330 @@ +import math + +import matplotlib.patches as patches +import matplotlib.pyplot as plt +import numpy as np +import torch +from matplotlib.colors import ListedColormap + + +def _sample_block_size( + height: int, + width: int, + min_scale: float, + max_scale: float, + min_aspect_ratio: float, + max_aspect_ratio: float, +): + """Sample a single block mask for an image. + + Args: + height (int): Height of the image in patches. + width (int): Width of the image in patches. + min_scale (float): Minimum scale factor for block area relative to total image area. + max_scale (float): Maximum scale factor for block area relative to total image area. + min_aspect_ratio (float): Minimum aspect ratio (height/width) for the block. + max_aspect_ratio (float): Maximum aspect ratio (height/width) for the block. + + Returns: + tuple[int, int]: A tuple (h, w) containing the sampled block height and width. + """ + _rand = torch.rand(1).item() + mask_scale = min_scale + _rand * (max_scale - min_scale) + max_keep = int(height * width * mask_scale) + aspect_ratio = min_aspect_ratio + _rand * (max_aspect_ratio - min_aspect_ratio) + + # Compute block height and width (given scale and aspect-ratio) + h = int(round(math.sqrt(max_keep * aspect_ratio))) + h = min(h, height - 1) + + w = int(round(math.sqrt(max_keep / aspect_ratio))) + w = min(w, width - 1) + + return (h, w) + + +def _sample_block_mask( + image_size: tuple[int, int], + block_size: tuple[int, int], + min_keep: int = 1, +): + """Sample a single block mask for an image. + Because mask positions are chosen randomly, we can occasionally end up with a mask that is too small. + This function will retry until a valid mask is found. + + Args: + image_size: Tuple[int, int] - Size of the image in patches + block_size: Tuple[int, int] - Size of the block in patches + min_keep (int): Minimum number of patches that must be in the block. + + Returns: + tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - mask: Binary tensor indices of patches exposed to encoder (1 = visible, 0 = masked) + - pred_mask: Binary tensor where of combined target block masks to be predicted (1 = visible, 0 = masked) + """ + h, w = block_size + height, width = image_size + max_attempts = 20 + + for _ in range(max_attempts): + top = torch.randint(0, height - h + 1, (1,)).item() + left = torch.randint(0, width - w + 1, (1,)).item() + + mask = torch.zeros((height, width), dtype=torch.int32) + mask[top : top + h, left : left + w] = 1 + + # Return the first mask that satisfies min_keep. + if torch.sum(mask) >= min_keep: + return mask + + # If we run out of attempts, return whatever we had last. + else: + return mask + + +def multi_block_mask( + height: int, + width: int, + block_scales: list[tuple[float, float]] = [(0.85, 1.0), *((0.15, 0.2),) * 4], + aspect_ratios: list[tuple[float, float]] = [(1.0, 1.0), *((0.75, 1.5),) * 4], + min_keep: int = 1, + seed: int = 0, +) -> list[torch.Tensor, ...]: + g = torch.Generator() + g.manual_seed(seed) + + # mapping from unique combinations of size x aspect ratio to the block size. + block_scale_to_size = { + (scale, ratio): _sample_block_size( + height, width, scale[0], scale[1], ratio[0], ratio[1] + ) + for scale, ratio in set(zip(block_scales, aspect_ratios)) + } + + masks: list[torch.Tensor] = [ + _sample_block_mask( + (height, width), block_scale_to_size[((sh, sw), (ah, aw))], min_keep + ) + for (sh, sw), (ah, aw) in zip(block_scales, aspect_ratios) + ] + # -- Return binary masks + return masks + + +def visualize_masking_strategy( + height=14, width=14, num_examples=6, save_path="ijepa_masking_visualization.png" +): + """Visualize the I-JEPA masking strategy with multiple examples. + + Args: + height: Image height in patches + width: Image width in patches + num_examples: Number of masking examples to show + save_path: Path to save the visualization + """ + # Each example shows: context + 4 target blocks = 5 columns total + fig, axes = plt.subplots(num_examples, 5, figsize=(15, 3 * num_examples)) + if num_examples == 1: + axes = axes.reshape(1, 5) + + # Set random seed for reproducible examples + torch.manual_seed(42) + + for i in range(num_examples): + # Generate masks - returns (context, combined_targets, individual_targets) + cleaned_context_mask, individual_target_masks = multi_block_mask( + height, + width, + num_blocks=4, + context_scale=(0.85, 1.0), + target_scale=(0.15, 0.2), + aspect_ratio=(0.75, 1.5), + ) + + # Convert to numpy for visualization + context_np = cleaned_context_mask.numpy() + + # Column 0: Context mask only + ax = axes[i, 0] + cmap_context = ListedColormap(["white", "lightblue"]) + ax.imshow(context_np, cmap=cmap_context, vmin=0, vmax=1) + ax.set_title("Context" if i == 0 else "", fontsize=10) + ax.set_xticks(range(0, width, 4)) + ax.set_yticks(range(0, height, 4)) + ax.grid(True, alpha=0.3) + + # Add grid lines for patches + for x in range(width + 1): + ax.axvline(x - 0.5, color="gray", linewidth=0.5, alpha=0.5) + for y in range(height + 1): + ax.axhline(y - 0.5, color="gray", linewidth=0.5, alpha=0.5) + + # Add row label + if i == 0: + ax.set_ylabel("Example 1", fontsize=12, rotation=0, ha="right", va="center") + else: + ax.set_ylabel( + f"Example {i + 1}", fontsize=12, rotation=0, ha="right", va="center" + ) + + # Columns 1-4: Individual target masks + target_colors = ["red", "orange", "green", "purple"] + for j, target_mask in enumerate(individual_target_masks): + ax = axes[i, j + 1] + target_np = target_mask.numpy() + + # Create colormap for this target + cmap_target = ListedColormap(["white", target_colors[j]]) + ax.imshow(target_np, cmap=cmap_target, vmin=0, vmax=1) + ax.set_title(f"Target {j + 1}" if i == 0 else "", fontsize=10) + ax.set_xticks(range(0, width, 4)) + ax.set_yticks(range(0, height, 4)) + ax.grid(True, alpha=0.3) + + # Add grid lines for patches + for x in range(width + 1): + ax.axvline(x - 0.5, color="gray", linewidth=0.5, alpha=0.5) + for y in range(height + 1): + ax.axhline(y - 0.5, color="gray", linewidth=0.5, alpha=0.5) + + # Add legend + legend_elements = [ + patches.Patch(color="white", label="Visible patches"), + patches.Patch(color="lightblue", label="Context block"), + patches.Patch(color="red", label="Target block 1"), + patches.Patch(color="orange", label="Target block 2"), + patches.Patch(color="green", label="Target block 3"), + patches.Patch(color="purple", label="Target block 4"), + ] + + fig.legend( + handles=legend_elements, + loc="center", + bbox_to_anchor=(0.5, 0.02), + ncol=6, + fontsize=10, + ) + + plt.suptitle( + "I-JEPA Masking Strategy: Context and Individual Target Blocks", + fontsize=14, + y=0.95, + ) + plt.tight_layout() + plt.subplots_adjust(bottom=0.12, top=0.88) + + # Save the figure + plt.savefig(save_path, dpi=300, bbox_inches="tight") + print(f"Visualization saved to: {save_path}") + plt.close() + + +def analyze_masking_statistics(height=14, width=14, num_samples=1000): + """Analyze statistics of the masking strategy.""" + context_scales = [] + target_scales = [] + aspect_ratios = [] + + torch.manual_seed(42) + + for _ in range(num_samples): + cleaned_context_mask, individual_target_masks = multi_block_mask( + height, + width, + num_blocks=4, + context_scale=(0.85, 1.0), + target_scale=(0.15, 0.2), + aspect_ratio=(0.75, 1.5), + ) + + total_patches = height * width + context_scale = torch.sum(cleaned_context_mask).item() / total_patches + context_scales.append(context_scale) + + for target_mask in individual_target_masks: + target_scale = torch.sum(target_mask).item() / total_patches + target_scales.append(target_scale) + + # Calculate aspect ratio of target + coords = torch.where(target_mask == 1) + if len(coords[0]) > 0: + h_extent = coords[0].max() - coords[0].min() + 1 + w_extent = coords[1].max() - coords[1].min() + 1 + aspect_ratios.append(h_extent.item() / w_extent.item()) + + # Create statistics plot + fig, axes = plt.subplots(1, 3, figsize=(15, 4)) + + axes[0].hist( + context_scales, bins=30, alpha=0.7, color="lightblue", edgecolor="black" + ) + axes[0].set_title("Context Block Scales") + axes[0].set_xlabel("Scale (fraction of image)") + axes[0].set_ylabel("Frequency") + axes[0].axvline( + np.mean(context_scales), + color="red", + linestyle="--", + label=f"Mean: {np.mean(context_scales):.3f}", + ) + axes[0].legend() + + axes[1].hist(target_scales, bins=30, alpha=0.7, color="orange", edgecolor="black") + axes[1].set_title("Target Block Scales") + axes[1].set_xlabel("Scale (fraction of image)") + axes[1].set_ylabel("Frequency") + axes[1].axvline( + np.mean(target_scales), + color="red", + linestyle="--", + label=f"Mean: {np.mean(target_scales):.3f}", + ) + axes[1].legend() + + axes[2].hist(aspect_ratios, bins=30, alpha=0.7, color="green", edgecolor="black") + axes[2].set_title("Target Block Aspect Ratios") + axes[2].set_xlabel("Aspect Ratio (height/width)") + axes[2].set_ylabel("Frequency") + axes[2].axvline( + np.mean(aspect_ratios), + color="red", + linestyle="--", + label=f"Mean: {np.mean(aspect_ratios):.3f}", + ) + axes[2].legend() + + plt.tight_layout() + plt.savefig("ijepa_masking_statistics.png", dpi=300, bbox_inches="tight") + print("Statistics saved to: ijepa_masking_statistics.png") + plt.close() + + return { + "context_scales": context_scales, + "target_scales": target_scales, + "aspect_ratios": aspect_ratios, + } + + +if __name__ == "__main__": + print("Generating I-JEPA masking visualizations...") + + # Create main visualization + visualize_masking_strategy( + height=14, width=14, num_examples=8, save_path="ijepa_masking_examples.png" + ) + + # Generate statistics + stats = analyze_masking_statistics() + + print("\nMasking Statistics:") + print( + f"Context scale - Mean: {np.mean(stats['context_scales']):.3f}, Std: {np.std(stats['context_scales']):.3f}" + ) + print( + f"Target scale - Mean: {np.mean(stats['target_scales']):.3f}, Std: {np.std(stats['target_scales']):.3f}" + ) + print( + f"Aspect ratio - Mean: {np.mean(stats['aspect_ratios']):.3f}, Std: {np.std(stats['aspect_ratios']):.3f}" + ) + + print("\nVisualization complete! Check the generated PNG files.") diff --git a/stable_ssl/data/transforms.py b/stable_ssl/data/transforms.py index a91fbcdf..52588d59 100644 --- a/stable_ssl/data/transforms.py +++ b/stable_ssl/data/transforms.py @@ -15,6 +15,8 @@ from torchvision.transforms.v2 import functional as F from torchvision.transforms.v2._utils import query_chw +from stable_ssl.data.masking import multi_block_mask + class Transform(v2.Transform): """Base transform class extending torchvision v2.Transform with nested data handling.""" @@ -711,6 +713,134 @@ def __call__(self, sample): return sample +class ContextTargetsMultiBlockMask(Transform): + """Transform that adds multi-block masks to batch, with multiple target blocks and one disjoint context block. + + Args: + patch_size: Size of the patch in patches + num_blocks: Number of blocks to sample + context_scale: Scale of the context block + aspect_ratio: Aspect ratio of the blocks + min_keep: Minimum number of patches that must be in the block + + """ + + def __init__( + self, + patch_size=16, + context_scale=(0.85, 1.0), + context_aspect_ratio=(1.0, 1.0), + target_scales=((0.15, 0.2),) * 4, + target_aspect_ratios=((0.75, 1.5),) * 4, + min_keep=10, + source: str = "image", + target_context: str = "mask_context", + target_targets: str = "masks_target", + ): + super().__init__() + self.patch_size = patch_size + self.context_scale = context_scale + self.context_aspect_ratio = context_aspect_ratio + self.target_scales = target_scales + self.target_aspect_ratios = target_aspect_ratios + self.source = source + self.target_context = target_context + self.target_targets = target_targets + if len(target_scales) != len(target_aspect_ratios): + raise ValueError( + "Each scale must have its associated aspect ratio and vice versa.", + "Received {len(target_scales)=} {len(target_aspect_ratios)=}", + ) + + self.min_keep = min_keep + + def __call__(self, x): + source = self.nested_get(x, self.source) + if isinstance(source, PIL.Image.Image): + W, H = source.size # PIL is W,H + elif isinstance(source, torch.Tensor): + # assumes H W + H, W = source.shape[-2:] + else: + raise ValueError( + f"Source must be a PIL.Image.Image or a torch.Tensor, but got {type(source)} instead." + ) + + scales = [self.context_scale, *self.target_scales] + aspect_ratios = [self.context_aspect_ratio, *self.target_aspect_ratios] + context_mask, *target_masks = multi_block_mask( + H // self.patch_size, + W // self.patch_size, + block_scales=scales, + aspect_ratios=aspect_ratios, + min_keep=self.min_keep, + ) + # makes targets disjoint with context + for mask in target_masks: + context_mask &= ~mask + + x[self.target_context] = torch.nonzero(context_mask.flatten()).squeeze() + x[self.target_targets] = [ + torch.nonzero(mask.flatten()).squeeze() for mask in target_masks + ] + x[self.get_name(x)] = torch.tensor([scales, aspect_ratios]) + return x + + +class RandomMask(Transform): + def __init__( + self, + patch_size=16, + mask_ratio=0.75, + source: str = "image", + target_visible: str = "mask_visible", + target_masked: str = "mask_masked", + target_ids_restore: str = "ids_restore", + target_len_keep: str = "len_keep", + ): + super().__init__() + self.patch_size = patch_size + self.mask_ratio = mask_ratio + self.source = source + self.target_visible = target_visible + self.target_masked = target_masked + self.target_ids_restore = target_ids_restore + self.target_len_keep = target_len_keep + + def __call__(self, x): + source = self.nested_get(x, self.source) + if isinstance(source, PIL.Image.Image): + W, H = source.size # PIL is W,H + elif isinstance(source, torch.Tensor): + # NOTE assumes _HW + H, W = source.shape[-2:] + else: + raise ValueError( + f"Source must be a PIL.Image.Image or a torch.Tensor, but got {type(source)} instead." + ) + + num_patches = (H // self.patch_size) * (W // self.patch_size) + len_keep = int(num_patches * (1 - self.mask_ratio)) + + # Generate random noise and shuffle indices (like MAE) + noise = torch.rand(num_patches) + ids_shuffle = torch.argsort(noise) + ids_restore = torch.argsort(ids_shuffle) # inverse permutation + + # Split into visible and masked + mask_visible = ids_shuffle[:len_keep] # first len_keep are visible + mask_masked = ids_shuffle[len_keep:] # rest are masked + + # Add to sample + x[self.target_visible] = mask_visible + x[self.target_masked] = mask_masked + x[self.target_ids_restore] = ids_restore # NEW: for reconstructing full sequence + x[self.target_len_keep] = len_keep + + return x + + + class MultiViewTransform(v2.Transform): """Apply different transforms to different views of the same sample. diff --git a/stable_ssl/manager.py b/stable_ssl/manager.py index 8baa1c08..fb38b839 100644 --- a/stable_ssl/manager.py +++ b/stable_ssl/manager.py @@ -403,6 +403,16 @@ def __call__(self): else: ckpt_path = self.ckpt_path logging.info(f"📣📣📣 CALLING trainer.fit with {ckpt_path=} 📣📣📣") + # Enable passing in gradient_clip_* to pl.Trainer while using + # manual optimization of ssl.Module by re-assigning it here. + # https://github.com/rbalestr-lab/stable-ssl/issues/246 + if hasattr(self.trainer, "gradient_clip_val"): + self.module.gradient_clip_val = self.trainer.gradient_clip_val + self.trainer.gradient_clip_val = None + if hasattr(self.trainer, "gradient_clip_algorithm"): + self.module.gradient_clip_algorithm = self.trainer.gradient_clip_algorithm + self.trainer.gradient_clip_algorithm = None + self._trainer.fit( self.instantiated_module, datamodule=self.instantiated_data, diff --git a/stable_ssl/module.py b/stable_ssl/module.py index cdd7165f..e9777b1f 100644 --- a/stable_ssl/module.py +++ b/stable_ssl/module.py @@ -116,6 +116,9 @@ def __init__(self, *args, forward: callable, hparams: dict = None, **kwargs): self._optimizer_index_by_name = None self._optimizer_frequencies = None + def on_fit_start(self): + super().on_fit_start() + def forward(self, *args, **kwargs): raise NotImplementedError("The forward() method must be implemented.") @@ -206,8 +209,8 @@ def training_step(self, batch, batch_idx): # Clip gradients for this optimizer then step self.clip_gradients( opt, - gradient_clip_val=self.trainer.gradient_clip_val, - gradient_clip_algorithm=self.trainer.gradient_clip_algorithm, + gradient_clip_val=self.gradient_clip_val, + gradient_clip_algorithm=self.gradient_clip_algorithm, ) opt.step() opt.zero_grad(set_to_none=True) diff --git a/stable_ssl/optim/lr_scheduler.py b/stable_ssl/optim/lr_scheduler.py index dabbcf50..13a1d7b5 100644 --- a/stable_ssl/optim/lr_scheduler.py +++ b/stable_ssl/optim/lr_scheduler.py @@ -395,3 +395,7 @@ def get_lr(self): / 2 for base_lr in self.base_lrs ] + + +if __name__ == "__main__": + _resolve_scheduler_callable("StepLR") \ No newline at end of file diff --git a/stable_ssl/tests/scripts/mae_arch_test.py b/stable_ssl/tests/scripts/mae_arch_test.py new file mode 100644 index 00000000..c7f9c812 --- /dev/null +++ b/stable_ssl/tests/scripts/mae_arch_test.py @@ -0,0 +1,109 @@ +import torch +import torch.nn as nn +from timm.models.vision_transformer import VisionTransformer + +encoder_kwargs = dict( + img_size=32, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + # include cls token for positional embedding: + # https://github.com/facebookresearch/mae/blob/main/models_mae.py#L68-L69 + no_embed_class=False, + norm_layer=nn.LayerNorm, +) + +decoder_kwargs = dict( + img_size=32, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + class_token=True, + no_embed_class=False, + norm_layer=nn.LayerNorm, +) + +class MAE_Encoder(VisionTransformer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # -------------------------------------------------------------------------- + # MAE encoder specifics + # replicated from timm's + self.num_patches = self.patch_embed.num_patches + self.num_prefix_tokens + # TODO Exclude this and add posembeds from outside ? Can we do that ? I dont think so since we work on raw pixels + # self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, self.embed_dim), requires_grad=False) # fixed sin-cos embedding + + +class MAE_Decoder(VisionTransformer): + def __init__(self, *args, **kwargs): + patch_size = kwargs.get('patch_size', 16) + super().__init__(*args, **kwargs) + # -------------------------------------------------------------------------- + # MAE decoder specifics + # replicated from timm's + self.num_patches = self.patch_embed.num_patches + self.num_prefix_tokens + self.out_proj = nn.Linear(self.embed_dim, self.num_patches * patch_size**2) + +def pos_embed(patches: torch.Tensor, with_cls: bool = True) -> torch.Tensor: + pass + + +def forward_encoder(encoder: MAE_Encoder, images: torch.Tensor, mask:torch.Tensor) -> torch.Tensor: + patches = encoder.patch_embed(images) + posemb = pos_embed(patches, with_cls = True) + posemb_patches = patches + posemb[:, 1:, :] + cls_token = encoder.cls_token + posemb[:, :1, :] + cls_tokens = cls_token.expand(patches.shape[0], -1, -1) + + masked_patches = apply_mask(posemb_patches, mask) + + patches = torch.cat([cls_tokens, masked_patches], dim=1) + for blk in encoder.blocks: + patches = blk(patches) + + patches = encoder.norm(patches) + return patches + +def apply_mask(patches: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + # TODO + pass + +def forward_decoder(decoder: MAE_Decoder, batch: dict) -> torch.Tensor: + patches = batch["patches"] + ids_restore = batch["ids_restore"] + + patches = decoder.patch_embed(patches) + + # NOTE from mae repo directly + # ids_restore im assuming can be the total set of indices + # dim1 is the number of tokens to predict. +1 for cls token? + # so we dont predict cls token? + mask_tokens = decoder.mask_token.repeat( + patches.shape[0], + ids_restore.shape[1] + 1 - patches.shape[1], + 1 + ) + # we want to pos-embed according to their original positions + # so we unshuffle first using ids_restore which is the inverse permutation + patches_ = torch.cat([patches[:, 1:, :], mask_tokens], dim=1) + patches_ = torch.gather(patches_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, patches.shape[2])) # unshuffle + # append cls token + patches = torch.cat([patches[:, :1, :], patches_], dim=1) + patches = patches + pos_embed(patches, with_cls = True) + # apply transformer blocks + patches = decoder.blocks(patches) + patches = decoder.norm(patches) + # predictor projection + patches = decoder.pred(patches) + # remove cls token + return patches[:, 1:, :] + +def forward_mae(self, batch: dict, stage): + out = {} + encoder = self.encoder + decoder = self.decoder + + if stage == "train": + pass \ No newline at end of file diff --git a/stable_ssl/tests/scripts/mae_arch_test2.py b/stable_ssl/tests/scripts/mae_arch_test2.py new file mode 100644 index 00000000..9380170c --- /dev/null +++ b/stable_ssl/tests/scripts/mae_arch_test2.py @@ -0,0 +1,207 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.vision_transformer import VisionTransformer + +# -------------------- utils: positional embeddings -------------------- + +def get_2d_sincos_pos_embed(embed_dim: int, grid_size: int, cls_token: bool = True): + """Returns [1, grid_size*grid_size + cls, embed_dim] fixed sin-cos embeddings.""" + def get_1d_sin_cos(d, n): + omega = torch.arange(d // 2, dtype=torch.float32) / (d // 2) + omega = 1.0 / (10000 ** omega) # [d/2] + pos = torch.arange(n, dtype=torch.float32) # [n] + out = torch.einsum('n,d->nd', pos, omega) # [n, d/2] + return torch.cat([out.sin(), out.cos()], dim=1) # [n, d] + + assert embed_dim % 2 == 0 + # 2D grid + pe_h = get_1d_sin_cos(embed_dim // 2, grid_size) # [G, D/2] + pe_w = get_1d_sin_cos(embed_dim // 2, grid_size) # [G, D/2] + pe = ( + torch.stack([pe_h.unsqueeze(1).expand(-1, grid_size, -1), + pe_w.unsqueeze(0).expand(grid_size, -1, -1)], dim=2) + .reshape(grid_size * grid_size, embed_dim) + ) # [G*G, D] + if cls_token: + pe = torch.cat([torch.zeros(1, embed_dim), pe], dim=0) # [1+G*G, D] + return pe.unsqueeze(0) # [1, 1+G*G, D] + +def pos_embed(x: torch.Tensor, with_cls: bool = True) -> torch.Tensor: + """ + x: [B, N(=H*W or tokens), D] + Returns fixed sin-cos PE of shape [B, (1+N), D] if with_cls else [B, N, D]. + """ + B, N, D = x.shape + G = int(math.sqrt(N)) # assumes square grid of patches + assert G * G == N, f"pos_embed expects square grid, got N={N}" + pe = get_2d_sincos_pos_embed(D, G, cls_token=with_cls).to(x.device, x.dtype) + return pe.expand(B, -1, -1) # [B, 1+N, D] or [B, N, D] depending on with_cls + +# -------------------- utils: masking and patchify -------------------- + +def apply_mask(x: torch.Tensor, ids_keep: torch.Tensor) -> torch.Tensor: + """ + Keep selection by indices. + x: [B, N, D], ids_keep: [B, K] + returns: [B, K, D] + """ + B, N, D = x.shape + idx = ids_keep.unsqueeze(-1).expand(-1, -1, D) # [B, K, D] + return x.gather(dim=1, index=idx) + + +# -------------------- encoder -------------------- + +class MAE_Encoder(VisionTransformer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # number of patch tokens (no cls) + self.patch_size = kwargs.get('patch_size', 16) + self.num_patches = self.patch_embed.num_patches # do NOT add prefix here + + def patchify(self, images: torch.Tensor) -> torch.Tensor: + """ + images: [B, C, H, W] -> [B, N, P*P*C] + """ + B, C, H, W = images.shape + P = self.patch_size + assert H % P == 0 and W % P == 0 + h = H // P + w = W // P + x = images.reshape(B, C, h, P, w, P) + x = x.permute(0, 2, 4, 3, 5, 1).reshape(B, h * w, P * P * C) + return x + + +def forward_encoder(encoder: MAE_Encoder, images: torch.Tensor, ids_keep: torch.Tensor) -> torch.Tensor: + """ + Returns visible token latents (with cls prepended), projected by encoder blocks. + """ + # Patch embed: [B, N, D] + x = encoder.patch_embed(images) + # Fixed PE + pe = pos_embed(x, with_cls=True) # [B, 1+N, D] + x = x + pe[:, 1:, :] # add PE to patch tokens + # cls token with PE + cls_tok = encoder.cls_token + pe[:, :1, :] # [B, 1, D] + cls_tok = cls_tok.expand(x.shape[0], -1, -1) + + # Keep only visible tokens + x_vis = apply_mask(x, ids_keep) # [B, K, D] + # prepend cls + x = torch.cat([cls_tok, x_vis], dim=1) # [B, 1+K, D] + + # Blocks + for blk in encoder.blocks: + x = blk(x) + x = encoder.norm(x) + return x # [B, 1+K, D] + +# -------------------- decoder -------------------- + +class MAE_Decoder(VisionTransformer): + def __init__(self, *args, **kwargs): + in_chans = kwargs.get('in_chans', 3) + patch_size = kwargs.get('patch_size', 16) + self.enc_dim = kwargs.get('enc_dim', 768) + super().__init__(*args, **kwargs) + + # tokens in the decoded grid (no cls) + self.num_patches = self.patch_embed.num_patches + + # Map encoder dim -> decoder dim (use embed_dim of this decoder) + # You must set this externally once you know the encoder dim: + self.decoder_embed = nn.Linear(self.enc_dim, self.embed_dim) # set to Linear(enc_dim, self.embed_dim) after init + + # mask token for missing positions + self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) + + # predict pixel values per token (P^2 * C) + self.patch_size = patch_size + self.in_chans = in_chans + self.decoder_pred = nn.Linear(self.embed_dim, (patch_size ** 2) * in_chans) + +def forward_decoder( + decoder: MAE_Decoder, + enc_tokens_vis: torch.Tensor, # [B, 1+K, D_enc] (cls + visible) + ids_restore: torch.Tensor # [B, N] +) -> torch.Tensor: + """ + Returns pixel predictions for all tokens (no cls): [B, N, P^2*C] + """ + B, _, D_enc = enc_tokens_vis.shape + + # project to decoder dim + x = decoder.decoder_embed(enc_tokens_vis) # [B, 1+K, D_dec] + + # prepare full grid by inserting mask tokens at missing positions + # 1) split cls vs visible + x_cls, x_vis = x[:, :1, :], x[:, 1:, :] # [B,1,D], [B,K,D] + # 2) number of tokens in grid + N = decoder.num_patches + K = x_vis.shape[1] + # 3) tokens to fill + mask_tokens = decoder.mask_token.expand(B, N - K, -1) # [B, N-K, D] + # 4) combine vis + mask, then unshuffle + x_ = torch.cat([x_vis, mask_tokens], dim=1) # [B, N, D] + x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).expand(-1, -1, x_.shape[-1])) # [B, N, D] + # 5) prepend cls + x = torch.cat([x_cls, x_], dim=1) # [B, 1+N, D] + + # add fixed PE for decoder + pe = pos_embed(x[:, 1:, :], with_cls=True) # reuse shape logic + x = x + pe # [B, 1+N, D] + + # transformer blocks + for blk in decoder.blocks: + x = blk(x) + x = decoder.norm(x) + + # predict pixels per token, drop cls + x = decoder.decoder_pred(x[:, 1:, :]) # [B, N, P^2*C] + return x + +# -------------------- MAE training forward -------------------- +mask_ratio = 0.75 + +def forward_mae(self, batch: dict, stage): + """ + Expects batch["image"]: [B,3,H,W] + Produces MAE reconstruction loss on masked patches. + """ + out = {} + encoder: MAE_Encoder = self.encoder + decoder: MAE_Decoder = self.decoder + + images = batch["image"] # [B,3,H,W] + ids_keep = batch["mask_visible"] + ids_restore = batch["ids_restore"] + mask_masked = batch["mask_masked"] + + B = images.shape[0] + N = encoder.num_patches + + # 2) encode visible tokens (cls+visible) + enc_tokens_vis = forward_encoder(encoder, images, ids_keep) # [B,1+K,D_enc] + + # 4) decode to pixel predictions for ALL tokens (no cls) + pred_pix = forward_decoder(decoder, enc_tokens_vis, ids_restore) # [B,N,P^2*C] + + # 5) ground-truth patches + target_pix = encoder.patchify(images) # [B,N,P^2*C] + + # 6) compute loss ONLY on masked tokens + mask_exp = mask_masked.unsqueeze(-1).type_as(pred_pix) # [B,N,1] + loss = self.loss_fn(pred_pix, target_pix, mask_exp) + + # 7) populate outputs for logging / probes + if stage != "train": + return out + + out["loss"] = loss + out["mask_ratio"] = torch.tensor(mask_ratio, device=images.device) + out["ids_restore"] = ids_restore + out["mask"] = mask_masked + return out \ No newline at end of file diff --git a/stable_ssl/tests/scripts/train_ijepa_cifar10.py b/stable_ssl/tests/scripts/train_ijepa_cifar10.py new file mode 100644 index 00000000..2a0f22eb --- /dev/null +++ b/stable_ssl/tests/scripts/train_ijepa_cifar10.py @@ -0,0 +1,464 @@ +import math +from functools import partial +import lightning as pl +import torch +import torch.nn.functional as F +import torchmetrics +import torchvision +from lightning.pytorch.loggers import WandbLogger +from timm.models.vision_transformer import VisionTransformer +from torch import nn + +import stable_ssl as ssl +from stable_ssl.backbone.utils import TeacherStudentWrapper +from stable_ssl.data import transforms +from stable_ssl.data.utils import Dataset +from stable_ssl.utils.pos_embed import get_2d_sincos_pos_embed +from stable_ssl.callbacks.teacher_student import TeacherStudentCallback + + +train_batch_size = 512 +val_batch_size = 128 +num_epochs = 100 +lr_warmup_epochs = 15 +# max_grad_norm = 10.0 +max_grad_norm = None +ema = (0.97, 0.999) +ipe_scale = 1.25 +encoder_embed_dim = 64 +predictor_embed_dim = 32 +lr = 0.02 + +# -- data +num_workers = 64 +num_classes = 10 +mean = [0.485, 0.456, 0.406] +std = [0.229, 0.224, 0.225] +height, width, patch_size = 32, 32, 2 +crop_height, crop_width = 32, 32 # CIFAR-10 is 32x32, but on INET, IJEPA uses 224 +# we precompute these so the predictor can make sinusoidal posembeds +num_patches = (crop_height // patch_size) * (crop_width // patch_size) +patch_channel_dim = 3 * patch_size * patch_size + +# based on the in1k_vith14_ep300.yaml config in the ijepa repository +mask_transform_kwargs = dict( + patch_size=patch_size, + context_scale=(0.85, 1.0), + context_aspect_ratio=(1.0, 1.0), + target_scales=((0.15, 0.2),) * 4, + target_aspect_ratios=((0.75, 1.5),) * 4, + min_keep=4, +) + + + +train_transform = transforms.Compose( + transforms.RGB(), + # transforms.CenterCrop((crop_height, crop_width)), + # transforms.RandomHorizontalFlip(), + transforms.RandomResizedCrop((crop_height, crop_width), scale=(0.3, 1.0)), + transforms.ContextTargetsMultiBlockMask(**mask_transform_kwargs), + transforms.ToImage(mean=mean, std=std), +) +# Don't mask during validation +val_transform = transforms.Compose( + transforms.RGB(), + transforms.Resize((height, width)), + transforms.CenterCrop((crop_height, crop_width)), + transforms.ToImage(mean=mean, std=std), +) + +# Use torchvision CIFAR-10 wrapped in FromTorchDataset +cifar_train = torchvision.datasets.CIFAR10( + root="/tmp/cifar10", train=True, download=True +) +cifar_val = torchvision.datasets.CIFAR10( + root="/tmp/cifar10", train=False, download=True +) + + +class IndexedDataset(Dataset): + """Custom dataset wrapper that adds sample_idx to each sample.""" + + def __init__(self, dataset, transform=None): + super().__init__(transform) + self.dataset = dataset + + def __getitem__(self, idx): + image, label = self.dataset[idx] + sample = {"image": image, "label": label, "sample_idx": idx} + return self.process_sample(sample) + + def __len__(self): + return len(self.dataset) + + +def standardize_masks(batch: list[dict]): + context_indices = [item.pop("mask_context") for item in batch] + target_indices = [item.pop("masks_target") for item in batch] + batch = torch.utils.data.default_collate(batch) + + min_keep_enc = min(len(ctx) for ctx in context_indices) + min_keep_pred = min( + len(block) for multiblock in target_indices for block in multiblock + ) + + context_batch = [ctx[:min_keep_enc] for ctx in context_indices] + target_batch = [ + [tgt[:min_keep_pred] for tgt in multiblock] for multiblock in target_indices + ] + + collated_masks_context = torch.utils.data.default_collate(context_batch) + collated_masks_target = torch.utils.data.default_collate(target_batch) + + batch["mask_context"] = collated_masks_context + batch["masks_target"] = collated_masks_target + return batch + + +train_dataset = IndexedDataset(cifar_train, transform=train_transform) +# IJEPA does not use multi-view sampling like SimCLR etc. because it processes +# single views and handles masking at the model level +train = torch.utils.data.DataLoader( + dataset=train_dataset, + batch_size=train_batch_size, + shuffle=True, # Regular shuffling, no RepeatedRandomSampler + num_workers=num_workers, + drop_last=True, + collate_fn=standardize_masks, + pin_memory=True, + persistent_workers=True, +) + +val_dataset = IndexedDataset(cifar_val, transform=val_transform) +val = torch.utils.data.DataLoader( + dataset=val_dataset, + batch_size=val_batch_size, + num_workers=num_workers, + shuffle=False, + pin_memory=True, + persistent_workers=True, +) + + +data = ssl.data.DataModule(train=train, val=val) + + +def pos_embed(patches: torch.Tensor) -> torch.Tensor: + return ( + torch.from_numpy( + get_2d_sincos_pos_embed(patches.shape[-1], int(math.sqrt(patches.shape[1]))) + ) + .to(patches.device) + .float() + .repeat(patches.shape[0], 1, 1) + ) + + +def apply_masks(x: torch.Tensor, *masks: torch.Tensor) -> torch.Tensor: + B, N, D = x.shape + M = len(masks) + idx = torch.stack( + [m.to(x.device, dtype=torch.long) for m in masks], dim=1 + ) # [B, M, K] + x_exp = x.unsqueeze(1).expand(-1, M, -1, -1) # [B, M, N, D] + out = x_exp.gather(2, idx.unsqueeze(-1).expand(-1, -1, -1, D)) # [B, M, K, D] + return out.reshape(B * M, idx.size(-1), D) # [B*M, K, D] + + +class IJEPA_Encoder(VisionTransformer): + """IJEPA encoder. + + Args: + ijepa_in_dim: Input dimension of the encoder, which is the patch dimension after re-arranging the image. + """ + + def __init__(self, *args, **kwargs): + self.weight_init = kwargs.get("weight_init", "") + self.fix_init = kwargs.get("fix_init", True) + ijepa_in_dim = kwargs.pop("ijepa_in_dim") + super().__init__(*args, **kwargs) + + self.ijepa_patch_project = nn.Linear(ijepa_in_dim, self.embed_dim) + + if self.weight_init != "skip": + self.init_weights(self.weight_init) + if self.fix_init: + self.fix_init_weight() + + def patchify(self, image: torch.Tensor) -> torch.Tensor: + """Convert image tensor into patches. + + Args: + image: Tensor of shape [B, C, H, W] + + Returns: + patches: Tensor of shape [B, N, P*P*C] where: + N = number of patches (H/P * W/P) + P = patch size + """ + B, C, H, W = image.shape + P = patch_size + + # Unfold into patches + patches = image.unfold(2, P, P).unfold(3, P, P) + + # Reshape to [B, num_patches, patch_dim] + patches = patches.permute(0, 2, 3, 1, 4, 5) + num_patch_h, num_patch_w = patches.shape[1], patches.shape[2] + patches = patches.reshape(B, num_patch_h * num_patch_w, P * P * C) + + return patches + + def project_patches(self, patches: torch.Tensor) -> torch.Tensor: + # assume they are already reshaped to patches + return self.ijepa_patch_project(patches) + + def encode_patches( + self, patches: torch.Tensor, with_layernorm: bool = True + ) -> torch.Tensor: + x = self.blocks(patches) + if with_layernorm: + x = F.layer_norm(x, (x.size(-1),)) + return x + + def encode_image(self, image: torch.Tensor) -> torch.Tensor: + patches = self.patchify(image) + patches = self.project_patches(patches) + patches = patches + pos_embed(patches) + return self.encode_patches(patches) + + +class IJEPA_Predictor(VisionTransformer): + """IJEPA predictor, handles the logic of conditioning the predictor based on the context and target masks. + + Args: + predictor_num_patches: Number of patches in the predictor. This is typically equal to the number of patches in the context/target encoder. + ijepa_encoder_dim: Dimension of the IJEPA context/target encoder. This is used to up/down project the latents. + """ + + def __init__(self, *args, **kwargs): + self.weight_init = kwargs.get("weight_init", "") + self.fix_init = kwargs.get("fix_init", True) + self.predictor_num_patches = kwargs.pop("predictor_num_patches") + self.ijepa_encoder_dim = kwargs.pop("ijepa_encoder_dim") + self.predictor_pos_embed = pos_embed( + torch.zeros(1, self.predictor_num_patches, kwargs["embed_dim"]) + ) + super().__init__(*args, **kwargs) + self.predictor_pos_embed = nn.Parameter( + self.predictor_pos_embed, requires_grad=False + ) + self.predictor_inproj = nn.Linear(self.ijepa_encoder_dim, self.embed_dim) + self.predictor_outproj = nn.Linear(self.embed_dim, self.ijepa_encoder_dim) + self.predictor_norm = nn.LayerNorm(self.embed_dim) + self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) + + if self.weight_init != "skip": + self.init_weights(self.weight_init) + if self.fix_init: + self.fix_init_weight() + + def project_context(self, context_patches: torch.Tensor) -> torch.Tensor: + return self.predictor_inproj(context_patches) + + def predict_targets( + self, context_patches: torch.Tensor, masks_target: list[torch.Tensor] + ) -> torch.Tensor: + B, *_ = context_patches.shape + M = len(masks_target) + + # NOTE: These are already projected -> posembedded + ctx: torch.Tensor = context_patches + + # target position embeddings (stacked per mask): [B*M, K_tgt, D] + pos_all = self.predictor_pos_embed.expand(B, -1, -1) + tgt_pos = apply_masks(pos_all, *masks_target) + + # repeat context across M target blocks: [B*M, N_ctx, D]. this means that + # the predictor predicts each target block independently, and not their union, + # as the ijepa repo does + ctx = ctx.repeat_interleave(M, dim=0) + + # mask tokens placed at target positions + N_tgt = tgt_pos.size(1) + pred_tokens = self.mask_token.expand(B * M, N_tgt, -1) + tgt_pos + + # each target block now gets predicted: [B*M, N_ctx+N_tgt, D] + x = torch.cat([ctx, pred_tokens], dim=1) + x = self.blocks(x) + x = self.predictor_norm(x) + + pred = x[:, -N_tgt:] + pred = self.predictor_outproj(pred) + return pred + + +# pico vit +encoder_kwargs = dict( + patch_size=patch_size, + embed_dim=encoder_embed_dim, + depth=12, + num_heads=2, + qkv_bias=True, + ijepa_in_dim=patch_channel_dim, +) +predictor_kwargs = dict( + patch_size=patch_size, + embed_dim=predictor_embed_dim, + depth=6, + num_heads=2, + qkv_bias=True, + ijepa_encoder_dim=encoder_embed_dim, + predictor_num_patches=num_patches, +) + +context_encoder = IJEPA_Encoder(**encoder_kwargs) +predictor = IJEPA_Predictor(**predictor_kwargs) + + +def forward(self: ssl.Module, batch, stage): + out = {} + target_encoder: IJEPA_Encoder = self.target_encoder.teacher + context_encoder: IJEPA_Encoder = self.context_encoder + predictor: IJEPA_Predictor = self.predictor + ijepa_loss: nn.Module = self.ijepa_loss + with torch.no_grad(): + image_patches = target_encoder.patchify(batch["image"]) + target_patches = target_encoder.project_patches(image_patches) + pos_embedding = pos_embed(target_patches) + target_patches = target_patches + pos_embedding + out["target_embedding"] = target_encoder.encode_patches( + target_patches, with_layernorm=True + ) + unmasked_context_patches = context_encoder.project_patches(image_patches) + unmasked_pos_embedding = pos_embed(unmasked_context_patches) + out["context_embedding"] = context_encoder.encode_patches( + unmasked_context_patches + unmasked_pos_embedding, + with_layernorm = False + ) + out["meanpool_context_embedding"] = out["context_embedding"].mean(dim=1) + out["sum_context_embedding"] = out["context_embedding"].sum(dim=1) + out["flat_context_embedding"] = out["context_embedding"].reshape(out["context_embedding"].shape[0], -1) + + if not self.training: + return out + + mask_context, masks_target = batch["mask_context"], batch["masks_target"] + # target encoder is applied on full patches, then masked + out["target_patches"] = apply_masks(out["target_embedding"], *masks_target) + # context encoder is applied on masked patches + context_patches = apply_masks(image_patches, mask_context) + context_patches = context_encoder.project_patches(context_patches) + context_patches = context_patches + apply_masks(pos_embedding, mask_context) + context_patches = context_encoder.encode_patches( + context_patches, with_layernorm=False + ) + out["context_patches"] = context_patches + # use context_patches.shape[0] because applying 4 masks quadruples the batch dimension for the predictor output + out["predicted_patches"] = predictor.predict_targets( + predictor.project_context(context_patches) + apply_masks(predictor.predictor_pos_embed.repeat(context_patches.shape[0],1,1), mask_context), + masks_target + ) + out["loss"] = ijepa_loss(out["predicted_patches"], out["target_patches"]) + # context embedding + return out + + +module = ssl.Module( + context_encoder=context_encoder, + target_encoder=TeacherStudentWrapper(context_encoder, base_ema_coefficient=ema[0], final_ema_coefficient=ema[1]), + predictor=predictor, + forward=forward, + ijepa_loss=F.smooth_l1_loss, + # ijepa_loss=partial(F.mse_loss, reduction='mean'), + optim={ + "optimizer": { + "type": "AdamW", + "lr": lr, + "weight_decay": 0.0, + }, + "scheduler": { + "type": "LinearWarmupCosineAnnealing", + "peak_step": lr_warmup_epochs * len(train), + "total_steps": num_epochs * len(train), + }, + "interval": "step", + }, +) + + +probe_optimizer = partial(torch.optim.AdamW, lr=1e-3, weight_decay=0.0, betas=(0.9, 0.999), eps=1e-8) +probe_scheduler = partial(torch.optim.lr_scheduler.CosineAnnealingLR, T_max=int(ipe_scale * num_epochs * len(train))) + +linear_probe = ssl.callbacks.OnlineProbe( + "linear_probe", + "meanpool_context_embedding", + "label", + probe=torch.nn.Sequential( + torch.nn.BatchNorm1d(encoder_embed_dim, affine=False), + torch.nn.Linear(encoder_embed_dim, 10) + ), + loss_fn=torch.nn.CrossEntropyLoss(), + optimizer=probe_optimizer, + scheduler=probe_scheduler, + metrics={ + "top1": torchmetrics.classification.MulticlassAccuracy(10), + "top5": torchmetrics.classification.MulticlassAccuracy(10, top_k=5), + }, +) + +rankme = ssl.callbacks.RankMe( + name="rankme", + target="flat_context_embedding", + queue_length=min(512, train_batch_size), # NOTE must be >= batch_size + target_shape=(encoder_embed_dim * num_patches), +) + +# Initialize W&B logger with explicit settings +wandb_logger = WandbLogger( + project="ijepa-cifar10", + entity="samibg", # Your W&B entity + name="ijepa-cifar10-run", + log_model=False, # Set to True if you want to save model artifacts + offline=False, # Ensure offline mode +) + +class PerModuleGradLogger(pl.Callback): + def __init__(self, modules=("predictor", "context_encoder", "target_encoder"), norm_type=2): + self.modules = modules + self.norm_type = norm_type + + def on_before_optimizer_step(self, trainer, pl_module, optimizer): + device = pl_module.device + if trainer.global_step % 1000 == 0: + for mod in self.modules: + group = [(n, p) for n, p in pl_module.named_parameters() if n.startswith(f"{mod}.")] + grads = [p.grad for _, p in group if p.grad is not None] + if not grads: + pl_module.log(f"grad/{mod}_norm", torch.tensor(0.0, device=device), on_step=True, logger=True, sync_dist=True) + pl_module.log(f"grad/{mod}_nz_params", torch.tensor(0, device=device), on_step=True, logger=True, sync_dist=True) + continue + norms = torch.stack([g.detach().data.float().norm(self.norm_type).to(device) for g in grads]) + total_norm = torch.norm(norms, self.norm_type) + nonzero = (norms > 0).sum() + pl_module.log(f"grad/{mod}_norm", total_norm, on_step=True, logger=True, sync_dist=True) + pl_module.log(f"grad/{mod}_nz_params", nonzero, on_step=True, logger=True, sync_dist=True) + +trainer = pl.Trainer( + max_epochs=num_epochs, + num_sanity_val_steps=0, # Skip sanity check as queues need to be filled first + callbacks=[ + linear_probe, rankme, PerModuleGradLogger(modules=("predictor","context_encoder","target_encoder")), + TeacherStudentCallback(update_frequency=100), + ], + precision="16-mixed", + logger=wandb_logger, + devices=1, + enable_checkpointing=False, + gradient_clip_val=max_grad_norm, + gradient_clip_algorithm="norm", +) + +manager = ssl.Manager(trainer=trainer, module=module, data=data) +manager() diff --git a/stable_ssl/tests/scripts/train_ijepa_inet1k.py b/stable_ssl/tests/scripts/train_ijepa_inet1k.py new file mode 100644 index 00000000..75dad612 --- /dev/null +++ b/stable_ssl/tests/scripts/train_ijepa_inet1k.py @@ -0,0 +1,447 @@ +import torch +import math +import lightning as pl +import torch.nn.functional as F +import torchmetrics +from lightning.pytorch.loggers import WandbLogger +from timm.models.vision_transformer import VisionTransformer +from torch import nn +from functools import partial + +import stable_ssl as ssl +from stable_ssl.backbone.utils import TeacherStudentWrapper +from stable_ssl.data import transforms +from stable_ssl.utils.pos_embed import get_2d_sincos_pos_embed +from lightning.pytorch.strategies import DDPStrategy +from stable_ssl.callbacks.teacher_student import TeacherStudentCallback + +train_batch_size = 128 +val_batch_size = 128 +num_workers = 32 +num_classes = 1000 +num_epochs = 300 +lr_warmup_epochs = 60 +start_lr = 0.0002 +lr = 0.001 +final_lr = 1.0e-6 +# max_grad_norm = 5.0 +max_grad_norm = None +ema = (0.996, 1.0) + + +mean = [0.485, 0.456, 0.406] +std = [0.229, 0.224, 0.225] +height, width, patch_size = 256, 256, 14 +crop_height, crop_width = 224, 224 # CIFAR-10 is 32x32, but on INET, IJEPA uses 224 +# We precompute these so the predictor can make sinusoidal posembeds +num_patches = (crop_height // patch_size) * (crop_width // patch_size) +patch_channel_dim = 3 * patch_size * patch_size +encoder_embed_dim = 768 +predictor_embed_dim = 384 + +# Based on the in1k_vith14_ep300.yaml config in the ijepa repository +mask_transform_kwargs = dict( + patch_size=patch_size, + context_scale=(0.85, 1.0), + context_aspect_ratio=(1.0, 1.0), + target_scales=((0.15, 0.2),) * 4, + target_aspect_ratios=((0.75, 1.5),) * 4, + min_keep=10, +) + + +train_transform = transforms.Compose( + transforms.RGB(), + transforms.RandomResizedCrop((crop_height, crop_width), scale=(0.3, 1.0)), + transforms.ContextTargetsMultiBlockMask(**mask_transform_kwargs), + transforms.ToImage(mean=mean, std=std), +) + +# Don't mask during validation +val_transform = transforms.Compose( + transforms.RGB(), + transforms.Resize((height, width)), + transforms.CenterCrop((crop_height, crop_width)), + transforms.ToImage(mean=mean, std=std), +) + + +inet1k_train = ssl.data.HFDataset( + path="ILSVRC/imagenet-1k", + split="train", + transform=train_transform, +) + +inet1k_val = ssl.data.HFDataset( + path="ILSVRC/imagenet-1k", + split="validation", + transform=val_transform, +) + + +def standardize_masks(batch: list[dict]): + context_indices = [item.pop("mask_context") for item in batch] + target_indices = [item.pop("masks_target") for item in batch] + batch = torch.utils.data.default_collate(batch) + + min_keep_enc = min(len(ctx) for ctx in context_indices) + min_keep_pred = min( + len(block) for multiblock in target_indices for block in multiblock + ) + + context_batch = [ctx[:min_keep_enc] for ctx in context_indices] + target_batch = [ + [tgt[:min_keep_pred] for tgt in multiblock] for multiblock in target_indices + ] + + collated_masks_context = torch.utils.data.default_collate(context_batch) + collated_masks_target = torch.utils.data.default_collate(target_batch) + + batch["mask_context"] = collated_masks_context + batch["masks_target"] = collated_masks_target + return batch + + +# IJEPA does not use multi-view sampling like SimCLR etc. because it processes +# single views and handles masking at the model level +train = torch.utils.data.DataLoader( + dataset=inet1k_train, + batch_size=train_batch_size, + shuffle=True, # Regular shuffling, no RepeatedRandomSampler + num_workers=num_workers, + drop_last=True, + collate_fn=standardize_masks, + pin_memory=True, + persistent_workers=True, +) + +val = torch.utils.data.DataLoader( + dataset=inet1k_val, + batch_size=val_batch_size, + num_workers=num_workers, + shuffle=False, +) + + +data = ssl.data.DataModule(train=train, val=val) + + +def pos_embed(patches: torch.Tensor) -> torch.Tensor: + return ( + torch.from_numpy( + get_2d_sincos_pos_embed(patches.shape[-1], int(math.sqrt(patches.shape[1]))) + ) + .to(patches.device) + .float() + .repeat(patches.shape[0], 1, 1) + ) + + +def apply_masks(x: torch.Tensor, *masks: torch.Tensor) -> torch.Tensor: + B, N, D = x.shape + M = len(masks) + idx = torch.stack( + [m.to(x.device, dtype=torch.long) for m in masks], dim=1 + ) # [B, M, K] + x_exp = x.unsqueeze(1).expand(-1, M, -1, -1) # [B, M, N, D] + out = x_exp.gather(2, idx.unsqueeze(-1).expand(-1, -1, -1, D)) # [B, M, K, D] + return out.reshape(B * M, idx.size(-1), D) # [B*M, K, D] + + +class IJEPA_Encoder(VisionTransformer): + """IJEPA encoder. + + Args: + ijepa_in_dim: Input dimension of the encoder, which is the patch dimension after re-arranging the image. + """ + + def __init__(self, *args, **kwargs): + self.weight_init = kwargs.get("weight_init", "") + self.fix_init = kwargs.get("fix_init", True) + ijepa_in_dim = kwargs.pop("ijepa_in_dim") + super().__init__(*args, **kwargs) + + self.ijepa_patch_project = nn.Linear(ijepa_in_dim, self.embed_dim) + + if self.weight_init != "skip": + self.init_weights(self.weight_init) + if self.fix_init: + self.fix_init_weight() + + def patchify(self, image: torch.Tensor) -> torch.Tensor: + """Convert image tensor into patches. + + Args: + image: Tensor of shape [B, C, H, W] + + Returns: + patches: Tensor of shape [B, N, P*P*C] where: + N = number of patches (H/P * W/P) + P = patch size + """ + B, C, H, W = image.shape + P = patch_size + + # Unfold into patches + patches = image.unfold(2, P, P).unfold(3, P, P) + + # Reshape to [B, num_patches, patch_dim] + patches = patches.permute(0, 2, 3, 1, 4, 5) + num_patch_h, num_patch_w = patches.shape[1], patches.shape[2] + patches = patches.reshape(B, num_patch_h * num_patch_w, P * P * C) + + return patches + + def project_patches(self, patches: torch.Tensor) -> torch.Tensor: + # assume they are already reshaped to patches + return self.ijepa_patch_project(patches) + + def encode_patches( + self, patches: torch.Tensor, with_layernorm: bool = True + ) -> torch.Tensor: + x = self.blocks(patches) + if with_layernorm: + x = F.layer_norm(x, (x.size(-1),)) + return x + + def encode_image(self, image: torch.Tensor) -> torch.Tensor: + patches = self.patchify(image) + patches = self.project_patches(patches) + patches = patches + pos_embed(patches) + return self.encode_patches(patches) + + +class IJEPA_Predictor(VisionTransformer): + """IJEPA predictor, handles the logic of conditioning the predictor based on the context and target masks. + + Args: + predictor_num_patches: Number of patches in the predictor. This is typically equal to the number of patches in the context/target encoder. + ijepa_encoder_dim: Dimension of the IJEPA context/target encoder. This is used to up/down project the latents. + """ + + def __init__(self, *args, **kwargs): + self.weight_init = kwargs.get("weight_init", "") + self.fix_init = kwargs.get("fix_init", True) + self.predictor_num_patches = kwargs.pop("predictor_num_patches") + self.ijepa_encoder_dim = kwargs.pop("ijepa_encoder_dim") + self.predictor_pos_embed = pos_embed( + torch.zeros(1, self.predictor_num_patches, kwargs["embed_dim"]) + ) + super().__init__(*args, **kwargs) + self.predictor_pos_embed = nn.Parameter( + self.predictor_pos_embed, requires_grad=False + ) + self.predictor_inproj = nn.Linear(self.ijepa_encoder_dim, self.embed_dim) + self.predictor_outproj = nn.Linear(self.embed_dim, self.ijepa_encoder_dim) + self.predictor_norm = nn.LayerNorm(self.embed_dim) + self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) + + if self.weight_init != "skip": + self.init_weights(self.weight_init) + if self.fix_init: + self.fix_init_weight() + + def project_context(self, context_patches: torch.Tensor) -> torch.Tensor: + return self.predictor_inproj(context_patches) + + def predict_targets( + self, context_patches: torch.Tensor, masks_target: list[torch.Tensor] + ) -> torch.Tensor: + # NOTE context_patches already positionally embedded + B, *_ = context_patches.shape + M = len(masks_target) + + # NOTE: These are already projected -> posembedded + ctx: torch.Tensor = context_patches + + # target position embeddings (stacked per mask): [B*M, K_tgt, D] + pos_all = self.predictor_pos_embed.expand(B, -1, -1) + tgt_pos = apply_masks(pos_all, *masks_target) + + # repeat context across M target blocks: [B*M, N_ctx, D]. this means that + # the predictor predicts each target block independently, and not their union, + # as the ijepa repo does + ctx = ctx.repeat_interleave(M, dim=0) + + # mask tokens placed at target positions + N_tgt = tgt_pos.size(1) + pred_tokens = self.mask_token.expand(B * M, N_tgt, -1) + tgt_pos + + # each target block now gets predicted: [B*M, N_ctx+N_tgt, D] + x = torch.cat([ctx, pred_tokens], dim=1) + x = self.blocks(x) + x = self.predictor_norm(x) + + pred = x[:, -N_tgt:] + pred = self.predictor_outproj(pred) + return pred + + +encoder_kwargs = dict( + patch_size=patch_size, + embed_dim=encoder_embed_dim, + depth=12, + num_heads=12, + qkv_bias=True, + ijepa_in_dim=patch_channel_dim, +) +predictor_kwargs = dict( + patch_size=patch_size, + embed_dim=predictor_embed_dim, + depth=12, + num_heads=12, + qkv_bias=True, + ijepa_encoder_dim=encoder_embed_dim, + predictor_num_patches=num_patches, +) + + +def forward(self: ssl.Module, batch, stage): + out = {} + target_encoder: IJEPA_Encoder = self.target_encoder.teacher + context_encoder: IJEPA_Encoder = self.context_encoder + predictor: IJEPA_Predictor = self.predictor + ijepa_loss: nn.Module = self.ijepa_loss + with torch.no_grad(): + image_patches = target_encoder.patchify(batch["image"]) + target_patches = target_encoder.project_patches(image_patches) + pos_embedding = pos_embed(target_patches) + target_patches = target_patches + pos_embedding + out["target_embedding"] = target_encoder.encode_patches( + target_patches, with_layernorm=True + ) + unmasked_context_patches = context_encoder.project_patches(image_patches) + unmasked_pos_embedding = pos_embed(unmasked_context_patches) + out["context_embedding"] = context_encoder.encode_patches( + unmasked_context_patches + unmasked_pos_embedding, + with_layernorm = False + ) + out["meanpool_context_embedding"] = out["context_embedding"].mean(dim=1) + out["sum_context_embedding"] = out["context_embedding"].sum(dim=1) + out["flat_context_embedding"] = out["context_embedding"].reshape(out["context_embedding"].shape[0], -1) + + if not self.training: + return out + + mask_context, masks_target = batch["mask_context"], batch["masks_target"] + # target encoder is applied on full patches, then masked + out["target_patches"] = apply_masks(out["target_embedding"], *masks_target) + # context encoder is applied on masked patches + context_patches = apply_masks(image_patches, mask_context) + context_patches = context_encoder.project_patches(context_patches) + context_patches = context_patches + apply_masks(pos_embedding, mask_context) + context_patches = context_encoder.encode_patches( + context_patches, with_layernorm=False + ) + out["context_patches"] = context_patches + # TODO Re-write this whole thing with the patchembeds inside the models themselves and using timms patchembed/block + # but making ur own vits. + # use context_patches.shape[0] because applying 4 masks quadruples the batch dimension for the predictor output + predictor_pos_embed = apply_masks(predictor.predictor_pos_embed.repeat(context_patches.shape[0],1,1), mask_context) + out["predicted_patches"] = predictor.predict_targets( + predictor.project_context(context_patches) + predictor_pos_embed, + masks_target + ) + out["loss"] = ijepa_loss(out["predicted_patches"], out["target_patches"]) + # context embedding + return out + + + +module = ssl.Module( + context_encoder=(ctx := IJEPA_Encoder(**encoder_kwargs)), + target_encoder=TeacherStudentWrapper(ctx), + predictor=IJEPA_Predictor(**predictor_kwargs), + forward=forward, + ijepa_loss=F.smooth_l1_loss, + optim={ + "optimizer": { + "type": "AdamW", + "lr": lr, + "weight_decay": 0.04, + }, + "scheduler": { + "type": "LinearWarmupCosineAnnealing", + "peak_step": lr_warmup_epochs * len(train), + "total_steps": num_epochs * len(train), + "end_lr": final_lr, + }, + "interval": "step", + } +) + + +linear_probe = ssl.callbacks.OnlineProbe( + "linear_probe", + "meanpool_context_embedding", + "label", + probe=torch.nn.Sequential( + torch.nn.BatchNorm1d(encoder_embed_dim, affine=False), + torch.nn.Linear(encoder_embed_dim, num_classes) + ), + loss_fn=torch.nn.CrossEntropyLoss(), + metrics={ + "top1": torchmetrics.classification.MulticlassAccuracy(num_classes), + "top5": torchmetrics.classification.MulticlassAccuracy(num_classes, top_k=5), + }, +) + +rankme = ssl.callbacks.RankMe( + name="rankme", + target="flat_context_embedding", + queue_length=min(512, train_batch_size), # NOTE must be >= batch_size + target_shape=(encoder_embed_dim * num_patches), +) + +class PerModuleGradLogger(pl.Callback): + def __init__(self, modules=("predictor", "context_encoder", "target_encoder"), norm_type=2): + self.modules = modules + self.norm_type = norm_type + + def on_before_optimizer_step(self, trainer, pl_module, optimizer): + device = pl_module.device + if trainer.global_step % 100 == 0: + for mod in self.modules: + group = [(n, p) for n, p in pl_module.named_parameters() if n.startswith(f"{mod}.")] + grads = [p.grad for _, p in group if p.grad is not None] + if not grads: + pl_module.log(f"grad/{mod}_norm", torch.tensor(0.0, device=device), on_step=True, logger=True, sync_dist=True) + pl_module.log(f"grad/{mod}_nz_params", torch.tensor(0, device=device), on_step=True, logger=True, sync_dist=True) + continue + norms = torch.stack([g.detach().data.float().norm(self.norm_type).to(device) for g in grads]) + total_norm = torch.norm(norms, self.norm_type) + nonzero = (norms > 0).sum() + pl_module.log(f"grad/{mod}_norm", total_norm, on_step=True, logger=True, sync_dist=True) + pl_module.log(f"grad/{mod}_nz_params", nonzero, on_step=True, logger=True, sync_dist=True) + + +# Initialize W&B logger with explicit settings +wandb_logger = WandbLogger( + project="ijepa-cifar10", + entity="samibg", # Your W&B entity + name="ijepa-inet1k-run", + log_model=False, # Set to True if you want to save model artifacts + offline=False, # Ensure offline mode +) + +trainer = pl.Trainer( + max_epochs=num_epochs, + num_sanity_val_steps=0, # Skip sanity check as queues need to be filled first + callbacks=[linear_probe, rankme, PerModuleGradLogger(modules=("predictor","context_encoder","target_encoder")), + TeacherStudentCallback(update_frequency=1, update_after_backward=True), + ], + precision="16-mixed", + logger=wandb_logger, + enable_checkpointing=False, + accelerator="gpu", + devices=8, + gradient_clip_val=max_grad_norm, + strategy=DDPStrategy( + find_unused_parameters=True, # this is because only teacher's params are used in the teacher-student module + static_graph=True, + gradient_as_bucket_view=True, + ) +) + +manager = ssl.Manager(trainer=trainer, module=module, data=data) +manager() diff --git a/stable_ssl/tests/scripts/train_mae_cifar10.py b/stable_ssl/tests/scripts/train_mae_cifar10.py new file mode 100644 index 00000000..fe2fed28 --- /dev/null +++ b/stable_ssl/tests/scripts/train_mae_cifar10.py @@ -0,0 +1,342 @@ +import math + +import lightning as pl +import torch +import torch.nn.functional as F +import torchmetrics +import torchvision +from lightning.pytorch.loggers import WandbLogger +from timm.models.vision_transformer import VisionTransformer +from torch import nn + +import stable_ssl as ssl +from stable_ssl.data import transforms +from stable_ssl.data.utils import Dataset +from stable_ssl.utils.pos_embed import get_2d_sincos_pos_embed + + +# Dataset configuration +num_classes = 10 +mean = [0.485, 0.456, 0.406] +std = [0.229, 0.224, 0.225] + +# Image and patch configuration +height, width = 32, 32 +crop_height, crop_width = 32, 32 +patch_size = 2 +num_patches = (crop_height // patch_size) * (crop_width // patch_size) +patch_channel_dim = 3 * patch_size * patch_size + +# Masking configuration +mask_ratio = 0.75 +num_visible_patches = int(num_patches * (1 - mask_ratio)) + +# Training configuration +batch_size = 128 +val_batch_size = 128 +num_epochs = 2000 +num_workers = 16 + +# Optimization configuration +lr = 1.5e-4 +warmup_epochs = 400 +weight_decay = 0.05 + +# Model configuration +encoder_embed_dim = 192 +decoder_embed_dim = 192 + +mask_transform_kwargs = dict( + patch_size=patch_size, + mask_ratio=mask_ratio, + source="image", + target_visible="mask_visible", + target_masked="mask_masked", +) + +train_transform = transforms.Compose( + # transforms.RandomResizedCrop((crop_height, crop_width), scale=(0.2, 1.0), interpolation=3), # 3 is bicubic + transforms.RandomHorizontalFlip(), + transforms.RandomMask(**mask_transform_kwargs), + transforms.ToImage(mean=mean, std=std), +) + +val_transform = transforms.Compose( + transforms.RGB(), + # transforms.Resize((height, width)), + # transforms.CenterCrop((height, width)), + transforms.ToImage(mean=mean, std=std), +) + + +cifar_train = torchvision.datasets.CIFAR10( + root="/tmp/cifar10", train=True, download=True +) +cifar_val = torchvision.datasets.CIFAR10( + root="/tmp/cifar10", train=False, download=True +) + + +class IndexedDataset(Dataset): + """Custom dataset wrapper that adds sample_idx to each sample.""" + + def __init__(self, dataset, transform=None): + super().__init__(transform) + self.dataset = dataset + + def __getitem__(self, idx): + image, label = self.dataset[idx] + sample = {"image": image, "label": label, "sample_idx": idx} + return self.process_sample(sample) + + def __len__(self): + return len(self.dataset) + + + +train_dataset = IndexedDataset(cifar_train, transform=train_transform) +val_dataset = IndexedDataset(cifar_val, transform=val_transform) +train = torch.utils.data.DataLoader( + dataset=train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + drop_last=True, + collate_fn=torch.utils.data.default_collate, + pin_memory=True, + persistent_workers=True, +) + +val = torch.utils.data.DataLoader( + dataset=val_dataset, + batch_size=val_batch_size, + num_workers=num_workers, + shuffle=False, + collate_fn=torch.utils.data.default_collate, + pin_memory=True, + persistent_workers=True, +) + +data = ssl.data.DataModule(train=train, val=val) + + +def pos_embed(patches: torch.Tensor, with_cls: bool = True) -> torch.Tensor: + embed = ( + torch.from_numpy( + get_2d_sincos_pos_embed(patches.shape[-1], int(math.sqrt(patches.shape[1]))) + ) + .to(patches.device) + .float() + .repeat(patches.shape[0], 1, 1) + ) + if with_cls: + embed = torch.cat([ + torch.zeros(embed.shape[0], 1, embed.shape[2], device=embed.device), + embed + ], dim=1) + + return embed + + +def apply_mask(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """Apply single mask to tensor""" + B, N, D = x.shape + mask_expanded = mask.unsqueeze(-1).expand(-1, -1, D) + return torch.gather(x, dim=1, index=mask_expanded) + + +def patchify(images: torch.Tensor, patch_size: int) -> torch.Tensor: + """ + images: [B, C, H, W] -> [B, N, P*P*C] + """ + B, C, H, W = images.shape + P = patch_size + assert H % P == 0 and W % P == 0 + h = H // P + w = W // P + x = images.reshape(B, C, h, P, w, P) + x = x.permute(0, 2, 4, 3, 5, 1).reshape(B, h * w, P * P * C) + return x + + +class MAE_Encoder(VisionTransformer): + def __init__(self, *args, **kwargs): + mae_in_dim = kwargs.pop('mae_in_dim', 3 * patch_size * patch_size) + super().__init__(*args, **kwargs) + # number of patch tokens (no cls) + self.patch_size = kwargs.get('patch_size', 16) + self.num_patches = self.patch_embed.num_patches + self.mae_patch_project = nn.Linear(mae_in_dim, self.embed_dim) + + def project_patches(self, patches: torch.Tensor) -> torch.Tensor: + return self.mae_patch_project(patches) + + +class MAE_Decoder(VisionTransformer): + def __init__(self, *args, **kwargs): + mae_enc_dim = kwargs.pop('mae_enc_dim', 768) + super().__init__(*args, **kwargs) + in_chans = kwargs.get('in_chans', 3) + patch_size = kwargs.get('patch_size', 16) + # tokens in the decoded grid (no cls) + self.num_patches = self.patch_embed.num_patches + self.decoder_embed = nn.Linear(mae_enc_dim, self.embed_dim) + # mask token for missing positions + self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) + # predict pixel values per token (P^2 * C) + self.patch_size = patch_size + self.in_chans = in_chans + self.decoder_pred = nn.Linear(self.embed_dim, (patch_size ** 2) * in_chans) + + +def _forward_decoder(self, batch: dict, out: dict, stage) -> torch.Tensor: + decoder: MAE_Decoder = self.decoder + patches_visible = out["embeddings"] + inverse_shuffle = batch["ids_restore"].unsqueeze(-1).expand(-1, -1, decoder.embed_dim) + # project encoded patches to decoder's space + decoder_patches = decoder.decoder_embed(patches_visible) + + patches_cls, patches_visible = torch.split(decoder_patches, [1, decoder_patches.shape[1] - 1], dim=1) + batch_size = patches_visible.shape[0] + num_patches = decoder.num_patches + num_visible = patches_visible.shape[1] + + mask_tokens = decoder.mask_token.expand(batch_size, num_patches - num_visible, -1) + # combine visible patches and mask tokens then unshuffle into place + patches = torch.cat([patches_visible, mask_tokens], dim=1) + unshuffled_patches = torch.gather(patches, dim=1, index=inverse_shuffle) + patches = torch.cat([patches_cls, unshuffled_patches], dim=1) + + pe = pos_embed(patches[:, 1:, :], with_cls=True) + patches = patches + pe + for blk in decoder.blocks: + patches = blk(patches) + patches = decoder.norm(patches) + # remove cls token + patches = patches[:, 1:, :] + patches = decoder.decoder_pred(patches) + return patches + + +def forward(self, batch: dict, stage): + out = {} + encoder: MAE_Encoder = self.encoder + images = batch["image"] + image_patches = patchify(images, patch_size) + patches = encoder.project_patches(image_patches) + posemb_cls, posemb_patches = torch.split(pos_embed(patches, with_cls=True), [1, patches.shape[1]], dim=1) + patches = patches + posemb_patches + cls_tok = encoder.cls_token + posemb_cls + + if self.training: + indices_keep, indices_masked = batch["mask_visible"], batch["mask_masked"] + patches_visible = apply_mask(patches, indices_keep) + patches_visible = torch.cat([cls_tok, patches_visible], dim=1) + for blk in encoder.blocks: + patches_visible = blk(patches_visible) + patches_visible = encoder.norm(patches_visible) + out["embeddings"] = patches_visible + out["reconstructed_pixel_patches"] = _forward_decoder(self, batch, out, stage) + out["loss"] = self.loss_fn( + apply_mask(out["reconstructed_pixel_patches"], indices_masked), + apply_mask(image_patches, indices_masked) + ) + else: + patches = torch.cat([cls_tok, patches], dim=1) + for blk in encoder.blocks: + patches = blk(patches) + patches = encoder.norm(patches) + out["embeddings"] = patches + + return out + + +encoder_kwargs = dict( + img_size=(crop_height, crop_width), + patch_size=patch_size, + embed_dim=encoder_embed_dim, + depth=12, + num_heads=3, + qkv_bias=True, # MAE typically uses bias + mae_in_dim=patch_channel_dim, +) + +decoder_kwargs = dict( + img_size=(crop_height, crop_width), + patch_size=patch_size, + mae_enc_dim=encoder_embed_dim, + embed_dim=decoder_embed_dim, + depth=4, + num_heads=3, +) + + +module = ssl.Module( + encoder=MAE_Encoder(**encoder_kwargs), + decoder=MAE_Decoder(**decoder_kwargs), + forward=forward, + loss_fn=F.mse_loss, # pixel MSE loss. we make implicit assumption that norm-pix-loss is False + optim={ + "optimizer": { + "type": "AdamW", + "lr": lr * (batch_size / 256), + "weight_decay": weight_decay, + "betas": (0.9, 0.999), + }, + "scheduler": { + "type": "LinearWarmupCosineAnnealing", + "peak_step": warmup_epochs * len(train), + "total_steps": num_epochs * len(train), + }, + "interval": "step", + }, +) + + +class MAE_Classifier(torch.nn.Module): + def __init__(self, patch_dim: int, num_classes: int): + super().__init__() + # classifies from the cls token + self.classifier = torch.nn.Linear(patch_dim, num_classes) + + def forward(self, embedding: torch.Tensor) -> torch.Tensor: + cls_token = embedding[:, 0, :] + return self.classifier(cls_token) + + +# Note: Linear probe uses visible patches only during training +linear_probe = ssl.callbacks.OnlineProbe( + "linear_probe", + "embeddings", + "label", + probe=MAE_Classifier(encoder_embed_dim, num_classes), + loss_fn=torch.nn.CrossEntropyLoss(), + metrics={ + "top1": torchmetrics.classification.MulticlassAccuracy(num_classes), + "top5": torchmetrics.classification.MulticlassAccuracy(num_classes, top_k=5), + }, +) + + +# Initialize W&B logger +wandb_logger = WandbLogger( + project="ijepa-cifar10", + entity="samibg", + name="mae-cifar10-run", + log_model=False, + offline=False, +) + +trainer = pl.Trainer( + max_epochs=num_epochs, + num_sanity_val_steps=0, + callbacks=[linear_probe], + precision="16-mixed", + logger=wandb_logger, + enable_checkpointing=False, + accelerator="gpu", + devices=1, +) + +manager = ssl.Manager(trainer=trainer, module=module, data=data) +manager() \ No newline at end of file diff --git a/stable_ssl/tests/scripts/train_mae_inet1k.py b/stable_ssl/tests/scripts/train_mae_inet1k.py new file mode 100644 index 00000000..474a66ba --- /dev/null +++ b/stable_ssl/tests/scripts/train_mae_inet1k.py @@ -0,0 +1,303 @@ +import math + +import lightning as pl +import torch +import torch.nn.functional as F +import torchmetrics +import torchvision +from lightning.pytorch.loggers import WandbLogger +from timm.models.vision_transformer import VisionTransformer +from torch import nn +from lightning.pytorch.strategies import DDPStrategy + +import stable_ssl as ssl +from stable_ssl.data import transforms +from stable_ssl.data.utils import Dataset +from stable_ssl.utils.pos_embed import get_2d_sincos_pos_embed + +encoder_embed_dim = 768 +decoder_embed_dim = 512 +train_batch_size = 128 +val_batch_size = 128 +num_workers = 0 +num_classes = 1000 + +# TODO +mean = [0.485, 0.456, 0.406] +std = [0.229, 0.224, 0.225] +height, width, patch_size = 256, 256, 16 +crop_height, crop_width = 224, 224 +num_patches = (crop_height // patch_size) * (crop_width // patch_size) +patch_channel_dim = 3 * patch_size * patch_size +mask_ratio = 0.75 +num_visible_patches = int(num_patches * (1 - mask_ratio)) + +mask_transform_kwargs = dict( + patch_size=patch_size, + mask_ratio=mask_ratio, + source="image", + target_visible="mask_visible", + target_masked="mask_masked", +) + +train_transform = transforms.Compose( + transforms.RandomResizedCrop((crop_height, crop_width), scale=(0.2, 1.0)), + transforms.RandomHorizontalFlip(), + transforms.RandomMask(**mask_transform_kwargs), + transforms.ToImage(mean=mean, std=std), +) + +val_transform = transforms.Compose( + transforms.RGB(), + transforms.Resize((height, width)), + transforms.CenterCrop((height, width)), + transforms.ToImage(mean=mean, std=std), +) + + +inet1k_train = ssl.data.HFDataset( + path="ILSVRC/imagenet-1k", + split="train", + transform=train_transform, +) + +inet1k_val = ssl.data.HFDataset( + path="ILSVRC/imagenet-1k", + split="validation", + transform=val_transform, +) + + +train = torch.utils.data.DataLoader( + dataset=inet1k_train, + batch_size=train_batch_size, + shuffle=True, + num_workers=num_workers, + drop_last=True, + collate_fn=torch.utils.data.default_collate, + pin_memory=True, + persistent_workers=True, +) + +val = torch.utils.data.DataLoader( + dataset=inet1k_val, + batch_size=val_batch_size, + num_workers=num_workers, + shuffle=False, + collate_fn=torch.utils.data.default_collate, + pin_memory=True, + persistent_workers=True, +) + +data = ssl.data.DataModule(train=train, val=val) + + +def pos_embed(patches: torch.Tensor, with_cls: bool = True) -> torch.Tensor: + embed = ( + torch.from_numpy( + get_2d_sincos_pos_embed(patches.shape[-1], int(math.sqrt(patches.shape[1]))) + ) + .to(patches.device) + .float() + .repeat(patches.shape[0], 1, 1) + ) + if with_cls: + embed = torch.cat([ + torch.zeros(embed.shape[0], 1, embed.shape[2], device=embed.device), + embed + ], dim=1) + + return embed + + +def apply_mask(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """Apply single mask to tensor""" + B, N, D = x.shape + mask_expanded = mask.unsqueeze(-1).expand(-1, -1, D) + return torch.gather(x, dim=1, index=mask_expanded) + + +def patchify(images: torch.Tensor, patch_size: int) -> torch.Tensor: + """ + images: [B, C, H, W] -> [B, N, P*P*C] + """ + B, C, H, W = images.shape + P = patch_size + assert H % P == 0 and W % P == 0 + h = H // P + w = W // P + x = images.reshape(B, C, h, P, w, P) + x = x.permute(0, 2, 4, 3, 5, 1).reshape(B, h * w, P * P * C) + return x + + +class MAE_Encoder(VisionTransformer): + def __init__(self, *args, **kwargs): + mae_in_dim = kwargs.pop('mae_in_dim', 3 * patch_size * patch_size) + super().__init__(*args, **kwargs) + # number of patch tokens (no cls) + self.patch_size = kwargs.get('patch_size', 16) + self.num_patches = self.patch_embed.num_patches + self.mae_patch_project = nn.Linear(mae_in_dim, self.embed_dim) + + def project_patches(self, patches: torch.Tensor) -> torch.Tensor: + return self.mae_patch_project(patches) + + +class MAE_Decoder(VisionTransformer): + def __init__(self, *args, **kwargs): + mae_enc_dim = kwargs.pop('mae_enc_dim', 768) + super().__init__(*args, **kwargs) + in_chans = kwargs.get('in_chans', 3) + patch_size = kwargs.get('patch_size', 16) + # tokens in the decoded grid (no cls) + self.num_patches = self.patch_embed.num_patches + self.decoder_embed = nn.Linear(mae_enc_dim, self.embed_dim) + # mask token for missing positions + self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) + # predict pixel values per token (P^2 * C) + self.patch_size = patch_size + self.in_chans = in_chans + self.decoder_pred = nn.Linear(self.embed_dim, (patch_size ** 2) * in_chans) + + +def _forward_decoder(self, batch: dict, out: dict, stage) -> torch.Tensor: + decoder: MAE_Decoder = self.decoder + patches_visible = out["embeddings"] + inverse_shuffle = batch["ids_restore"].unsqueeze(-1).expand(-1, -1, decoder.embed_dim) + # project encoded patches to decoder's space + decoder_patches = decoder.decoder_embed(patches_visible) + + patches_cls, patches_visible = torch.split(decoder_patches, [1, decoder_patches.shape[1] - 1], dim=1) + batch_size = patches_visible.shape[0] + num_patches = decoder.num_patches + num_visible = patches_visible.shape[1] + + mask_tokens = decoder.mask_token.expand(batch_size, num_patches - num_visible, -1) + # combine visible patches and mask tokens then unshuffle into place + patches = torch.cat([patches_visible, mask_tokens], dim=1) + unshuffled_patches = torch.gather(patches, dim=1, index=inverse_shuffle) + patches = torch.cat([patches_cls, unshuffled_patches], dim=1) + + pe = pos_embed(patches[:, 1:, :], with_cls=True) + patches = patches + pe + for blk in decoder.blocks: + patches = blk(patches) + patches = decoder.norm(patches) + # remove cls token + patches = patches[:, 1:, :] + patches = decoder.decoder_pred(patches) + return patches + + +def forward(self, batch: dict, stage): + out = {} + encoder: MAE_Encoder = self.encoder + images = batch["image"] + image_patches = patchify(images, patch_size) + patches = encoder.project_patches(image_patches) + posemb_cls, posemb_patches = torch.split(pos_embed(patches, with_cls=True), [1, patches.shape[1]], dim=1) + patches = patches + posemb_patches + cls_tok = encoder.cls_token + posemb_cls + + if self.training: + indices_keep, indices_masked = batch["mask_visible"], batch["mask_masked"] + patches_visible = apply_mask(patches, indices_keep) + patches_visible = torch.cat([cls_tok, patches_visible], dim=1) + for blk in encoder.blocks: + patches_visible = blk(patches_visible) + patches_visible = encoder.norm(patches_visible) + out["embeddings"] = patches_visible + out["reconstructed_pixel_patches"] = _forward_decoder(self, batch, out, stage) + out["loss"] = self.loss_fn( + apply_mask(out["reconstructed_pixel_patches"], indices_masked), + apply_mask(image_patches, indices_masked) + ) + else: + patches = torch.cat([cls_tok, patches], dim=1) + for blk in encoder.blocks: + patches = blk(patches) + patches = encoder.norm(patches) + out["embeddings"] = patches + + # exclude cls token for rankme (flat_embedding) and linear probe (sum_embedding) + out["flat_embedding"] = out["embeddings"][:, 1:, :].flatten(start_dim=1) + out["sum_embedding"] = out["embeddings"][:, 1:, :].sum(dim=1) + return out + + +encoder_kwargs = dict( + img_size=(crop_height, crop_width), + patch_size=patch_size, + embed_dim=encoder_embed_dim, + depth=16, + num_heads=16, + qkv_bias=True, # MAE typically uses bias + mae_in_dim=patch_channel_dim, +) + +decoder_kwargs = dict( + img_size=(crop_height, crop_width), + patch_size=patch_size, + mae_enc_dim=encoder_embed_dim, + embed_dim=decoder_embed_dim, + depth=8, + num_heads=16, +) + + +module = ssl.Module( + encoder=MAE_Encoder(**encoder_kwargs), + decoder=MAE_Decoder(**decoder_kwargs), + forward=forward, + loss_fn=F.mse_loss, # pixel MSE loss. we make implicit assumption that norm-pix-loss is False +) + +# Note: Linear probe uses visible patches only during training +linear_probe = ssl.callbacks.OnlineProbe( + "linear_probe", + "sum_embedding", + "label", + probe=torch.nn.Linear(encoder_embed_dim, num_classes), + loss_fn=torch.nn.CrossEntropyLoss(), + metrics={ + "top1": torchmetrics.classification.MulticlassAccuracy(num_classes), + "top5": torchmetrics.classification.MulticlassAccuracy(num_classes, top_k=5), + }, +) + +# RankMe on encoder outputs +rankme = ssl.callbacks.RankMe( + name="rankme", + target="flat_embedding", + queue_length=min(512, train_batch_size), + target_shape=(num_visible_patches, encoder_embed_dim), +) + +# Initialize W&B logger +wandb_logger = WandbLogger( + project="mae-inet1k", + entity="slightly-more-badass", + name="mae-inet1k-run", + log_model=False, + offline=True, +) + +trainer = pl.Trainer( + max_epochs=6, + num_sanity_val_steps=0, + callbacks=[linear_probe, rankme], + precision="16-mixed", + logger=wandb_logger, + enable_checkpointing=False, + accelerator="gpu", + devices=8, + strategy=DDPStrategy( + find_unused_parameters=True, + static_graph=True, + gradient_as_bucket_view=True, + ) +) + +manager = ssl.Manager(trainer=trainer, module=module, data=data) +manager() \ No newline at end of file diff --git a/stable_ssl/utils/pos_embed.py b/stable_ssl/utils/pos_embed.py new file mode 100644 index 00000000..be78e875 --- /dev/null +++ b/stable_ssl/utils/pos_embed.py @@ -0,0 +1,67 @@ +import numpy as np + +# -------------------------------------------------------- +# 2D sine-cosine position embedding +# References: +# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py +# MoCo v3: https://github.com/facebookresearch/moco-v3 +# -------------------------------------------------------- + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """Get 1D sinusoidal positional embedding from grid. + + Args: + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + + Returns: + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=float) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """Get 2D sinusoidal positional embedding. + + Args: + embed_dim: embedding dimension + grid_size: int of the grid height and width + cls_token: whether to include class token + + Returns: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed