From 6dc75fefa357a77d6d48c18e3be52d69e259fdc7 Mon Sep 17 00:00:00 2001 From: Sami Date: Sat, 9 Aug 2025 07:38:57 +0000 Subject: [PATCH 01/28] IJEPA Multiblock masking with visualizations --- stable_ssl/data/transforms.py | 24 +++ stable_ssl/utils/masking.py | 330 ++++++++++++++++++++++++++++++++++ 2 files changed, 354 insertions(+) create mode 100644 stable_ssl/utils/masking.py diff --git a/stable_ssl/data/transforms.py b/stable_ssl/data/transforms.py index 77e6ceda..1d741de8 100644 --- a/stable_ssl/data/transforms.py +++ b/stable_ssl/data/transforms.py @@ -15,6 +15,8 @@ from torchvision.transforms.v2 import functional as F from torchvision.transforms.v2._utils import query_chw +from stable_ssl.utils.masking import multi_block_mask + class Transform(v2.Transform): """Base transform class extending torchvision v2.Transform with nested data handling.""" @@ -711,6 +713,28 @@ def __call__(self, sample): return sample +class MultiBlockMask(Transform): + """Transform that adds block masks to batch.""" + + def __init__(self, patch_size=16, mask_ratio=0.75): + super().__init__() + self.patch_size = patch_size + self.mask_ratio = mask_ratio + + def __call__(self, x): + H, W = x["image"].shape[-2:] + mask = multi_block_mask( + H // self.patch_size, + W // self.patch_size, + self.mask_ratio, + ) + x["mask"] = mask + # Mask ratio that was actually sampled (since it's not exact) + sample_mask_ratio = mask.sum().item() / mask.numel() + x[self.get_name(x)] = torch.tensor([self.mask_ratio, sample_mask_ratio]) + return x + + # class MultiTransforms(v2.Transform): # def __init__(self, transforms, repeats: list = None): # super().__init__() diff --git a/stable_ssl/utils/masking.py b/stable_ssl/utils/masking.py new file mode 100644 index 00000000..e92453e9 --- /dev/null +++ b/stable_ssl/utils/masking.py @@ -0,0 +1,330 @@ +import math +import torch +from typing import Tuple +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.colors import ListedColormap +import matplotlib.patches as patches + +def _sample_block_size( + height: int, + width: int, + min_scale: float, + max_scale: float, + min_aspect_ratio: float, + max_aspect_ratio: float, +): + """Sample a single block mask for an image. + + Args: + height (int): Height of the image in patches. + width (int): Width of the image in patches. + min_scale (float): Minimum scale factor for block area relative to total image area. + max_scale (float): Maximum scale factor for block area relative to total image area. + min_aspect_ratio (float): Minimum aspect ratio (height/width) for the block. + max_aspect_ratio (float): Maximum aspect ratio (height/width) for the block. + + Returns: + tuple[int, int]: A tuple (h, w) containing the sampled block height and width. + """ + _rand = torch.rand(1).item() + mask_scale = min_scale + _rand * (max_scale - min_scale) + max_keep = int(height * width * mask_scale) + aspect_ratio = min_aspect_ratio + _rand * (max_aspect_ratio - min_aspect_ratio) + + # Compute block height and width (given scale and aspect-ratio) + h = int(round(math.sqrt(max_keep * aspect_ratio))) + h = min(h, height) + + w = int(round(math.sqrt(max_keep / aspect_ratio))) + w = min(w, width) + + return (h, w) + + +def _sample_block_mask( + image_size: Tuple[int, int], + block_size: Tuple[int, int], + min_keep: int = 1, +): + """Sample a single block mask for an image. + Because mask positions are chosen randomly, we can occasionally end up with a mask that is too small. + This function will retry until a valid mask is found. + + Args: + height (int): Height of the image in patches. + width (int): Width of the image in patches. + min_scale (float): Minimum scale factor for block area relative to total image area. + max_scale (float): Maximum scale factor for block area relative to total image area. + min_aspect_ratio (float): Minimum aspect ratio (height/width) for the block. + max_aspect_ratio (float): Maximum aspect ratio (height/width) for the block. + min_keep (int): Minimum number of patches that must be in the block. + + Returns: + tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - mask: Binary tensor indices of masked patches (flattened) + - mask_complement: Binary tensor where 1 = available for future blocks + """ + h, w = block_size + height, width = image_size + max_attempts = 20 + + for _ in range(max_attempts): + top = torch.randint(0, height - h + 1, (1,)).item() + left = torch.randint(0, width - w + 1, (1,)).item() + + mask = torch.zeros((height, width), dtype=torch.int32) + mask[top:top+h, left:left+w] = 1 + # Return the first mask that satisfies min_keep. + if torch.sum(mask) >= min_keep: return mask + + # If we run out of attempts, return whatever we had last. + else: return mask + + +def multi_block_mask( + height: int, + width: int, + num_blocks: int = 4, + context_scale: Tuple[float, float] = (0.85, 1.0), # -- enc mask scale + target_scale: Tuple[float, float] = (0.15, 0.2), # -- pred mask scale + aspect_ratio: Tuple[float, float] = (0.75, 1.5), + min_keep: int = 1, +) -> torch.Tensor: + """Generate block mask(s) for an image. + + Args: + height: Height in patches + width: Width in patches + mask_ratio: Fraction to mask + num_blocks: Number of mask blocks + aspect_ratio: (min, max) aspect ratio for blocks + min_keep: Minimum patches to keep unmasked + generator: For reproducibility + + Returns: + Binary mask of shape (height, width) where 1 = masked, 0 = visible + """ + min_scale, max_scale = context_scale + # No aspect ratio for the context block + h, w = _sample_block_size(height, width, min_scale, max_scale, 1., 1.) + + # -- Sample context mask + mask_enc = _sample_block_mask( + (height, width), + (h, w), + min_keep, + ) + + min_scale, max_scale = target_scale + min_aspect_ratio, max_aspect_ratio = aspect_ratio + + masks_pred = [] + for _ in range(num_blocks): + h, w = _sample_block_size( + height, width, + min_scale, max_scale, + min_aspect_ratio, max_aspect_ratio + ) + masks_pred += [ + _sample_block_mask( + (height, width), + (h, w), + min_keep + )] + + # NOTE Since 1 == discard and 0 == keep, combining masks is an OR operation + combined_mask = (1 - mask_enc).clone() + for mask in masks_pred[1:]: + combined_mask = torch.logical_or(combined_mask, mask) + + # -- Return masks + return mask_enc, masks_pred, combined_mask + +def visualize_masking_strategy(height=14, width=14, num_examples=6, save_path="ijepa_masking_visualization.png"): + """ + Visualize the I-JEPA masking strategy with multiple examples. + + Args: + height: Image height in patches + width: Image width in patches + num_examples: Number of masking examples to show + save_path: Path to save the visualization + """ + + fig, axes = plt.subplots(2, num_examples, figsize=(3*num_examples, 6)) + if num_examples == 1: + axes = axes.reshape(2, 1) + + # Set random seed for reproducible examples + torch.manual_seed(42) + + for i in range(num_examples): + # Generate masks + context_mask, target_masks, combined_mask = multi_block_mask( + height, width, + num_blocks=4, + context_scale=(0.85, 1.0), + target_scale=(0.15, 0.2), + aspect_ratio=(0.75, 1.5) + ) + + # Convert to numpy for visualization + context_np = context_mask.numpy() + combined_np = combined_mask.numpy() + + # Create visualization grids + vis_grid_separate = np.zeros((height, width)) + vis_grid_combined = np.zeros((height, width)) + + # Color coding: 0=visible, 1=context, 2-5=targets 1-4 + # For separate visualization + vis_grid_separate[context_np == 1] = 1 # Context in blue + for j, target_mask in enumerate(target_masks): + target_np = target_mask.numpy() + vis_grid_separate[target_np == 1] = j + 2 # Targets in different colors + + # For combined visualization + vis_grid_combined[combined_np == 1] = 1 # All masked regions + + # Plot separate masks (context + targets) + ax1 = axes[0, i] + colors = ['white', 'lightblue', 'red', 'orange', 'green', 'purple'] + cmap = ListedColormap(colors[:6]) + im1 = ax1.imshow(vis_grid_separate, cmap=cmap, vmin=0, vmax=5) + ax1.set_title(f'Example {i+1}: Context + Targets', fontsize=10) + ax1.set_xticks(range(0, width, 2)) + ax1.set_yticks(range(0, height, 2)) + ax1.grid(True, alpha=0.3) + + # Add grid lines for patches + for x in range(width + 1): + ax1.axvline(x - 0.5, color='gray', linewidth=0.5, alpha=0.5) + for y in range(height + 1): + ax1.axhline(y - 0.5, color='gray', linewidth=0.5, alpha=0.5) + + # Plot combined mask + ax2 = axes[1, i] + cmap_combined = ListedColormap(['white', 'black']) + im2 = ax2.imshow(vis_grid_combined, cmap=cmap_combined, vmin=0, vmax=1) + ax2.set_title(f'Combined Mask', fontsize=10) + ax2.set_xticks(range(0, width, 2)) + ax2.set_yticks(range(0, height, 2)) + ax2.grid(True, alpha=0.3) + + # Add grid lines for patches + for x in range(width + 1): + ax2.axvline(x - 0.5, color='gray', linewidth=0.5, alpha=0.5) + for y in range(height + 1): + ax2.axhline(y - 0.5, color='gray', linewidth=0.5, alpha=0.5) + + # Add legend + legend_elements = [ + patches.Patch(color='white', label='Visible patches'), + patches.Patch(color='lightblue', label='Context block'), + patches.Patch(color='red', label='Target block 1'), + patches.Patch(color='orange', label='Target block 2'), + patches.Patch(color='green', label='Target block 3'), + patches.Patch(color='purple', label='Target block 4') + ] + + fig.legend(handles=legend_elements, loc='center', bbox_to_anchor=(0.5, 0.02), ncol=6, fontsize=10) + + plt.suptitle('I-JEPA Masking Strategy Visualization\n(Top: Context + Target blocks, Bottom: Combined mask)', + fontsize=14, y=0.95) + plt.tight_layout() + plt.subplots_adjust(bottom=0.15, top=0.85) + + # Save the figure + plt.savefig(save_path, dpi=300, bbox_inches='tight') + print(f"Visualization saved to: {save_path}") + plt.close() + +def analyze_masking_statistics(height=14, width=14, num_samples=1000): + """ + Analyze statistics of the masking strategy. + """ + context_scales = [] + target_scales = [] + aspect_ratios = [] + + torch.manual_seed(42) + + for _ in range(num_samples): + context_mask, target_masks, combined_mask = multi_block_mask( + height, width, + num_blocks=4, + context_scale=(0.85, 1.0), + target_scale=(0.15, 0.2), + aspect_ratio=(0.75, 1.5) + ) + + total_patches = height * width + context_scale = torch.sum(context_mask).item() / total_patches + context_scales.append(context_scale) + + for target_mask in target_masks: + target_scale = torch.sum(target_mask).item() / total_patches + target_scales.append(target_scale) + + # Calculate aspect ratio of target + coords = torch.where(target_mask == 1) + if len(coords[0]) > 0: + h_extent = coords[0].max() - coords[0].min() + 1 + w_extent = coords[1].max() - coords[1].min() + 1 + aspect_ratios.append(h_extent.item() / w_extent.item()) + + # Create statistics plot + fig, axes = plt.subplots(1, 3, figsize=(15, 4)) + + axes[0].hist(context_scales, bins=30, alpha=0.7, color='lightblue', edgecolor='black') + axes[0].set_title('Context Block Scales') + axes[0].set_xlabel('Scale (fraction of image)') + axes[0].set_ylabel('Frequency') + axes[0].axvline(np.mean(context_scales), color='red', linestyle='--', + label=f'Mean: {np.mean(context_scales):.3f}') + axes[0].legend() + + axes[1].hist(target_scales, bins=30, alpha=0.7, color='orange', edgecolor='black') + axes[1].set_title('Target Block Scales') + axes[1].set_xlabel('Scale (fraction of image)') + axes[1].set_ylabel('Frequency') + axes[1].axvline(np.mean(target_scales), color='red', linestyle='--', + label=f'Mean: {np.mean(target_scales):.3f}') + axes[1].legend() + + axes[2].hist(aspect_ratios, bins=30, alpha=0.7, color='green', edgecolor='black') + axes[2].set_title('Target Block Aspect Ratios') + axes[2].set_xlabel('Aspect Ratio (height/width)') + axes[2].set_ylabel('Frequency') + axes[2].axvline(np.mean(aspect_ratios), color='red', linestyle='--', + label=f'Mean: {np.mean(aspect_ratios):.3f}') + axes[2].legend() + + plt.tight_layout() + plt.savefig('ijepa_masking_statistics.png', dpi=300, bbox_inches='tight') + print("Statistics saved to: ijepa_masking_statistics.png") + plt.close() + + return { + 'context_scales': context_scales, + 'target_scales': target_scales, + 'aspect_ratios': aspect_ratios + } + +if __name__ == "__main__": + print("Generating I-JEPA masking visualizations...") + + # Create main visualization + visualize_masking_strategy(height=14, width=14, num_examples=8, + save_path="ijepa_masking_examples.png") + + # Generate statistics + stats = analyze_masking_statistics() + + print(f"\nMasking Statistics:") + print(f"Context scale - Mean: {np.mean(stats['context_scales']):.3f}, Std: {np.std(stats['context_scales']):.3f}") + print(f"Target scale - Mean: {np.mean(stats['target_scales']):.3f}, Std: {np.std(stats['target_scales']):.3f}") + print(f"Aspect ratio - Mean: {np.mean(stats['aspect_ratios']):.3f}, Std: {np.std(stats['aspect_ratios']):.3f}") + + print("\nVisualization complete! Check the generated PNG files.") \ No newline at end of file From 329154723f06a8d542e36f98cbded560c1b10ea4 Mon Sep 17 00:00:00 2001 From: Sami Date: Sat, 9 Aug 2025 08:04:28 +0000 Subject: [PATCH 02/28] Multiblock masking complete --- stable_ssl/utils/masking.py | 332 ++++++++++++++++++------------------ 1 file changed, 167 insertions(+), 165 deletions(-) diff --git a/stable_ssl/utils/masking.py b/stable_ssl/utils/masking.py index e92453e9..408b75cb 100644 --- a/stable_ssl/utils/masking.py +++ b/stable_ssl/utils/masking.py @@ -62,8 +62,8 @@ def _sample_block_mask( Returns: tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - mask: Binary tensor indices of masked patches (flattened) - - mask_complement: Binary tensor where 1 = available for future blocks + - mask: Binary tensor indices of patches exposed to encoder (1 = masked, 0 = visible) + - pred_mask: Binary tensor where of combined target block masks to be predicted (1 = masked, 0 = visible) """ h, w = block_size height, width = image_size @@ -90,7 +90,7 @@ def multi_block_mask( target_scale: Tuple[float, float] = (0.15, 0.2), # -- pred mask scale aspect_ratio: Tuple[float, float] = (0.75, 1.5), min_keep: int = 1, -) -> torch.Tensor: +) -> tuple[torch.Tensor, torch.Tensor]: """Generate block mask(s) for an image. Args: @@ -100,10 +100,11 @@ def multi_block_mask( num_blocks: Number of mask blocks aspect_ratio: (min, max) aspect ratio for blocks min_keep: Minimum patches to keep unmasked - generator: For reproducibility Returns: - Binary mask of shape (height, width) where 1 = masked, 0 = visible + tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - Binary tensor indices of patches exposed to encoder (1 = masked, 0 = visible) + - Binary tensor where of combined target block masks to be predicted (1 = masked, 0 = visible) """ min_scale, max_scale = context_scale # No aspect ratio for the context block @@ -134,197 +135,198 @@ def multi_block_mask( )] # NOTE Since 1 == discard and 0 == keep, combining masks is an OR operation - combined_mask = (1 - mask_enc).clone() + combined_pred_mask = masks_pred[0].clone() for mask in masks_pred[1:]: - combined_mask = torch.logical_or(combined_mask, mask) + combined_pred_mask |= mask + + # Remove all target masks from the context + compliment_mask_enc = mask_enc.clone() + compliment_mask_enc &= ~combined_pred_mask # -- Return masks - return mask_enc, masks_pred, combined_mask + return compliment_mask_enc, combined_pred_mask#, masks_pred -def visualize_masking_strategy(height=14, width=14, num_examples=6, save_path="ijepa_masking_visualization.png"): - """ - Visualize the I-JEPA masking strategy with multiple examples. + +# def visualize_masking_strategy(height=14, width=14, num_examples=6, save_path="ijepa_masking_visualization.png"): +# """ +# Visualize the I-JEPA masking strategy with multiple examples. - Args: - height: Image height in patches - width: Image width in patches - num_examples: Number of masking examples to show - save_path: Path to save the visualization - """ +# Args: +# height: Image height in patches +# width: Image width in patches +# num_examples: Number of masking examples to show +# save_path: Path to save the visualization +# """ - fig, axes = plt.subplots(2, num_examples, figsize=(3*num_examples, 6)) - if num_examples == 1: - axes = axes.reshape(2, 1) +# # Each example shows: context + 4 target blocks = 5 columns total +# fig, axes = plt.subplots(num_examples, 5, figsize=(15, 3*num_examples)) +# if num_examples == 1: +# axes = axes.reshape(1, 5) - # Set random seed for reproducible examples - torch.manual_seed(42) +# # Set random seed for reproducible examples +# torch.manual_seed(42) - for i in range(num_examples): - # Generate masks - context_mask, target_masks, combined_mask = multi_block_mask( - height, width, - num_blocks=4, - context_scale=(0.85, 1.0), - target_scale=(0.15, 0.2), - aspect_ratio=(0.75, 1.5) - ) - - # Convert to numpy for visualization - context_np = context_mask.numpy() - combined_np = combined_mask.numpy() +# for i in range(num_examples): +# # Generate masks - now returns (cleaned_context, combined_targets, individual_targets) +# cleaned_context_mask, individual_target_masks = multi_block_mask( +# height, width, +# num_blocks=4, +# context_scale=(0.85, 1.0), +# target_scale=(0.15, 0.2), +# aspect_ratio=(0.75, 1.5) +# ) - # Create visualization grids - vis_grid_separate = np.zeros((height, width)) - vis_grid_combined = np.zeros((height, width)) +# # Convert to numpy for visualization +# context_np = cleaned_context_mask.numpy() - # Color coding: 0=visible, 1=context, 2-5=targets 1-4 - # For separate visualization - vis_grid_separate[context_np == 1] = 1 # Context in blue - for j, target_mask in enumerate(target_masks): - target_np = target_mask.numpy() - vis_grid_separate[target_np == 1] = j + 2 # Targets in different colors +# # Column 0: Context mask only +# ax = axes[i, 0] +# cmap_context = ListedColormap(['white', 'lightblue']) +# ax.imshow(context_np, cmap=cmap_context, vmin=0, vmax=1) +# ax.set_title('Context' if i == 0 else '', fontsize=10) +# ax.set_xticks(range(0, width, 4)) +# ax.set_yticks(range(0, height, 4)) +# ax.grid(True, alpha=0.3) - # For combined visualization - vis_grid_combined[combined_np == 1] = 1 # All masked regions +# # Add grid lines for patches +# for x in range(width + 1): +# ax.axvline(x - 0.5, color='gray', linewidth=0.5, alpha=0.5) +# for y in range(height + 1): +# ax.axhline(y - 0.5, color='gray', linewidth=0.5, alpha=0.5) - # Plot separate masks (context + targets) - ax1 = axes[0, i] - colors = ['white', 'lightblue', 'red', 'orange', 'green', 'purple'] - cmap = ListedColormap(colors[:6]) - im1 = ax1.imshow(vis_grid_separate, cmap=cmap, vmin=0, vmax=5) - ax1.set_title(f'Example {i+1}: Context + Targets', fontsize=10) - ax1.set_xticks(range(0, width, 2)) - ax1.set_yticks(range(0, height, 2)) - ax1.grid(True, alpha=0.3) +# # Add row label +# if i == 0: +# ax.set_ylabel('Example 1', fontsize=12, rotation=0, ha='right', va='center') +# else: +# ax.set_ylabel(f'Example {i+1}', fontsize=12, rotation=0, ha='right', va='center') - # Add grid lines for patches - for x in range(width + 1): - ax1.axvline(x - 0.5, color='gray', linewidth=0.5, alpha=0.5) - for y in range(height + 1): - ax1.axhline(y - 0.5, color='gray', linewidth=0.5, alpha=0.5) - - # Plot combined mask - ax2 = axes[1, i] - cmap_combined = ListedColormap(['white', 'black']) - im2 = ax2.imshow(vis_grid_combined, cmap=cmap_combined, vmin=0, vmax=1) - ax2.set_title(f'Combined Mask', fontsize=10) - ax2.set_xticks(range(0, width, 2)) - ax2.set_yticks(range(0, height, 2)) - ax2.grid(True, alpha=0.3) - - # Add grid lines for patches - for x in range(width + 1): - ax2.axvline(x - 0.5, color='gray', linewidth=0.5, alpha=0.5) - for y in range(height + 1): - ax2.axhline(y - 0.5, color='gray', linewidth=0.5, alpha=0.5) +# # Columns 1-4: Individual target masks +# target_colors = ['red', 'orange', 'green', 'purple'] +# for j, target_mask in enumerate(individual_target_masks): +# ax = axes[i, j+1] +# target_np = target_mask.numpy() + +# # Create colormap for this target +# cmap_target = ListedColormap(['white', target_colors[j]]) +# ax.imshow(target_np, cmap=cmap_target, vmin=0, vmax=1) +# ax.set_title(f'Target {j+1}' if i == 0 else '', fontsize=10) +# ax.set_xticks(range(0, width, 4)) +# ax.set_yticks(range(0, height, 4)) +# ax.grid(True, alpha=0.3) + +# # Add grid lines for patches +# for x in range(width + 1): +# ax.axvline(x - 0.5, color='gray', linewidth=0.5, alpha=0.5) +# for y in range(height + 1): +# ax.axhline(y - 0.5, color='gray', linewidth=0.5, alpha=0.5) - # Add legend - legend_elements = [ - patches.Patch(color='white', label='Visible patches'), - patches.Patch(color='lightblue', label='Context block'), - patches.Patch(color='red', label='Target block 1'), - patches.Patch(color='orange', label='Target block 2'), - patches.Patch(color='green', label='Target block 3'), - patches.Patch(color='purple', label='Target block 4') - ] +# # Add legend +# legend_elements = [ +# patches.Patch(color='white', label='Visible patches'), +# patches.Patch(color='lightblue', label='Context block'), +# patches.Patch(color='red', label='Target block 1'), +# patches.Patch(color='orange', label='Target block 2'), +# patches.Patch(color='green', label='Target block 3'), +# patches.Patch(color='purple', label='Target block 4') +# ] - fig.legend(handles=legend_elements, loc='center', bbox_to_anchor=(0.5, 0.02), ncol=6, fontsize=10) +# fig.legend(handles=legend_elements, loc='center', bbox_to_anchor=(0.5, 0.02), ncol=6, fontsize=10) - plt.suptitle('I-JEPA Masking Strategy Visualization\n(Top: Context + Target blocks, Bottom: Combined mask)', - fontsize=14, y=0.95) - plt.tight_layout() - plt.subplots_adjust(bottom=0.15, top=0.85) +# plt.suptitle('I-JEPA Masking Strategy: Context and Individual Target Blocks', +# fontsize=14, y=0.95) +# plt.tight_layout() +# plt.subplots_adjust(bottom=0.12, top=0.88) - # Save the figure - plt.savefig(save_path, dpi=300, bbox_inches='tight') - print(f"Visualization saved to: {save_path}") - plt.close() +# # Save the figure +# plt.savefig(save_path, dpi=300, bbox_inches='tight') +# print(f"Visualization saved to: {save_path}") +# plt.close() -def analyze_masking_statistics(height=14, width=14, num_samples=1000): - """ - Analyze statistics of the masking strategy. - """ - context_scales = [] - target_scales = [] - aspect_ratios = [] +# def analyze_masking_statistics(height=14, width=14, num_samples=1000): +# """ +# Analyze statistics of the masking strategy. +# """ +# context_scales = [] +# target_scales = [] +# aspect_ratios = [] - torch.manual_seed(42) +# torch.manual_seed(42) - for _ in range(num_samples): - context_mask, target_masks, combined_mask = multi_block_mask( - height, width, - num_blocks=4, - context_scale=(0.85, 1.0), - target_scale=(0.15, 0.2), - aspect_ratio=(0.75, 1.5) - ) +# for _ in range(num_samples): +# cleaned_context_mask, combined_target_mask, individual_target_masks = multi_block_mask( +# height, width, +# num_blocks=4, +# context_scale=(0.85, 1.0), +# target_scale=(0.15, 0.2), +# aspect_ratio=(0.75, 1.5) +# ) - total_patches = height * width - context_scale = torch.sum(context_mask).item() / total_patches - context_scales.append(context_scale) +# total_patches = height * width +# context_scale = torch.sum(cleaned_context_mask).item() / total_patches +# context_scales.append(context_scale) - for target_mask in target_masks: - target_scale = torch.sum(target_mask).item() / total_patches - target_scales.append(target_scale) +# for target_mask in individual_target_masks: +# target_scale = torch.sum(target_mask).item() / total_patches +# target_scales.append(target_scale) - # Calculate aspect ratio of target - coords = torch.where(target_mask == 1) - if len(coords[0]) > 0: - h_extent = coords[0].max() - coords[0].min() + 1 - w_extent = coords[1].max() - coords[1].min() + 1 - aspect_ratios.append(h_extent.item() / w_extent.item()) +# # Calculate aspect ratio of target +# coords = torch.where(target_mask == 1) +# if len(coords[0]) > 0: +# h_extent = coords[0].max() - coords[0].min() + 1 +# w_extent = coords[1].max() - coords[1].min() + 1 +# aspect_ratios.append(h_extent.item() / w_extent.item()) - # Create statistics plot - fig, axes = plt.subplots(1, 3, figsize=(15, 4)) +# # Create statistics plot +# fig, axes = plt.subplots(1, 3, figsize=(15, 4)) - axes[0].hist(context_scales, bins=30, alpha=0.7, color='lightblue', edgecolor='black') - axes[0].set_title('Context Block Scales') - axes[0].set_xlabel('Scale (fraction of image)') - axes[0].set_ylabel('Frequency') - axes[0].axvline(np.mean(context_scales), color='red', linestyle='--', - label=f'Mean: {np.mean(context_scales):.3f}') - axes[0].legend() +# axes[0].hist(context_scales, bins=30, alpha=0.7, color='lightblue', edgecolor='black') +# axes[0].set_title('Context Block Scales') +# axes[0].set_xlabel('Scale (fraction of image)') +# axes[0].set_ylabel('Frequency') +# axes[0].axvline(np.mean(context_scales), color='red', linestyle='--', +# label=f'Mean: {np.mean(context_scales):.3f}') +# axes[0].legend() - axes[1].hist(target_scales, bins=30, alpha=0.7, color='orange', edgecolor='black') - axes[1].set_title('Target Block Scales') - axes[1].set_xlabel('Scale (fraction of image)') - axes[1].set_ylabel('Frequency') - axes[1].axvline(np.mean(target_scales), color='red', linestyle='--', - label=f'Mean: {np.mean(target_scales):.3f}') - axes[1].legend() +# axes[1].hist(target_scales, bins=30, alpha=0.7, color='orange', edgecolor='black') +# axes[1].set_title('Target Block Scales') +# axes[1].set_xlabel('Scale (fraction of image)') +# axes[1].set_ylabel('Frequency') +# axes[1].axvline(np.mean(target_scales), color='red', linestyle='--', +# label=f'Mean: {np.mean(target_scales):.3f}') +# axes[1].legend() - axes[2].hist(aspect_ratios, bins=30, alpha=0.7, color='green', edgecolor='black') - axes[2].set_title('Target Block Aspect Ratios') - axes[2].set_xlabel('Aspect Ratio (height/width)') - axes[2].set_ylabel('Frequency') - axes[2].axvline(np.mean(aspect_ratios), color='red', linestyle='--', - label=f'Mean: {np.mean(aspect_ratios):.3f}') - axes[2].legend() +# axes[2].hist(aspect_ratios, bins=30, alpha=0.7, color='green', edgecolor='black') +# axes[2].set_title('Target Block Aspect Ratios') +# axes[2].set_xlabel('Aspect Ratio (height/width)') +# axes[2].set_ylabel('Frequency') +# axes[2].axvline(np.mean(aspect_ratios), color='red', linestyle='--', +# label=f'Mean: {np.mean(aspect_ratios):.3f}') +# axes[2].legend() - plt.tight_layout() - plt.savefig('ijepa_masking_statistics.png', dpi=300, bbox_inches='tight') - print("Statistics saved to: ijepa_masking_statistics.png") - plt.close() +# plt.tight_layout() +# plt.savefig('ijepa_masking_statistics.png', dpi=300, bbox_inches='tight') +# print("Statistics saved to: ijepa_masking_statistics.png") +# plt.close() - return { - 'context_scales': context_scales, - 'target_scales': target_scales, - 'aspect_ratios': aspect_ratios - } +# return { +# 'context_scales': context_scales, +# 'target_scales': target_scales, +# 'aspect_ratios': aspect_ratios +# } -if __name__ == "__main__": - print("Generating I-JEPA masking visualizations...") +# if __name__ == "__main__": +# print("Generating I-JEPA masking visualizations...") - # Create main visualization - visualize_masking_strategy(height=14, width=14, num_examples=8, - save_path="ijepa_masking_examples.png") +# # Create main visualization +# visualize_masking_strategy(height=14, width=14, num_examples=8, +# save_path="ijepa_masking_examples.png") - # Generate statistics - stats = analyze_masking_statistics() +# # Generate statistics +# stats = analyze_masking_statistics() - print(f"\nMasking Statistics:") - print(f"Context scale - Mean: {np.mean(stats['context_scales']):.3f}, Std: {np.std(stats['context_scales']):.3f}") - print(f"Target scale - Mean: {np.mean(stats['target_scales']):.3f}, Std: {np.std(stats['target_scales']):.3f}") - print(f"Aspect ratio - Mean: {np.mean(stats['aspect_ratios']):.3f}, Std: {np.std(stats['aspect_ratios']):.3f}") +# print(f"\nMasking Statistics:") +# print(f"Context scale - Mean: {np.mean(stats['context_scales']):.3f}, Std: {np.std(stats['context_scales']):.3f}") +# print(f"Target scale - Mean: {np.mean(stats['target_scales']):.3f}, Std: {np.std(stats['target_scales']):.3f}") +# print(f"Aspect ratio - Mean: {np.mean(stats['aspect_ratios']):.3f}, Std: {np.std(stats['aspect_ratios']):.3f}") - print("\nVisualization complete! Check the generated PNG files.") \ No newline at end of file +# print("\nVisualization complete! Check the generated PNG files.") \ No newline at end of file From 990ac137b6383c2d48f7192cf5a6a298c15d1a31 Mon Sep 17 00:00:00 2001 From: Sami Date: Sat, 9 Aug 2025 08:09:09 +0000 Subject: [PATCH 03/28] docstrings --- stable_ssl/utils/masking.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/stable_ssl/utils/masking.py b/stable_ssl/utils/masking.py index 408b75cb..239e8672 100644 --- a/stable_ssl/utils/masking.py +++ b/stable_ssl/utils/masking.py @@ -52,12 +52,8 @@ def _sample_block_mask( This function will retry until a valid mask is found. Args: - height (int): Height of the image in patches. - width (int): Width of the image in patches. - min_scale (float): Minimum scale factor for block area relative to total image area. - max_scale (float): Maximum scale factor for block area relative to total image area. - min_aspect_ratio (float): Minimum aspect ratio (height/width) for the block. - max_aspect_ratio (float): Maximum aspect ratio (height/width) for the block. + image_size: Tuple[int, int] - Size of the image in patches + block_size: Tuple[int, int] - Size of the block in patches min_keep (int): Minimum number of patches that must be in the block. Returns: From 517889c35e0ab852d1fbad425f18937033df425b Mon Sep 17 00:00:00 2001 From: Sami Date: Sat, 9 Aug 2025 08:15:41 +0000 Subject: [PATCH 04/28] transform --- stable_ssl/data/transforms.py | 37 +++++++++++++++++++++++++++-------- 1 file changed, 29 insertions(+), 8 deletions(-) diff --git a/stable_ssl/data/transforms.py b/stable_ssl/data/transforms.py index 1d741de8..0a44e92f 100644 --- a/stable_ssl/data/transforms.py +++ b/stable_ssl/data/transforms.py @@ -714,24 +714,45 @@ def __call__(self, sample): class MultiBlockMask(Transform): - """Transform that adds block masks to batch.""" + """Transform that adds multi-block masks to batch. - def __init__(self, patch_size=16, mask_ratio=0.75): + Args: + patch_size: Size of the patch in patches + num_blocks: Number of blocks to sample + context_scale: Scale of the context block + aspect_ratio: Aspect ratio of the blocks + min_keep: Minimum number of patches that must be in the block + + """ + + def __init__(self, + patch_size=16, + num_blocks=1, + context_scale=(0.85, 1.0), + aspect_ratio=(0.75, 1.5), + min_keep=10, + ): super().__init__() self.patch_size = patch_size - self.mask_ratio = mask_ratio + self.num_blocks = num_blocks + self.context_scale = context_scale + self.aspect_ratio = aspect_ratio + self.min_keep = min_keep def __call__(self, x): H, W = x["image"].shape[-2:] - mask = multi_block_mask( + mask_context, mask_target = multi_block_mask( H // self.patch_size, W // self.patch_size, - self.mask_ratio, + num_blocks=self.num_blocks, + context_scale=self.context_scale, + aspect_ratio=self.aspect_ratio, + min_keep=self.min_keep, ) - x["mask"] = mask + x["mask_context"] = mask_context + x["mask_target"] = mask_target # Mask ratio that was actually sampled (since it's not exact) - sample_mask_ratio = mask.sum().item() / mask.numel() - x[self.get_name(x)] = torch.tensor([self.mask_ratio, sample_mask_ratio]) + x[self.get_name(x)] = mask_target.sum() / (mask_context.numel() + mask_target.numel()) return x From 1cd7c3202587f61a0bd4b168fbc45dafa88718af Mon Sep 17 00:00:00 2001 From: Sami Date: Sun, 10 Aug 2025 09:10:59 +0000 Subject: [PATCH 05/28] masking fixed, collate fixed, ijepa nearly done --- stable_ssl/data/masking.py | 338 ++++++++++++++++++ stable_ssl/data/transforms.py | 11 +- .../tests/scripts/train_ijepa_cifar10.py | 263 ++++++++++++++ stable_ssl/utils/masking.py | 328 ----------------- stable_ssl/utils/pos_embed.py | 67 ++++ 5 files changed, 675 insertions(+), 332 deletions(-) create mode 100644 stable_ssl/data/masking.py create mode 100644 stable_ssl/tests/scripts/train_ijepa_cifar10.py delete mode 100644 stable_ssl/utils/masking.py create mode 100644 stable_ssl/utils/pos_embed.py diff --git a/stable_ssl/data/masking.py b/stable_ssl/data/masking.py new file mode 100644 index 00000000..528248ee --- /dev/null +++ b/stable_ssl/data/masking.py @@ -0,0 +1,338 @@ +import math +import torch +from typing import Tuple +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.colors import ListedColormap +import matplotlib.patches as patches + +def _sample_block_size( + height: int, + width: int, + min_scale: float, + max_scale: float, + min_aspect_ratio: float, + max_aspect_ratio: float, +): + """Sample a single block mask for an image. + + Args: + height (int): Height of the image in patches. + width (int): Width of the image in patches. + min_scale (float): Minimum scale factor for block area relative to total image area. + max_scale (float): Maximum scale factor for block area relative to total image area. + min_aspect_ratio (float): Minimum aspect ratio (height/width) for the block. + max_aspect_ratio (float): Maximum aspect ratio (height/width) for the block. + + Returns: + tuple[int, int]: A tuple (h, w) containing the sampled block height and width. + """ + _rand = torch.rand(1).item() + mask_scale = min_scale + _rand * (max_scale - min_scale) + max_keep = int(height * width * mask_scale) + aspect_ratio = min_aspect_ratio + _rand * (max_aspect_ratio - min_aspect_ratio) + + # Compute block height and width (given scale and aspect-ratio) + h = int(round(math.sqrt(max_keep * aspect_ratio))) + h = min(h, height-1) + + w = int(round(math.sqrt(max_keep / aspect_ratio))) + w = min(w, width-1) + + return (h, w) + + +def _sample_block_mask( + image_size: Tuple[int, int], + block_size: Tuple[int, int], + min_keep: int = 1, +): + """Sample a single block mask for an image. + Because mask positions are chosen randomly, we can occasionally end up with a mask that is too small. + This function will retry until a valid mask is found. + + Args: + image_size: Tuple[int, int] - Size of the image in patches + block_size: Tuple[int, int] - Size of the block in patches + min_keep (int): Minimum number of patches that must be in the block. + + Returns: + tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - mask: Binary tensor indices of patches exposed to encoder (1 = masked, 0 = visible) + - pred_mask: Binary tensor where of combined target block masks to be predicted (1 = masked, 0 = visible) + """ + h, w = block_size + height, width = image_size + max_attempts = 20 + + for _ in range(max_attempts): + top = torch.randint(0, height - h + 1, (1,)).item() + left = torch.randint(0, width - w + 1, (1,)).item() + + mask = torch.zeros((height, width), dtype=torch.int32) + mask[top:top+h, left:left+w] = 1 + + # Return the first mask that satisfies min_keep. + if torch.sum(mask) >= min_keep: return mask + + # If we run out of attempts, return whatever we had last. + else: return mask + + +def multi_block_mask( + height: int, + width: int, + num_blocks: int = 4, + context_scale: Tuple[float, float] = (0.85, 1.0), # -- enc mask scale + target_scale: Tuple[float, float] = (0.15, 0.2), # -- pred mask scale + aspect_ratio: Tuple[float, float] = (0.75, 1.5), + min_keep: int = 1, + seed: int = 0, +) -> tuple[list[torch.Tensor], list[torch.Tensor]]: + """Generate block mask(s) for an image. + + Args: + height: Height in patches + width: Width in patches + mask_ratio: Fraction to mask + num_blocks: Number of mask blocks + aspect_ratio: (min, max) aspect ratio for blocks + min_keep: Minimum patches to keep unmasked + + Returns: + tuple[list[torch.Tensor], list[torch.Tensor]]: A tuple containing: + - list[torch.Tensor] - Binary tensors indices of patches exposed to encoder (1 = masked, 0 = visible) + - list[torch.Tensor] - Binary tensors indices where of combined target block masks to be predicted (1 = masked, 0 = visible) + """ + + min_scale, max_scale = target_scale + min_aspect_ratio, max_aspect_ratio = aspect_ratio + g = torch.Generator() + g.manual_seed(seed) + + h, w = _sample_block_size( + height, width, + min_scale, max_scale, + min_aspect_ratio, max_aspect_ratio + ) + + masks_pred: list[torch.Tensor] = [] + for _ in range(num_blocks): + masks_pred += [ + _sample_block_mask( + (height, width), + (h, w), + min_keep + )] + + + min_scale, max_scale = context_scale + # No aspect ratio for the context block + h, w = _sample_block_size(height, width, min_scale, max_scale, 1., 1.) + + # -- Sample context mask + mask_enc = _sample_block_mask( + (height, width), + (h, w), + min_keep, + ) + + # NOTE Since 1 == discard and 0 == keep, combining masks is an OR operation + # Remove all target masks from the context + compliment_mask_enc = mask_enc.clone() + + for mask in masks_pred: + compliment_mask_enc &= ~mask + + # -- Return mask indices + return ( + torch.nonzero(compliment_mask_enc.flatten()).squeeze(), + [ + torch.nonzero(mask.flatten()).squeeze() + for mask in masks_pred + ] + ) + +def visualize_masking_strategy(height=14, width=14, num_examples=6, save_path="ijepa_masking_visualization.png"): + """ + Visualize the I-JEPA masking strategy with multiple examples. + + Args: + height: Image height in patches + width: Image width in patches + num_examples: Number of masking examples to show + save_path: Path to save the visualization + """ + + # Each example shows: context + 4 target blocks = 5 columns total + fig, axes = plt.subplots(num_examples, 5, figsize=(15, 3*num_examples)) + if num_examples == 1: + axes = axes.reshape(1, 5) + + # Set random seed for reproducible examples + torch.manual_seed(42) + + for i in range(num_examples): + # Generate masks - returns (context, combined_targets, individual_targets) + cleaned_context_mask, individual_target_masks = multi_block_mask( + height, width, + num_blocks=4, + context_scale=(0.85, 1.0), + target_scale=(0.15, 0.2), + aspect_ratio=(0.75, 1.5) + ) + + # Convert to numpy for visualization + context_np = cleaned_context_mask.numpy() + + # Column 0: Context mask only + ax = axes[i, 0] + cmap_context = ListedColormap(['white', 'lightblue']) + ax.imshow(context_np, cmap=cmap_context, vmin=0, vmax=1) + ax.set_title('Context' if i == 0 else '', fontsize=10) + ax.set_xticks(range(0, width, 4)) + ax.set_yticks(range(0, height, 4)) + ax.grid(True, alpha=0.3) + + # Add grid lines for patches + for x in range(width + 1): + ax.axvline(x - 0.5, color='gray', linewidth=0.5, alpha=0.5) + for y in range(height + 1): + ax.axhline(y - 0.5, color='gray', linewidth=0.5, alpha=0.5) + + # Add row label + if i == 0: + ax.set_ylabel('Example 1', fontsize=12, rotation=0, ha='right', va='center') + else: + ax.set_ylabel(f'Example {i+1}', fontsize=12, rotation=0, ha='right', va='center') + + # Columns 1-4: Individual target masks + target_colors = ['red', 'orange', 'green', 'purple'] + for j, target_mask in enumerate(individual_target_masks): + ax = axes[i, j+1] + target_np = target_mask.numpy() + + # Create colormap for this target + cmap_target = ListedColormap(['white', target_colors[j]]) + ax.imshow(target_np, cmap=cmap_target, vmin=0, vmax=1) + ax.set_title(f'Target {j+1}' if i == 0 else '', fontsize=10) + ax.set_xticks(range(0, width, 4)) + ax.set_yticks(range(0, height, 4)) + ax.grid(True, alpha=0.3) + + # Add grid lines for patches + for x in range(width + 1): + ax.axvline(x - 0.5, color='gray', linewidth=0.5, alpha=0.5) + for y in range(height + 1): + ax.axhline(y - 0.5, color='gray', linewidth=0.5, alpha=0.5) + + # Add legend + legend_elements = [ + patches.Patch(color='white', label='Visible patches'), + patches.Patch(color='lightblue', label='Context block'), + patches.Patch(color='red', label='Target block 1'), + patches.Patch(color='orange', label='Target block 2'), + patches.Patch(color='green', label='Target block 3'), + patches.Patch(color='purple', label='Target block 4') + ] + + fig.legend(handles=legend_elements, loc='center', bbox_to_anchor=(0.5, 0.02), ncol=6, fontsize=10) + + plt.suptitle('I-JEPA Masking Strategy: Context and Individual Target Blocks', + fontsize=14, y=0.95) + plt.tight_layout() + plt.subplots_adjust(bottom=0.12, top=0.88) + + # Save the figure + plt.savefig(save_path, dpi=300, bbox_inches='tight') + print(f"Visualization saved to: {save_path}") + plt.close() + +def analyze_masking_statistics(height=14, width=14, num_samples=1000): + """ + Analyze statistics of the masking strategy. + """ + context_scales = [] + target_scales = [] + aspect_ratios = [] + + torch.manual_seed(42) + + for _ in range(num_samples): + cleaned_context_mask, individual_target_masks = multi_block_mask( + height, width, + num_blocks=4, + context_scale=(0.85, 1.0), + target_scale=(0.15, 0.2), + aspect_ratio=(0.75, 1.5) + ) + + total_patches = height * width + context_scale = torch.sum(cleaned_context_mask).item() / total_patches + context_scales.append(context_scale) + + for target_mask in individual_target_masks: + target_scale = torch.sum(target_mask).item() / total_patches + target_scales.append(target_scale) + + # Calculate aspect ratio of target + coords = torch.where(target_mask == 1) + if len(coords[0]) > 0: + h_extent = coords[0].max() - coords[0].min() + 1 + w_extent = coords[1].max() - coords[1].min() + 1 + aspect_ratios.append(h_extent.item() / w_extent.item()) + + # Create statistics plot + fig, axes = plt.subplots(1, 3, figsize=(15, 4)) + + axes[0].hist(context_scales, bins=30, alpha=0.7, color='lightblue', edgecolor='black') + axes[0].set_title('Context Block Scales') + axes[0].set_xlabel('Scale (fraction of image)') + axes[0].set_ylabel('Frequency') + axes[0].axvline(np.mean(context_scales), color='red', linestyle='--', + label=f'Mean: {np.mean(context_scales):.3f}') + axes[0].legend() + + axes[1].hist(target_scales, bins=30, alpha=0.7, color='orange', edgecolor='black') + axes[1].set_title('Target Block Scales') + axes[1].set_xlabel('Scale (fraction of image)') + axes[1].set_ylabel('Frequency') + axes[1].axvline(np.mean(target_scales), color='red', linestyle='--', + label=f'Mean: {np.mean(target_scales):.3f}') + axes[1].legend() + + axes[2].hist(aspect_ratios, bins=30, alpha=0.7, color='green', edgecolor='black') + axes[2].set_title('Target Block Aspect Ratios') + axes[2].set_xlabel('Aspect Ratio (height/width)') + axes[2].set_ylabel('Frequency') + axes[2].axvline(np.mean(aspect_ratios), color='red', linestyle='--', + label=f'Mean: {np.mean(aspect_ratios):.3f}') + axes[2].legend() + + plt.tight_layout() + plt.savefig('ijepa_masking_statistics.png', dpi=300, bbox_inches='tight') + print("Statistics saved to: ijepa_masking_statistics.png") + plt.close() + + return { + 'context_scales': context_scales, + 'target_scales': target_scales, + 'aspect_ratios': aspect_ratios + } + +if __name__ == "__main__": + print("Generating I-JEPA masking visualizations...") + + # Create main visualization + visualize_masking_strategy(height=14, width=14, num_examples=8, + save_path="ijepa_masking_examples.png") + + # Generate statistics + stats = analyze_masking_statistics() + + print(f"\nMasking Statistics:") + print(f"Context scale - Mean: {np.mean(stats['context_scales']):.3f}, Std: {np.std(stats['context_scales']):.3f}") + print(f"Target scale - Mean: {np.mean(stats['target_scales']):.3f}, Std: {np.std(stats['target_scales']):.3f}") + print(f"Aspect ratio - Mean: {np.mean(stats['aspect_ratios']):.3f}, Std: {np.std(stats['aspect_ratios']):.3f}") + + print("\nVisualization complete! Check the generated PNG files.") \ No newline at end of file diff --git a/stable_ssl/data/transforms.py b/stable_ssl/data/transforms.py index 0a44e92f..097df233 100644 --- a/stable_ssl/data/transforms.py +++ b/stable_ssl/data/transforms.py @@ -15,7 +15,7 @@ from torchvision.transforms.v2 import functional as F from torchvision.transforms.v2._utils import query_chw -from stable_ssl.utils.masking import multi_block_mask +from stable_ssl.data.masking import multi_block_mask class Transform(v2.Transform): @@ -729,6 +729,7 @@ def __init__(self, patch_size=16, num_blocks=1, context_scale=(0.85, 1.0), + target_scale=(0.15, 0.2), aspect_ratio=(0.75, 1.5), min_keep=10, ): @@ -736,23 +737,25 @@ def __init__(self, self.patch_size = patch_size self.num_blocks = num_blocks self.context_scale = context_scale + self.target_scale = target_scale self.aspect_ratio = aspect_ratio self.min_keep = min_keep def __call__(self, x): - H, W = x["image"].shape[-2:] + # TODO This assumes PIL.Image because of transforms.RGB() + H, W = x["image"]._size mask_context, mask_target = multi_block_mask( H // self.patch_size, W // self.patch_size, num_blocks=self.num_blocks, context_scale=self.context_scale, + target_scale=self.target_scale, aspect_ratio=self.aspect_ratio, min_keep=self.min_keep, ) x["mask_context"] = mask_context x["mask_target"] = mask_target - # Mask ratio that was actually sampled (since it's not exact) - x[self.get_name(x)] = mask_target.sum() / (mask_context.numel() + mask_target.numel()) + x[self.get_name(x)] = torch.tensor([len(mask_context), len(mask_target)]) return x diff --git a/stable_ssl/tests/scripts/train_ijepa_cifar10.py b/stable_ssl/tests/scripts/train_ijepa_cifar10.py new file mode 100644 index 00000000..555fa363 --- /dev/null +++ b/stable_ssl/tests/scripts/train_ijepa_cifar10.py @@ -0,0 +1,263 @@ + +import torch +from torch import nn +from typing import Optional +import torchvision +import torch.nn.functional as F + +import stable_ssl as ssl +from stable_ssl.data import transforms +from stable_ssl.data.utils import Dataset +from stable_ssl.backbone.utils import TeacherStudentModule + +from timm.models.vision_transformer import VisionTransformer + + +mean = [0.485, 0.456, 0.406] +std = [0.229, 0.224, 0.225] +patch_size = 4 + +# Based on the in1k_vith14_ep300.yaml config in the ijepa repository +mask_transform_kwargs = dict( + patch_size=patch_size, + num_blocks=4, + context_scale=(0.85, 1.0), + target_scale=(0.15, 0.2), + aspect_ratio=(0.75, 1.5), + min_keep=20, +) + + +train_transform = transforms.Compose( + transforms.RGB(), + transforms.RandomResizedCrop((32, 32), scale=(0.3, 1.0)), # CIFAR-10 is 32x32, but on INET, IJEPA uses 224 + transforms.MultiBlockMask(**mask_transform_kwargs), + transforms.ToImage(mean=mean, std=std), +) +# Don't mask during validation +val_transform = transforms.Compose( + transforms.RGB(), + transforms.Resize((32, 32)), + transforms.CenterCrop((32, 32)), + transforms.ToImage(mean=mean, std=std), +) + +# Use torchvision CIFAR-10 wrapped in FromTorchDataset +cifar_train = torchvision.datasets.CIFAR10( + root="/tmp/cifar10", train=True, download=True +) +cifar_val = torchvision.datasets.CIFAR10( + root="/tmp/cifar10", train=False, download=True +) + + +class IndexedDataset(Dataset): + """Custom dataset wrapper that adds sample_idx to each sample.""" + + def __init__(self, dataset, transform=None): + super().__init__(transform) + self.dataset = dataset + + def __getitem__(self, idx): + image, label = self.dataset[idx] + sample = {"image": image, "label": label, "sample_idx": idx} + return self.process_sample(sample) + + def __len__(self): + return len(self.dataset) + + +def standardize_masks(batch: list[dict]): + + context_indices = [item.pop('mask_context') for item in batch] + target_indices = [item.pop('mask_target') for item in batch] + batch = torch.utils.data.default_collate(batch) + + min_keep_enc = min(len(ctx) for ctx in context_indices) + min_keep_pred = min( + len(tgt) + for multiblock in target_indices + for tgt in multiblock + ) + + context_batch = [ctx[:min_keep_enc] for ctx in context_indices] + target_batch = [ + [tgt[:min_keep_pred] for tgt in multiblock] + for multiblock in target_indices + ] + + collated_masks_context = torch.utils.data.default_collate(context_batch) + collated_masks_target = torch.utils.data.default_collate(target_batch) + + batch['mask_context'] = collated_masks_context + batch['mask_target'] = collated_masks_target + return batch + + +train_dataset = IndexedDataset(cifar_train, transform=train_transform) +# IJEPA does not use multi-view sampling like SimCLR etc. because it processes +# single views and handles masking at the model level +train = torch.utils.data.DataLoader( + dataset=train_dataset, + batch_size=64, + shuffle=True, # Regular shuffling, no RepeatedRandomSampler + num_workers=0, + drop_last=True, + collate_fn=standardize_masks +) + +""" +{ + 'image': torch.Size([64, 3, 32, 32]), + 'label': torch.Size([64]), + 'sample_idx': torch.Size([64]), + 'RandomResizedCrop': torch.Size([64, 4]), + 'mask_context': torch.Size([64, 8, 8]), + 'mask_target': torch.Size([64, 8, 8]), + 'MultiBlockMask': torch.Size([64]) +} +""" + + +val_dataset = IndexedDataset(cifar_val, transform=val_transform) +val = torch.utils.data.DataLoader( + dataset=val_dataset, + batch_size=128, + num_workers=0, + shuffle=True, +) + +""" +{ +'image': torch.Size([128, 3, 32, 32]), +'label': torch.Size([128]), +'sample_idx': torch.Size([128]) +} +""" + +data = ssl.data.DataModule(train=train, val=val) + + +def pos_embed(patches: torch.Tensor) -> torch.Tensor: + return patches + + +def patchify(image: torch.Tensor, patch_size: int = patch_size): + """Convert image tensor into patches. + + Args: + image: Tensor of shape [B, C, H, W] + + Returns: + patches: Tensor of shape [B, N, P*P*C] where: + N = number of patches (H/P * W/P) + P = patch size + """ + B, C, H, W = image.shape + P = patch_size + + # Unfold into patches + patches = image.unfold(2, P, P).unfold(3, P, P) + + # Reshape to [B, num_patches, patch_dim] + patches = patches.permute(0, 2, 3, 1, 4, 5) + num_patch_h, num_patch_w = patches.shape[1], patches.shape[2] + patches = patches.reshape(B, num_patch_h * num_patch_w, P*P*C) + + return patches + + +def apply_mask(patches: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + patch_dim = patches.shape[-1] + mask_expanded = mask.unsqueeze(-1).expand(-1,-1,patch_dim) + masked_patches = torch.gather(patches, dim=1, index=mask_expanded) + return masked_patches + + +class IJEPA_Encoder(VisionTransformer): + def forward_features(self, x, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + x = self.blocks(x) + x = self.norm(x) + return x + + +class IJEPA_Predictor(VisionTransformer): + # TODO + def forward_features(self, x, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + x = self.blocks(x) + x = self.norm(x) + return x + + +encoder_kwargs = dict(patch_size=4, embed_dim=768, depth=12, num_heads=12, qkv_bias=False) +context_encoder = IJEPA_Encoder(**encoder_kwargs) +target_encoder = TeacherStudentModule(context_encoder) +predictor_kwargs = dict(patch_size=4, embed_dim=384, depth=6, num_heads=6, qkv_bias=False) +predictor = IJEPA_Predictor(**predictor_kwargs) + + +def forward(self, batch, stage): + out = {} + if self.training: + mask_context, mask_target = batch['mask_context'], batch['mask_target'] + image_patches = patchify(batch['image']) + pos_embedding = pos_embed(image_patches) + target_patches = self.target_encoder.patch_project(image_patches) + # target encoder is applied on full patches, then masked + out['target_patches'] = apply_mask( + self.target_encoder(target_patches) + pos_embedding, + mask_target + ) + # context encoder is applied on masked patches + context_patches = self.context_encoder.patch_project(apply_mask(image_patches, mask_context)) + context_pos = apply_mask(pos_embedding, mask_context) + out['context_patches'] = self.context_encoder(context_patches + context_pos) + + out['predicted_patches'] = self.predictor( + out['context_patches'], + mask_context, + mask_target, + ) + out['loss'] = self.ijepa_loss( + apply_mask(out['predicted_patches'], mask_target), + out['target_patches'] + ) + else: + image_patches = patchify(batch['image']) + patches = self.target_encoder.patch_project(image_patches) + return self.target_encoder(image_patches) + + +module = ssl.Module( + context_encoder=context_encoder, + target_encoder=target_encoder, + predictor=predictor, + forward=forward, + ijepa_loss=F.mse_loss, +) + + + +if __name__ == "__main__": + train_iter = iter(train) + val_iter = iter(val) + for batch in train_iter: + print({k:v.shape for k,v in batch.items()}) + break + for batch in val_iter: + print({k:v.shape for k,v in batch.items()}) + break + + +# def ijepa_forward(self: ssl.Module, batch: dict) -> dict: +# mask_keep, mask_discard = batch['mask'], 1-batch['mask'] +# batch['tgt_patches'] = self.tgt_enc(mask_discard & batch['images']) +# batch['ctx_patches'] = self.ctx_enc(mask_keep & batch['images']) +# batch['pred_patches'] = self.pred (mask_keep, batch['ctx_patches']) +# # -- calculate loss +# batch['loss'] = self.ijepa_loss(batch) +# return batch + + +# def custom_ijepa_loss(self: ssl.Module, batch: dict) -> Tensor: +# return F.mse(batch['pred_patches'], batch['tgt_patches']) + self.vicreg_loss(batch['pred_patches']) diff --git a/stable_ssl/utils/masking.py b/stable_ssl/utils/masking.py deleted file mode 100644 index 239e8672..00000000 --- a/stable_ssl/utils/masking.py +++ /dev/null @@ -1,328 +0,0 @@ -import math -import torch -from typing import Tuple -import matplotlib.pyplot as plt -import numpy as np -from matplotlib.colors import ListedColormap -import matplotlib.patches as patches - -def _sample_block_size( - height: int, - width: int, - min_scale: float, - max_scale: float, - min_aspect_ratio: float, - max_aspect_ratio: float, -): - """Sample a single block mask for an image. - - Args: - height (int): Height of the image in patches. - width (int): Width of the image in patches. - min_scale (float): Minimum scale factor for block area relative to total image area. - max_scale (float): Maximum scale factor for block area relative to total image area. - min_aspect_ratio (float): Minimum aspect ratio (height/width) for the block. - max_aspect_ratio (float): Maximum aspect ratio (height/width) for the block. - - Returns: - tuple[int, int]: A tuple (h, w) containing the sampled block height and width. - """ - _rand = torch.rand(1).item() - mask_scale = min_scale + _rand * (max_scale - min_scale) - max_keep = int(height * width * mask_scale) - aspect_ratio = min_aspect_ratio + _rand * (max_aspect_ratio - min_aspect_ratio) - - # Compute block height and width (given scale and aspect-ratio) - h = int(round(math.sqrt(max_keep * aspect_ratio))) - h = min(h, height) - - w = int(round(math.sqrt(max_keep / aspect_ratio))) - w = min(w, width) - - return (h, w) - - -def _sample_block_mask( - image_size: Tuple[int, int], - block_size: Tuple[int, int], - min_keep: int = 1, -): - """Sample a single block mask for an image. - Because mask positions are chosen randomly, we can occasionally end up with a mask that is too small. - This function will retry until a valid mask is found. - - Args: - image_size: Tuple[int, int] - Size of the image in patches - block_size: Tuple[int, int] - Size of the block in patches - min_keep (int): Minimum number of patches that must be in the block. - - Returns: - tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - mask: Binary tensor indices of patches exposed to encoder (1 = masked, 0 = visible) - - pred_mask: Binary tensor where of combined target block masks to be predicted (1 = masked, 0 = visible) - """ - h, w = block_size - height, width = image_size - max_attempts = 20 - - for _ in range(max_attempts): - top = torch.randint(0, height - h + 1, (1,)).item() - left = torch.randint(0, width - w + 1, (1,)).item() - - mask = torch.zeros((height, width), dtype=torch.int32) - mask[top:top+h, left:left+w] = 1 - # Return the first mask that satisfies min_keep. - if torch.sum(mask) >= min_keep: return mask - - # If we run out of attempts, return whatever we had last. - else: return mask - - -def multi_block_mask( - height: int, - width: int, - num_blocks: int = 4, - context_scale: Tuple[float, float] = (0.85, 1.0), # -- enc mask scale - target_scale: Tuple[float, float] = (0.15, 0.2), # -- pred mask scale - aspect_ratio: Tuple[float, float] = (0.75, 1.5), - min_keep: int = 1, -) -> tuple[torch.Tensor, torch.Tensor]: - """Generate block mask(s) for an image. - - Args: - height: Height in patches - width: Width in patches - mask_ratio: Fraction to mask - num_blocks: Number of mask blocks - aspect_ratio: (min, max) aspect ratio for blocks - min_keep: Minimum patches to keep unmasked - - Returns: - tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - Binary tensor indices of patches exposed to encoder (1 = masked, 0 = visible) - - Binary tensor where of combined target block masks to be predicted (1 = masked, 0 = visible) - """ - min_scale, max_scale = context_scale - # No aspect ratio for the context block - h, w = _sample_block_size(height, width, min_scale, max_scale, 1., 1.) - - # -- Sample context mask - mask_enc = _sample_block_mask( - (height, width), - (h, w), - min_keep, - ) - - min_scale, max_scale = target_scale - min_aspect_ratio, max_aspect_ratio = aspect_ratio - - masks_pred = [] - for _ in range(num_blocks): - h, w = _sample_block_size( - height, width, - min_scale, max_scale, - min_aspect_ratio, max_aspect_ratio - ) - masks_pred += [ - _sample_block_mask( - (height, width), - (h, w), - min_keep - )] - - # NOTE Since 1 == discard and 0 == keep, combining masks is an OR operation - combined_pred_mask = masks_pred[0].clone() - for mask in masks_pred[1:]: - combined_pred_mask |= mask - - # Remove all target masks from the context - compliment_mask_enc = mask_enc.clone() - compliment_mask_enc &= ~combined_pred_mask - - # -- Return masks - return compliment_mask_enc, combined_pred_mask#, masks_pred - - -# def visualize_masking_strategy(height=14, width=14, num_examples=6, save_path="ijepa_masking_visualization.png"): -# """ -# Visualize the I-JEPA masking strategy with multiple examples. - -# Args: -# height: Image height in patches -# width: Image width in patches -# num_examples: Number of masking examples to show -# save_path: Path to save the visualization -# """ - -# # Each example shows: context + 4 target blocks = 5 columns total -# fig, axes = plt.subplots(num_examples, 5, figsize=(15, 3*num_examples)) -# if num_examples == 1: -# axes = axes.reshape(1, 5) - -# # Set random seed for reproducible examples -# torch.manual_seed(42) - -# for i in range(num_examples): -# # Generate masks - now returns (cleaned_context, combined_targets, individual_targets) -# cleaned_context_mask, individual_target_masks = multi_block_mask( -# height, width, -# num_blocks=4, -# context_scale=(0.85, 1.0), -# target_scale=(0.15, 0.2), -# aspect_ratio=(0.75, 1.5) -# ) - -# # Convert to numpy for visualization -# context_np = cleaned_context_mask.numpy() - -# # Column 0: Context mask only -# ax = axes[i, 0] -# cmap_context = ListedColormap(['white', 'lightblue']) -# ax.imshow(context_np, cmap=cmap_context, vmin=0, vmax=1) -# ax.set_title('Context' if i == 0 else '', fontsize=10) -# ax.set_xticks(range(0, width, 4)) -# ax.set_yticks(range(0, height, 4)) -# ax.grid(True, alpha=0.3) - -# # Add grid lines for patches -# for x in range(width + 1): -# ax.axvline(x - 0.5, color='gray', linewidth=0.5, alpha=0.5) -# for y in range(height + 1): -# ax.axhline(y - 0.5, color='gray', linewidth=0.5, alpha=0.5) - -# # Add row label -# if i == 0: -# ax.set_ylabel('Example 1', fontsize=12, rotation=0, ha='right', va='center') -# else: -# ax.set_ylabel(f'Example {i+1}', fontsize=12, rotation=0, ha='right', va='center') - -# # Columns 1-4: Individual target masks -# target_colors = ['red', 'orange', 'green', 'purple'] -# for j, target_mask in enumerate(individual_target_masks): -# ax = axes[i, j+1] -# target_np = target_mask.numpy() - -# # Create colormap for this target -# cmap_target = ListedColormap(['white', target_colors[j]]) -# ax.imshow(target_np, cmap=cmap_target, vmin=0, vmax=1) -# ax.set_title(f'Target {j+1}' if i == 0 else '', fontsize=10) -# ax.set_xticks(range(0, width, 4)) -# ax.set_yticks(range(0, height, 4)) -# ax.grid(True, alpha=0.3) - -# # Add grid lines for patches -# for x in range(width + 1): -# ax.axvline(x - 0.5, color='gray', linewidth=0.5, alpha=0.5) -# for y in range(height + 1): -# ax.axhline(y - 0.5, color='gray', linewidth=0.5, alpha=0.5) - -# # Add legend -# legend_elements = [ -# patches.Patch(color='white', label='Visible patches'), -# patches.Patch(color='lightblue', label='Context block'), -# patches.Patch(color='red', label='Target block 1'), -# patches.Patch(color='orange', label='Target block 2'), -# patches.Patch(color='green', label='Target block 3'), -# patches.Patch(color='purple', label='Target block 4') -# ] - -# fig.legend(handles=legend_elements, loc='center', bbox_to_anchor=(0.5, 0.02), ncol=6, fontsize=10) - -# plt.suptitle('I-JEPA Masking Strategy: Context and Individual Target Blocks', -# fontsize=14, y=0.95) -# plt.tight_layout() -# plt.subplots_adjust(bottom=0.12, top=0.88) - -# # Save the figure -# plt.savefig(save_path, dpi=300, bbox_inches='tight') -# print(f"Visualization saved to: {save_path}") -# plt.close() - -# def analyze_masking_statistics(height=14, width=14, num_samples=1000): -# """ -# Analyze statistics of the masking strategy. -# """ -# context_scales = [] -# target_scales = [] -# aspect_ratios = [] - -# torch.manual_seed(42) - -# for _ in range(num_samples): -# cleaned_context_mask, combined_target_mask, individual_target_masks = multi_block_mask( -# height, width, -# num_blocks=4, -# context_scale=(0.85, 1.0), -# target_scale=(0.15, 0.2), -# aspect_ratio=(0.75, 1.5) -# ) - -# total_patches = height * width -# context_scale = torch.sum(cleaned_context_mask).item() / total_patches -# context_scales.append(context_scale) - -# for target_mask in individual_target_masks: -# target_scale = torch.sum(target_mask).item() / total_patches -# target_scales.append(target_scale) - -# # Calculate aspect ratio of target -# coords = torch.where(target_mask == 1) -# if len(coords[0]) > 0: -# h_extent = coords[0].max() - coords[0].min() + 1 -# w_extent = coords[1].max() - coords[1].min() + 1 -# aspect_ratios.append(h_extent.item() / w_extent.item()) - -# # Create statistics plot -# fig, axes = plt.subplots(1, 3, figsize=(15, 4)) - -# axes[0].hist(context_scales, bins=30, alpha=0.7, color='lightblue', edgecolor='black') -# axes[0].set_title('Context Block Scales') -# axes[0].set_xlabel('Scale (fraction of image)') -# axes[0].set_ylabel('Frequency') -# axes[0].axvline(np.mean(context_scales), color='red', linestyle='--', -# label=f'Mean: {np.mean(context_scales):.3f}') -# axes[0].legend() - -# axes[1].hist(target_scales, bins=30, alpha=0.7, color='orange', edgecolor='black') -# axes[1].set_title('Target Block Scales') -# axes[1].set_xlabel('Scale (fraction of image)') -# axes[1].set_ylabel('Frequency') -# axes[1].axvline(np.mean(target_scales), color='red', linestyle='--', -# label=f'Mean: {np.mean(target_scales):.3f}') -# axes[1].legend() - -# axes[2].hist(aspect_ratios, bins=30, alpha=0.7, color='green', edgecolor='black') -# axes[2].set_title('Target Block Aspect Ratios') -# axes[2].set_xlabel('Aspect Ratio (height/width)') -# axes[2].set_ylabel('Frequency') -# axes[2].axvline(np.mean(aspect_ratios), color='red', linestyle='--', -# label=f'Mean: {np.mean(aspect_ratios):.3f}') -# axes[2].legend() - -# plt.tight_layout() -# plt.savefig('ijepa_masking_statistics.png', dpi=300, bbox_inches='tight') -# print("Statistics saved to: ijepa_masking_statistics.png") -# plt.close() - -# return { -# 'context_scales': context_scales, -# 'target_scales': target_scales, -# 'aspect_ratios': aspect_ratios -# } - -# if __name__ == "__main__": -# print("Generating I-JEPA masking visualizations...") - -# # Create main visualization -# visualize_masking_strategy(height=14, width=14, num_examples=8, -# save_path="ijepa_masking_examples.png") - -# # Generate statistics -# stats = analyze_masking_statistics() - -# print(f"\nMasking Statistics:") -# print(f"Context scale - Mean: {np.mean(stats['context_scales']):.3f}, Std: {np.std(stats['context_scales']):.3f}") -# print(f"Target scale - Mean: {np.mean(stats['target_scales']):.3f}, Std: {np.std(stats['target_scales']):.3f}") -# print(f"Aspect ratio - Mean: {np.mean(stats['aspect_ratios']):.3f}, Std: {np.std(stats['aspect_ratios']):.3f}") - -# print("\nVisualization complete! Check the generated PNG files.") \ No newline at end of file diff --git a/stable_ssl/utils/pos_embed.py b/stable_ssl/utils/pos_embed.py new file mode 100644 index 00000000..be78e875 --- /dev/null +++ b/stable_ssl/utils/pos_embed.py @@ -0,0 +1,67 @@ +import numpy as np + +# -------------------------------------------------------- +# 2D sine-cosine position embedding +# References: +# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py +# MoCo v3: https://github.com/facebookresearch/moco-v3 +# -------------------------------------------------------- + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """Get 1D sinusoidal positional embedding from grid. + + Args: + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + + Returns: + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=float) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """Get 2D sinusoidal positional embedding. + + Args: + embed_dim: embedding dimension + grid_size: int of the grid height and width + cls_token: whether to include class token + + Returns: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed From db9860a356ce626ecca0afb5148b9041fc1b9db6 Mon Sep 17 00:00:00 2001 From: Sami Date: Sun, 10 Aug 2025 09:11:40 +0000 Subject: [PATCH 06/28] small bug --- stable_ssl/tests/scripts/train_ijepa_cifar10.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stable_ssl/tests/scripts/train_ijepa_cifar10.py b/stable_ssl/tests/scripts/train_ijepa_cifar10.py index 555fa363..e765de3a 100644 --- a/stable_ssl/tests/scripts/train_ijepa_cifar10.py +++ b/stable_ssl/tests/scripts/train_ijepa_cifar10.py @@ -196,7 +196,7 @@ def forward_features(self, x, attn_mask: Optional[torch.Tensor] = None) -> torch predictor = IJEPA_Predictor(**predictor_kwargs) -def forward(self, batch, stage): +def forward(self: ssl.Module, batch, stage): out = {} if self.training: mask_context, mask_target = batch['mask_context'], batch['mask_target'] @@ -225,7 +225,7 @@ def forward(self, batch, stage): else: image_patches = patchify(batch['image']) patches = self.target_encoder.patch_project(image_patches) - return self.target_encoder(image_patches) + return self.target_encoder(patches) module = ssl.Module( From e1fab883eac7483886273801e61a9d2f72c1c6a9 Mon Sep 17 00:00:00 2001 From: Sami Date: Sun, 10 Aug 2025 09:14:37 +0000 Subject: [PATCH 07/28] import --- .../tests/scripts/train_ijepa_cifar10.py | 21 +------------------ 1 file changed, 1 insertion(+), 20 deletions(-) diff --git a/stable_ssl/tests/scripts/train_ijepa_cifar10.py b/stable_ssl/tests/scripts/train_ijepa_cifar10.py index e765de3a..94a40038 100644 --- a/stable_ssl/tests/scripts/train_ijepa_cifar10.py +++ b/stable_ssl/tests/scripts/train_ijepa_cifar10.py @@ -9,6 +9,7 @@ from stable_ssl.data import transforms from stable_ssl.data.utils import Dataset from stable_ssl.backbone.utils import TeacherStudentModule +from stable_ssl.utils.pos_embed import get_1d_sincos_pos_embed_from_grid from timm.models.vision_transformer import VisionTransformer @@ -106,19 +107,6 @@ def standardize_masks(batch: list[dict]): collate_fn=standardize_masks ) -""" -{ - 'image': torch.Size([64, 3, 32, 32]), - 'label': torch.Size([64]), - 'sample_idx': torch.Size([64]), - 'RandomResizedCrop': torch.Size([64, 4]), - 'mask_context': torch.Size([64, 8, 8]), - 'mask_target': torch.Size([64, 8, 8]), - 'MultiBlockMask': torch.Size([64]) -} -""" - - val_dataset = IndexedDataset(cifar_val, transform=val_transform) val = torch.utils.data.DataLoader( dataset=val_dataset, @@ -127,13 +115,6 @@ def standardize_masks(batch: list[dict]): shuffle=True, ) -""" -{ -'image': torch.Size([128, 3, 32, 32]), -'label': torch.Size([128]), -'sample_idx': torch.Size([128]) -} -""" data = ssl.data.DataModule(train=train, val=val) From 6367714ac68cdc2fd10657052e6fbfe9fac393c4 Mon Sep 17 00:00:00 2001 From: Sami Date: Mon, 11 Aug 2025 09:48:02 +0000 Subject: [PATCH 08/28] IJEPA - need to do sinusoidal posemb still --- stable_ssl/data/masking.py | 89 ++--- stable_ssl/data/transforms.py | 48 ++- .../tests/scripts/train_ijepa_cifar10.py | 308 +++++++++++------- 3 files changed, 249 insertions(+), 196 deletions(-) diff --git a/stable_ssl/data/masking.py b/stable_ssl/data/masking.py index 528248ee..fbcfd4b9 100644 --- a/stable_ssl/data/masking.py +++ b/stable_ssl/data/masking.py @@ -1,6 +1,5 @@ import math import torch -from typing import Tuple import matplotlib.pyplot as plt import numpy as np from matplotlib.colors import ListedColormap @@ -43,8 +42,8 @@ def _sample_block_size( def _sample_block_mask( - image_size: Tuple[int, int], - block_size: Tuple[int, int], + image_size: tuple[int, int], + block_size: tuple[int, int], min_keep: int = 1, ): """Sample a single block mask for an image. @@ -58,8 +57,8 @@ def _sample_block_mask( Returns: tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - mask: Binary tensor indices of patches exposed to encoder (1 = masked, 0 = visible) - - pred_mask: Binary tensor where of combined target block masks to be predicted (1 = masked, 0 = visible) + - mask: Binary tensor indices of patches exposed to encoder (1 = visible, 0 = masked) + - pred_mask: Binary tensor where of combined target block masks to be predicted (1 = visible, 0 = masked) """ h, w = block_size height, width = image_size @@ -82,76 +81,32 @@ def _sample_block_mask( def multi_block_mask( height: int, width: int, - num_blocks: int = 4, - context_scale: Tuple[float, float] = (0.85, 1.0), # -- enc mask scale - target_scale: Tuple[float, float] = (0.15, 0.2), # -- pred mask scale - aspect_ratio: Tuple[float, float] = (0.75, 1.5), + block_scales: list[tuple[float, float]] = [(0.85, 1.0), *((0.15, 0.2),)*4], + aspect_ratios: list[tuple[float, float]] = [(1.0, 1.0), *((0.75, 1.5),)*4], min_keep: int = 1, seed: int = 0, -) -> tuple[list[torch.Tensor], list[torch.Tensor]]: - """Generate block mask(s) for an image. - - Args: - height: Height in patches - width: Width in patches - mask_ratio: Fraction to mask - num_blocks: Number of mask blocks - aspect_ratio: (min, max) aspect ratio for blocks - min_keep: Minimum patches to keep unmasked - - Returns: - tuple[list[torch.Tensor], list[torch.Tensor]]: A tuple containing: - - list[torch.Tensor] - Binary tensors indices of patches exposed to encoder (1 = masked, 0 = visible) - - list[torch.Tensor] - Binary tensors indices where of combined target block masks to be predicted (1 = masked, 0 = visible) - """ - - min_scale, max_scale = target_scale - min_aspect_ratio, max_aspect_ratio = aspect_ratio +) -> list[torch.Tensor, ...]: g = torch.Generator() g.manual_seed(seed) - h, w = _sample_block_size( - height, width, - min_scale, max_scale, - min_aspect_ratio, max_aspect_ratio - ) - - masks_pred: list[torch.Tensor] = [] - for _ in range(num_blocks): - masks_pred += [ - _sample_block_mask( - (height, width), - (h, w), - min_keep - )] - - - min_scale, max_scale = context_scale - # No aspect ratio for the context block - h, w = _sample_block_size(height, width, min_scale, max_scale, 1., 1.) + # mapping from unique combinations of size x aspect ratio to the block size. + block_scale_to_size = { + (scale, ratio): _sample_block_size( + height, width, + scale[0], scale[1], + ratio[0], ratio[1] + ) + for scale, ratio in set(zip(block_scales, aspect_ratios)) + } - # -- Sample context mask - mask_enc = _sample_block_mask( - (height, width), - (h, w), - min_keep, - ) + masks: list[torch.Tensor] = [ + _sample_block_mask((height, width), block_scale_to_size[((sh, sw), (ah,aw))], min_keep) + for (sh,sw), (ah, aw) in zip(block_scales, aspect_ratios) + ] + # -- Return binary masks + return masks - # NOTE Since 1 == discard and 0 == keep, combining masks is an OR operation - # Remove all target masks from the context - compliment_mask_enc = mask_enc.clone() - - for mask in masks_pred: - compliment_mask_enc &= ~mask - # -- Return mask indices - return ( - torch.nonzero(compliment_mask_enc.flatten()).squeeze(), - [ - torch.nonzero(mask.flatten()).squeeze() - for mask in masks_pred - ] - ) def visualize_masking_strategy(height=14, width=14, num_examples=6, save_path="ijepa_masking_visualization.png"): """ diff --git a/stable_ssl/data/transforms.py b/stable_ssl/data/transforms.py index 097df233..82985888 100644 --- a/stable_ssl/data/transforms.py +++ b/stable_ssl/data/transforms.py @@ -713,8 +713,8 @@ def __call__(self, sample): return sample -class MultiBlockMask(Transform): - """Transform that adds multi-block masks to batch. +class ContextTargetsMultiBlockMask(Transform): + """Transform that adds multi-block masks to batch, with multiple target blocks and one disjoint context block. Args: patch_size: Size of the patch in patches @@ -727,35 +727,47 @@ class MultiBlockMask(Transform): def __init__(self, patch_size=16, - num_blocks=1, context_scale=(0.85, 1.0), - target_scale=(0.15, 0.2), - aspect_ratio=(0.75, 1.5), + context_aspect_ratio=(1.0, 1.0), + target_scales=((0.15, 0.2),)*4, + target_aspect_ratios=((0.75, 1.5),)*4, min_keep=10, ): super().__init__() self.patch_size = patch_size - self.num_blocks = num_blocks self.context_scale = context_scale - self.target_scale = target_scale - self.aspect_ratio = aspect_ratio + self.context_aspect_ratio = context_aspect_ratio + self.target_scales = target_scales + self.target_aspect_ratios = target_aspect_ratios + + if len(target_scales) != len(target_aspect_ratios): + raise ValueError( + f'Each scale must have its associated aspect ratio and vice versa.', + 'Received {len(target_scales)=} {len(target_aspect_ratios)=}' + ) + self.min_keep = min_keep - + def __call__(self, x): - # TODO This assumes PIL.Image because of transforms.RGB() H, W = x["image"]._size - mask_context, mask_target = multi_block_mask( + # TODO Could this ever fully hide the context? If so, should + # the guardrail be in here or in multi_block_mask? Definitely shouldn't be after batch is formed + scales = [self.context_scale, *self.target_scales] + aspect_ratios = [self.context_aspect_ratio, *self.target_aspect_ratios] + context_mask, *target_masks = multi_block_mask( H // self.patch_size, W // self.patch_size, - num_blocks=self.num_blocks, - context_scale=self.context_scale, - target_scale=self.target_scale, - aspect_ratio=self.aspect_ratio, + block_scales=scales, + aspect_ratios=aspect_ratios, min_keep=self.min_keep, ) - x["mask_context"] = mask_context - x["mask_target"] = mask_target - x[self.get_name(x)] = torch.tensor([len(mask_context), len(mask_target)]) + # makes targets disjoint with context + for mask in target_masks: + context_mask &= ~mask + + x["mask_context"] = torch.nonzero(context_mask).flatten().squeeze() + x["masks_target"] = [torch.nonzero(mask).flatten().squeeze() for mask in target_masks] + x[self.get_name(x)] = torch.tensor([scales, aspect_ratios]) return x diff --git a/stable_ssl/tests/scripts/train_ijepa_cifar10.py b/stable_ssl/tests/scripts/train_ijepa_cifar10.py index 94a40038..60be23f9 100644 --- a/stable_ssl/tests/scripts/train_ijepa_cifar10.py +++ b/stable_ssl/tests/scripts/train_ijepa_cifar10.py @@ -1,9 +1,11 @@ - +import lightning as pl import torch from torch import nn from typing import Optional import torchvision +import torchmetrics import torch.nn.functional as F +from lightning.pytorch.loggers import WandbLogger import stable_ssl as ssl from stable_ssl.data import transforms @@ -16,30 +18,34 @@ mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] -patch_size = 4 +height, width, patch_size = 32, 32, 4 +crop_height, crop_width = 32, 32 # # CIFAR-10 is 32x32, but on INET, IJEPA uses 224 +# We precompute these so the predictor can make sinusoidal posembeds +num_patches = (height // patch_size) * (width // patch_size) +patch_channel_dim = 3 * patch_size * patch_size # Based on the in1k_vith14_ep300.yaml config in the ijepa repository mask_transform_kwargs = dict( patch_size=patch_size, - num_blocks=4, context_scale=(0.85, 1.0), - target_scale=(0.15, 0.2), - aspect_ratio=(0.75, 1.5), + context_aspect_ratio=(1.0, 1.0), + target_scales=((0.15, 0.2),)*4, + target_aspect_ratios=((0.75, 1.5),)*4, min_keep=20, ) train_transform = transforms.Compose( transforms.RGB(), - transforms.RandomResizedCrop((32, 32), scale=(0.3, 1.0)), # CIFAR-10 is 32x32, but on INET, IJEPA uses 224 - transforms.MultiBlockMask(**mask_transform_kwargs), + transforms.RandomResizedCrop((crop_height, crop_width), scale=(0.3, 1.0)), + transforms.ContextTargetsMultiBlockMask(**mask_transform_kwargs), transforms.ToImage(mean=mean, std=std), ) # Don't mask during validation val_transform = transforms.Compose( transforms.RGB(), - transforms.Resize((32, 32)), - transforms.CenterCrop((32, 32)), + transforms.Resize((height, width)), + transforms.CenterCrop((height, width)), transforms.ToImage(mean=mean, std=std), ) @@ -71,14 +77,17 @@ def __len__(self): def standardize_masks(batch: list[dict]): context_indices = [item.pop('mask_context') for item in batch] - target_indices = [item.pop('mask_target') for item in batch] + target_indices = [item.pop('masks_target') for item in batch] batch = torch.utils.data.default_collate(batch) - + """ + [c.unique().shape for c in collated_masks_pred] + [torch.Size([194]), torch.Size([168]), torch.Size([190]), torch.Size([158])] + """ min_keep_enc = min(len(ctx) for ctx in context_indices) min_keep_pred = min( - len(tgt) + len(block) for multiblock in target_indices - for tgt in multiblock + for block in multiblock ) context_batch = [ctx[:min_keep_enc] for ctx in context_indices] @@ -90,8 +99,8 @@ def standardize_masks(batch: list[dict]): collated_masks_context = torch.utils.data.default_collate(context_batch) collated_masks_target = torch.utils.data.default_collate(target_batch) - batch['mask_context'] = collated_masks_context - batch['mask_target'] = collated_masks_target + batch['mask_context'] = collated_masks_context + batch['masks_target'] = collated_masks_target return batch @@ -119,126 +128,203 @@ def standardize_masks(batch: list[dict]): data = ssl.data.DataModule(train=train, val=val) -def pos_embed(patches: torch.Tensor) -> torch.Tensor: - return patches +# TODO sinusoidal posembed. for now this is just a dummy fn +def pos_embed(patches: torch.Tensor, device: torch.device, TMP_DIM = 768) -> torch.Tensor: + return torch.zeros(patches.shape[0], patches.shape[1], TMP_DIM, device=device) -def patchify(image: torch.Tensor, patch_size: int = patch_size): - """Convert image tensor into patches. - - Args: - image: Tensor of shape [B, C, H, W] - - Returns: - patches: Tensor of shape [B, N, P*P*C] where: - N = number of patches (H/P * W/P) - P = patch size - """ - B, C, H, W = image.shape - P = patch_size - - # Unfold into patches - patches = image.unfold(2, P, P).unfold(3, P, P) - - # Reshape to [B, num_patches, patch_dim] - patches = patches.permute(0, 2, 3, 1, 4, 5) - num_patch_h, num_patch_w = patches.shape[1], patches.shape[2] - patches = patches.reshape(B, num_patch_h * num_patch_w, P*P*C) +def apply_masks(x: torch.Tensor, *masks: torch.Tensor) -> torch.Tensor: + B, N, D = x.shape + M = len(masks) + idx = torch.stack([m.to(x.device, dtype=torch.long) for m in masks], dim=1) # [B, M, K] + x_exp = x.unsqueeze(1).expand(-1, M, -1, -1) # [B, M, N, D] + out = x_exp.gather(2, idx.unsqueeze(-1).expand(-1, -1, -1, D)) # [B, M, K, D] + return out.reshape(B * M, idx.size(-1), D) # [B*M, K, D] - return patches +class IJEPA_Encoder(VisionTransformer): + # TODO + def __init__(self, *args, **kwargs): + self.weight_init = kwargs.get('weight_init', '') + self.fix_init = kwargs.get('fix_init', False) + ijepa_in_dim = kwargs.pop('ijepa_in_dim') + super().__init__(*args, **kwargs) + + self.ijepa_patch_project = nn.Linear(ijepa_in_dim, self.embed_dim) + + if self.weight_init != 'skip': + self.init_weights(self.weight_init) + if self.fix_init: + self.fix_init_weight() + + def patchify(self, image: torch.Tensor) -> torch.Tensor: + """Convert image tensor into patches. + + Args: + image: Tensor of shape [B, C, H, W] + + Returns: + patches: Tensor of shape [B, N, P*P*C] where: + N = number of patches (H/P * W/P) + P = patch size + """ + B, C, H, W = image.shape + P = patch_size + + # Unfold into patches + patches = image.unfold(2, P, P).unfold(3, P, P) + + # Reshape to [B, num_patches, patch_dim] + patches = patches.permute(0, 2, 3, 1, 4, 5) + num_patch_h, num_patch_w = patches.shape[1], patches.shape[2] + patches = patches.reshape(B, num_patch_h * num_patch_w, P*P*C) -def apply_mask(patches: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: - patch_dim = patches.shape[-1] - mask_expanded = mask.unsqueeze(-1).expand(-1,-1,patch_dim) - masked_patches = torch.gather(patches, dim=1, index=mask_expanded) - return masked_patches + return patches + def project_patches(self, patches: torch.Tensor) -> torch.Tensor: + # assume they are already reshaped to patches + return self.ijepa_patch_project(patches) -class IJEPA_Encoder(VisionTransformer): - def forward_features(self, x, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: - x = self.blocks(x) - x = self.norm(x) + def encode_patches(self, patches: torch.Tensor, with_layernorm: bool = True) -> torch.Tensor: + x = self.blocks(patches) + if with_layernorm: + x = F.layer_norm(x, (x.size(-1),)) return x + def encode_image(self, image: torch.Tensor) -> torch.Tensor: + patches = self.patchify(image) + patches = self.project_patches(patches) + patches = patches + pos_embed(patches) + return self.encode_patches(patches) + class IJEPA_Predictor(VisionTransformer): - # TODO - def forward_features(self, x, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + def __init__(self, *args, **kwargs): + self.weight_init = kwargs.get('weight_init', '') + self.fix_init = kwargs.get('fix_init', False) + self.predictor_num_patches = kwargs.pop('predictor_num_patches') + self.ijepa_encoder_dim = kwargs.pop('ijepa_encoder_dim') + # TODO Fix device somehow and embeddim + self.predictor_pos_embed = pos_embed(torch.zeros(1, self.predictor_num_patches, kwargs['embed_dim']), device=torch.device('cuda'), TMP_DIM=kwargs['embed_dim']) + super().__init__(*args, **kwargs) + self.predictor_pos_embed = nn.Parameter(self.predictor_pos_embed, requires_grad=False) + self.predictor_inproj = nn.Linear(self.ijepa_encoder_dim, self.embed_dim) + self.predictor_outproj = nn.Linear(self.embed_dim, self.ijepa_encoder_dim) + self.predictor_norm = nn.LayerNorm(self.embed_dim) + self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) + + if self.weight_init != 'skip': + self.init_weights(self.weight_init) + if self.fix_init: + self.fix_init_weight() + + def predict_targets(self, context_patches: torch.Tensor, masks_target: list[torch.Tensor]) -> torch.Tensor: + # NOTE context_patches already positionally embedded + B, *_ = context_patches.shape + M = len(masks_target) + + ctx: torch.Tensor = self.predictor_inproj(context_patches) # [B, N_ctx, Dp] + + # target position embeddings (stacked per mask): [B*M, K_tgt, D] + pos_all = self.predictor_pos_embed.expand(B, -1, -1) + tgt_pos = apply_masks(pos_all, *masks_target) + + # repeat context across M target blocks: [B*M, N_ctx, D]. this means that + # the predictor predicts each target block independently, and not their union, + # as the ijepa repo does + ctx = ctx.repeat_interleave(M, dim=0) + + # mask tokens placed at target positions + N_tgt = tgt_pos.size(1) + pred_tokens = self.mask_token.expand(B*M, N_tgt, -1) + tgt_pos + + # each target block now gets predicted: [B*M, N_ctx+N_tgt, D] + x = torch.cat([ctx, pred_tokens], dim=1) x = self.blocks(x) - x = self.norm(x) - return x + x = self.predictor_norm(x) + pred = x[:, -N_tgt:] + pred = self.predictor_outproj(pred) + return pred -encoder_kwargs = dict(patch_size=4, embed_dim=768, depth=12, num_heads=12, qkv_bias=False) -context_encoder = IJEPA_Encoder(**encoder_kwargs) -target_encoder = TeacherStudentModule(context_encoder) -predictor_kwargs = dict(patch_size=4, embed_dim=384, depth=6, num_heads=6, qkv_bias=False) -predictor = IJEPA_Predictor(**predictor_kwargs) +encoder_kwargs = dict(patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, ijepa_in_dim=patch_channel_dim) +predictor_kwargs = dict(patch_size=patch_size, embed_dim=384, depth=6, num_heads=6, qkv_bias=False, ijepa_encoder_dim=768, + predictor_num_patches=num_patches) def forward(self: ssl.Module, batch, stage): out = {} - if self.training: - mask_context, mask_target = batch['mask_context'], batch['mask_target'] - image_patches = patchify(batch['image']) - pos_embedding = pos_embed(image_patches) - target_patches = self.target_encoder.patch_project(image_patches) - # target encoder is applied on full patches, then masked - out['target_patches'] = apply_mask( - self.target_encoder(target_patches) + pos_embedding, - mask_target - ) - # context encoder is applied on masked patches - context_patches = self.context_encoder.patch_project(apply_mask(image_patches, mask_context)) - context_pos = apply_mask(pos_embedding, mask_context) - out['context_patches'] = self.context_encoder(context_patches + context_pos) - - out['predicted_patches'] = self.predictor( - out['context_patches'], - mask_context, - mask_target, - ) - out['loss'] = self.ijepa_loss( - apply_mask(out['predicted_patches'], mask_target), - out['target_patches'] - ) - else: - image_patches = patchify(batch['image']) - patches = self.target_encoder.patch_project(image_patches) - return self.target_encoder(patches) + target_encoder: IJEPA_Encoder = self.target_encoder.teacher # NOTE Would this break anything? + context_encoder: IJEPA_Encoder = self.context_encoder + predictor: IJEPA_Predictor = self.predictor + ijepa_loss: nn.Module = self.ijepa_loss + + image_patches = target_encoder.patchify(batch['image']) + pos_embedding = pos_embed(image_patches, device=batch['image'].device) + target_patches = target_encoder.project_patches(image_patches) + pos_embedding + out['embedding'] = target_encoder.encode_patches(target_patches, with_layernorm=True) + + if not self.training: + return out + + mask_context, masks_target = batch['mask_context'], batch['masks_target'] + # target encoder is applied on full patches, then masked + out['target_patches'] = apply_masks(out['embedding'], *masks_target) + # context encoder is applied on masked patches + context_patches = apply_masks(image_patches, mask_context) + context_patches = context_encoder.project_patches(context_patches) + context_patches = context_patches + apply_masks(pos_embedding, mask_context) + out['context_patches'] = context_encoder.encode_patches(context_patches, with_layernorm=False) + out['predicted_patches'] = predictor.predict_targets(context_patches, masks_target) + + out['loss'] = ijepa_loss( + out['predicted_patches'], + out['target_patches'] + ) + return out module = ssl.Module( - context_encoder=context_encoder, - target_encoder=target_encoder, - predictor=predictor, + context_encoder=(ctx := IJEPA_Encoder(**encoder_kwargs)), + target_encoder=TeacherStudentModule(ctx), + predictor=IJEPA_Predictor(**predictor_kwargs), forward=forward, - ijepa_loss=F.mse_loss, + ijepa_loss=F.smooth_l1_loss, ) +trainer = pl.Trainer( + max_epochs=6, + num_sanity_val_steps=0, + precision='16-mixed', + enable_checkpointing=False, +) +knn_probe = ssl.callbacks.OnlineKNN( + name="knn_probe", + input="embedding", + target="label", + queue_length=20000, + metrics={"accuracy": torchmetrics.classification.MulticlassAccuracy(10)}, + input_dim=512, + k=10, +) -if __name__ == "__main__": - train_iter = iter(train) - val_iter = iter(val) - for batch in train_iter: - print({k:v.shape for k,v in batch.items()}) - break - for batch in val_iter: - print({k:v.shape for k,v in batch.items()}) - break - - -# def ijepa_forward(self: ssl.Module, batch: dict) -> dict: -# mask_keep, mask_discard = batch['mask'], 1-batch['mask'] -# batch['tgt_patches'] = self.tgt_enc(mask_discard & batch['images']) -# batch['ctx_patches'] = self.ctx_enc(mask_keep & batch['images']) -# batch['pred_patches'] = self.pred (mask_keep, batch['ctx_patches']) -# # -- calculate loss -# batch['loss'] = self.ijepa_loss(batch) -# return batch - +# Initialize W&B logger with explicit settings +wandb_logger = WandbLogger( + project="ijepa-cifar10", + entity="slightly-more-badass", # Your W&B entity + name="ijepa-cifar10-run", + log_model=False, # Set to True if you want to save model artifacts + offline=False, # Ensure online mode +) -# def custom_ijepa_loss(self: ssl.Module, batch: dict) -> Tensor: -# return F.mse(batch['pred_patches'], batch['tgt_patches']) + self.vicreg_loss(batch['pred_patches']) +trainer = pl.Trainer( + max_epochs=6, + num_sanity_val_steps=0, # Skip sanity check as queues need to be filled first + callbacks=[knn_probe], + precision="16-mixed", + logger=wandb_logger, + enable_checkpointing=False, +) +manager = ssl.Manager(trainer=trainer, module=module, data=data) +manager() From 585314728199607062a5a133732b3af96926bc7a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 11 Aug 2025 09:50:59 +0000 Subject: [PATCH 09/28] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- stable_ssl/data/masking.py | 261 ++++++++++-------- stable_ssl/data/transforms.py | 25 +- .../tests/scripts/train_ijepa_cifar10.py | 190 +++++++------ 3 files changed, 269 insertions(+), 207 deletions(-) diff --git a/stable_ssl/data/masking.py b/stable_ssl/data/masking.py index fbcfd4b9..bc81e9fe 100644 --- a/stable_ssl/data/masking.py +++ b/stable_ssl/data/masking.py @@ -1,9 +1,11 @@ import math -import torch + +import matplotlib.patches as patches import matplotlib.pyplot as plt import numpy as np +import torch from matplotlib.colors import ListedColormap -import matplotlib.patches as patches + def _sample_block_size( height: int, @@ -22,7 +24,7 @@ def _sample_block_size( max_scale (float): Maximum scale factor for block area relative to total image area. min_aspect_ratio (float): Minimum aspect ratio (height/width) for the block. max_aspect_ratio (float): Maximum aspect ratio (height/width) for the block. - + Returns: tuple[int, int]: A tuple (h, w) containing the sampled block height and width. """ @@ -30,13 +32,13 @@ def _sample_block_size( mask_scale = min_scale + _rand * (max_scale - min_scale) max_keep = int(height * width * mask_scale) aspect_ratio = min_aspect_ratio + _rand * (max_aspect_ratio - min_aspect_ratio) - + # Compute block height and width (given scale and aspect-ratio) h = int(round(math.sqrt(max_keep * aspect_ratio))) - h = min(h, height-1) + h = min(h, height - 1) w = int(round(math.sqrt(max_keep / aspect_ratio))) - w = min(w, width-1) + w = min(w, width - 1) return (h, w) @@ -54,7 +56,7 @@ def _sample_block_mask( image_size: Tuple[int, int] - Size of the image in patches block_size: Tuple[int, int] - Size of the block in patches min_keep (int): Minimum number of patches that must be in the block. - + Returns: tuple[torch.Tensor, torch.Tensor]: A tuple containing: - mask: Binary tensor indices of patches exposed to encoder (1 = visible, 0 = masked) @@ -67,227 +69,262 @@ def _sample_block_mask( for _ in range(max_attempts): top = torch.randint(0, height - h + 1, (1,)).item() left = torch.randint(0, width - w + 1, (1,)).item() - + mask = torch.zeros((height, width), dtype=torch.int32) - mask[top:top+h, left:left+w] = 1 + mask[top : top + h, left : left + w] = 1 # Return the first mask that satisfies min_keep. - if torch.sum(mask) >= min_keep: return mask + if torch.sum(mask) >= min_keep: + return mask # If we run out of attempts, return whatever we had last. - else: return mask + else: + return mask def multi_block_mask( height: int, width: int, - block_scales: list[tuple[float, float]] = [(0.85, 1.0), *((0.15, 0.2),)*4], - aspect_ratios: list[tuple[float, float]] = [(1.0, 1.0), *((0.75, 1.5),)*4], + block_scales: list[tuple[float, float]] = [(0.85, 1.0), *((0.15, 0.2),) * 4], + aspect_ratios: list[tuple[float, float]] = [(1.0, 1.0), *((0.75, 1.5),) * 4], min_keep: int = 1, seed: int = 0, ) -> list[torch.Tensor, ...]: g = torch.Generator() g.manual_seed(seed) - + # mapping from unique combinations of size x aspect ratio to the block size. block_scale_to_size = { (scale, ratio): _sample_block_size( - height, width, - scale[0], scale[1], - ratio[0], ratio[1] + height, width, scale[0], scale[1], ratio[0], ratio[1] ) for scale, ratio in set(zip(block_scales, aspect_ratios)) } masks: list[torch.Tensor] = [ - _sample_block_mask((height, width), block_scale_to_size[((sh, sw), (ah,aw))], min_keep) - for (sh,sw), (ah, aw) in zip(block_scales, aspect_ratios) + _sample_block_mask( + (height, width), block_scale_to_size[((sh, sw), (ah, aw))], min_keep + ) + for (sh, sw), (ah, aw) in zip(block_scales, aspect_ratios) ] # -- Return binary masks return masks +def visualize_masking_strategy( + height=14, width=14, num_examples=6, save_path="ijepa_masking_visualization.png" +): + """Visualize the I-JEPA masking strategy with multiple examples. -def visualize_masking_strategy(height=14, width=14, num_examples=6, save_path="ijepa_masking_visualization.png"): - """ - Visualize the I-JEPA masking strategy with multiple examples. - Args: height: Image height in patches - width: Image width in patches + width: Image width in patches num_examples: Number of masking examples to show save_path: Path to save the visualization """ - # Each example shows: context + 4 target blocks = 5 columns total - fig, axes = plt.subplots(num_examples, 5, figsize=(15, 3*num_examples)) + fig, axes = plt.subplots(num_examples, 5, figsize=(15, 3 * num_examples)) if num_examples == 1: axes = axes.reshape(1, 5) - + # Set random seed for reproducible examples torch.manual_seed(42) - + for i in range(num_examples): # Generate masks - returns (context, combined_targets, individual_targets) cleaned_context_mask, individual_target_masks = multi_block_mask( - height, width, + height, + width, num_blocks=4, context_scale=(0.85, 1.0), target_scale=(0.15, 0.2), - aspect_ratio=(0.75, 1.5) + aspect_ratio=(0.75, 1.5), ) - + # Convert to numpy for visualization context_np = cleaned_context_mask.numpy() - + # Column 0: Context mask only ax = axes[i, 0] - cmap_context = ListedColormap(['white', 'lightblue']) + cmap_context = ListedColormap(["white", "lightblue"]) ax.imshow(context_np, cmap=cmap_context, vmin=0, vmax=1) - ax.set_title('Context' if i == 0 else '', fontsize=10) + ax.set_title("Context" if i == 0 else "", fontsize=10) ax.set_xticks(range(0, width, 4)) ax.set_yticks(range(0, height, 4)) ax.grid(True, alpha=0.3) - + # Add grid lines for patches for x in range(width + 1): - ax.axvline(x - 0.5, color='gray', linewidth=0.5, alpha=0.5) + ax.axvline(x - 0.5, color="gray", linewidth=0.5, alpha=0.5) for y in range(height + 1): - ax.axhline(y - 0.5, color='gray', linewidth=0.5, alpha=0.5) - + ax.axhline(y - 0.5, color="gray", linewidth=0.5, alpha=0.5) + # Add row label if i == 0: - ax.set_ylabel('Example 1', fontsize=12, rotation=0, ha='right', va='center') + ax.set_ylabel("Example 1", fontsize=12, rotation=0, ha="right", va="center") else: - ax.set_ylabel(f'Example {i+1}', fontsize=12, rotation=0, ha='right', va='center') - + ax.set_ylabel( + f"Example {i + 1}", fontsize=12, rotation=0, ha="right", va="center" + ) + # Columns 1-4: Individual target masks - target_colors = ['red', 'orange', 'green', 'purple'] + target_colors = ["red", "orange", "green", "purple"] for j, target_mask in enumerate(individual_target_masks): - ax = axes[i, j+1] + ax = axes[i, j + 1] target_np = target_mask.numpy() - + # Create colormap for this target - cmap_target = ListedColormap(['white', target_colors[j]]) + cmap_target = ListedColormap(["white", target_colors[j]]) ax.imshow(target_np, cmap=cmap_target, vmin=0, vmax=1) - ax.set_title(f'Target {j+1}' if i == 0 else '', fontsize=10) + ax.set_title(f"Target {j + 1}" if i == 0 else "", fontsize=10) ax.set_xticks(range(0, width, 4)) ax.set_yticks(range(0, height, 4)) ax.grid(True, alpha=0.3) - + # Add grid lines for patches for x in range(width + 1): - ax.axvline(x - 0.5, color='gray', linewidth=0.5, alpha=0.5) + ax.axvline(x - 0.5, color="gray", linewidth=0.5, alpha=0.5) for y in range(height + 1): - ax.axhline(y - 0.5, color='gray', linewidth=0.5, alpha=0.5) - + ax.axhline(y - 0.5, color="gray", linewidth=0.5, alpha=0.5) + # Add legend legend_elements = [ - patches.Patch(color='white', label='Visible patches'), - patches.Patch(color='lightblue', label='Context block'), - patches.Patch(color='red', label='Target block 1'), - patches.Patch(color='orange', label='Target block 2'), - patches.Patch(color='green', label='Target block 3'), - patches.Patch(color='purple', label='Target block 4') + patches.Patch(color="white", label="Visible patches"), + patches.Patch(color="lightblue", label="Context block"), + patches.Patch(color="red", label="Target block 1"), + patches.Patch(color="orange", label="Target block 2"), + patches.Patch(color="green", label="Target block 3"), + patches.Patch(color="purple", label="Target block 4"), ] - - fig.legend(handles=legend_elements, loc='center', bbox_to_anchor=(0.5, 0.02), ncol=6, fontsize=10) - - plt.suptitle('I-JEPA Masking Strategy: Context and Individual Target Blocks', - fontsize=14, y=0.95) + + fig.legend( + handles=legend_elements, + loc="center", + bbox_to_anchor=(0.5, 0.02), + ncol=6, + fontsize=10, + ) + + plt.suptitle( + "I-JEPA Masking Strategy: Context and Individual Target Blocks", + fontsize=14, + y=0.95, + ) plt.tight_layout() plt.subplots_adjust(bottom=0.12, top=0.88) - + # Save the figure - plt.savefig(save_path, dpi=300, bbox_inches='tight') + plt.savefig(save_path, dpi=300, bbox_inches="tight") print(f"Visualization saved to: {save_path}") plt.close() + def analyze_masking_statistics(height=14, width=14, num_samples=1000): - """ - Analyze statistics of the masking strategy. - """ + """Analyze statistics of the masking strategy.""" context_scales = [] target_scales = [] aspect_ratios = [] - + torch.manual_seed(42) - + for _ in range(num_samples): cleaned_context_mask, individual_target_masks = multi_block_mask( - height, width, + height, + width, num_blocks=4, context_scale=(0.85, 1.0), target_scale=(0.15, 0.2), - aspect_ratio=(0.75, 1.5) + aspect_ratio=(0.75, 1.5), ) - + total_patches = height * width context_scale = torch.sum(cleaned_context_mask).item() / total_patches context_scales.append(context_scale) - + for target_mask in individual_target_masks: target_scale = torch.sum(target_mask).item() / total_patches target_scales.append(target_scale) - + # Calculate aspect ratio of target coords = torch.where(target_mask == 1) if len(coords[0]) > 0: h_extent = coords[0].max() - coords[0].min() + 1 w_extent = coords[1].max() - coords[1].min() + 1 aspect_ratios.append(h_extent.item() / w_extent.item()) - + # Create statistics plot fig, axes = plt.subplots(1, 3, figsize=(15, 4)) - - axes[0].hist(context_scales, bins=30, alpha=0.7, color='lightblue', edgecolor='black') - axes[0].set_title('Context Block Scales') - axes[0].set_xlabel('Scale (fraction of image)') - axes[0].set_ylabel('Frequency') - axes[0].axvline(np.mean(context_scales), color='red', linestyle='--', - label=f'Mean: {np.mean(context_scales):.3f}') + + axes[0].hist( + context_scales, bins=30, alpha=0.7, color="lightblue", edgecolor="black" + ) + axes[0].set_title("Context Block Scales") + axes[0].set_xlabel("Scale (fraction of image)") + axes[0].set_ylabel("Frequency") + axes[0].axvline( + np.mean(context_scales), + color="red", + linestyle="--", + label=f"Mean: {np.mean(context_scales):.3f}", + ) axes[0].legend() - - axes[1].hist(target_scales, bins=30, alpha=0.7, color='orange', edgecolor='black') - axes[1].set_title('Target Block Scales') - axes[1].set_xlabel('Scale (fraction of image)') - axes[1].set_ylabel('Frequency') - axes[1].axvline(np.mean(target_scales), color='red', linestyle='--', - label=f'Mean: {np.mean(target_scales):.3f}') + + axes[1].hist(target_scales, bins=30, alpha=0.7, color="orange", edgecolor="black") + axes[1].set_title("Target Block Scales") + axes[1].set_xlabel("Scale (fraction of image)") + axes[1].set_ylabel("Frequency") + axes[1].axvline( + np.mean(target_scales), + color="red", + linestyle="--", + label=f"Mean: {np.mean(target_scales):.3f}", + ) axes[1].legend() - - axes[2].hist(aspect_ratios, bins=30, alpha=0.7, color='green', edgecolor='black') - axes[2].set_title('Target Block Aspect Ratios') - axes[2].set_xlabel('Aspect Ratio (height/width)') - axes[2].set_ylabel('Frequency') - axes[2].axvline(np.mean(aspect_ratios), color='red', linestyle='--', - label=f'Mean: {np.mean(aspect_ratios):.3f}') + + axes[2].hist(aspect_ratios, bins=30, alpha=0.7, color="green", edgecolor="black") + axes[2].set_title("Target Block Aspect Ratios") + axes[2].set_xlabel("Aspect Ratio (height/width)") + axes[2].set_ylabel("Frequency") + axes[2].axvline( + np.mean(aspect_ratios), + color="red", + linestyle="--", + label=f"Mean: {np.mean(aspect_ratios):.3f}", + ) axes[2].legend() - + plt.tight_layout() - plt.savefig('ijepa_masking_statistics.png', dpi=300, bbox_inches='tight') + plt.savefig("ijepa_masking_statistics.png", dpi=300, bbox_inches="tight") print("Statistics saved to: ijepa_masking_statistics.png") plt.close() - + return { - 'context_scales': context_scales, - 'target_scales': target_scales, - 'aspect_ratios': aspect_ratios + "context_scales": context_scales, + "target_scales": target_scales, + "aspect_ratios": aspect_ratios, } + if __name__ == "__main__": print("Generating I-JEPA masking visualizations...") - + # Create main visualization - visualize_masking_strategy(height=14, width=14, num_examples=8, - save_path="ijepa_masking_examples.png") - + visualize_masking_strategy( + height=14, width=14, num_examples=8, save_path="ijepa_masking_examples.png" + ) + # Generate statistics stats = analyze_masking_statistics() - - print(f"\nMasking Statistics:") - print(f"Context scale - Mean: {np.mean(stats['context_scales']):.3f}, Std: {np.std(stats['context_scales']):.3f}") - print(f"Target scale - Mean: {np.mean(stats['target_scales']):.3f}, Std: {np.std(stats['target_scales']):.3f}") - print(f"Aspect ratio - Mean: {np.mean(stats['aspect_ratios']):.3f}, Std: {np.std(stats['aspect_ratios']):.3f}") - - print("\nVisualization complete! Check the generated PNG files.") \ No newline at end of file + + print("\nMasking Statistics:") + print( + f"Context scale - Mean: {np.mean(stats['context_scales']):.3f}, Std: {np.std(stats['context_scales']):.3f}" + ) + print( + f"Target scale - Mean: {np.mean(stats['target_scales']):.3f}, Std: {np.std(stats['target_scales']):.3f}" + ) + print( + f"Aspect ratio - Mean: {np.mean(stats['aspect_ratios']):.3f}, Std: {np.std(stats['aspect_ratios']):.3f}" + ) + + print("\nVisualization complete! Check the generated PNG files.") diff --git a/stable_ssl/data/transforms.py b/stable_ssl/data/transforms.py index 82985888..f33aeb10 100644 --- a/stable_ssl/data/transforms.py +++ b/stable_ssl/data/transforms.py @@ -715,22 +715,23 @@ def __call__(self, sample): class ContextTargetsMultiBlockMask(Transform): """Transform that adds multi-block masks to batch, with multiple target blocks and one disjoint context block. - + Args: patch_size: Size of the patch in patches num_blocks: Number of blocks to sample context_scale: Scale of the context block aspect_ratio: Aspect ratio of the blocks min_keep: Minimum number of patches that must be in the block - + """ - - def __init__(self, + + def __init__( + self, patch_size=16, context_scale=(0.85, 1.0), context_aspect_ratio=(1.0, 1.0), - target_scales=((0.15, 0.2),)*4, - target_aspect_ratios=((0.75, 1.5),)*4, + target_scales=((0.15, 0.2),) * 4, + target_aspect_ratios=((0.75, 1.5),) * 4, min_keep=10, ): super().__init__() @@ -742,8 +743,8 @@ def __init__(self, if len(target_scales) != len(target_aspect_ratios): raise ValueError( - f'Each scale must have its associated aspect ratio and vice versa.', - 'Received {len(target_scales)=} {len(target_aspect_ratios)=}' + "Each scale must have its associated aspect ratio and vice versa.", + "Received {len(target_scales)=} {len(target_aspect_ratios)=}", ) self.min_keep = min_keep @@ -752,8 +753,8 @@ def __call__(self, x): H, W = x["image"]._size # TODO Could this ever fully hide the context? If so, should # the guardrail be in here or in multi_block_mask? Definitely shouldn't be after batch is formed - scales = [self.context_scale, *self.target_scales] - aspect_ratios = [self.context_aspect_ratio, *self.target_aspect_ratios] + scales = [self.context_scale, *self.target_scales] + aspect_ratios = [self.context_aspect_ratio, *self.target_aspect_ratios] context_mask, *target_masks = multi_block_mask( H // self.patch_size, W // self.patch_size, @@ -766,7 +767,9 @@ def __call__(self, x): context_mask &= ~mask x["mask_context"] = torch.nonzero(context_mask).flatten().squeeze() - x["masks_target"] = [torch.nonzero(mask).flatten().squeeze() for mask in target_masks] + x["masks_target"] = [ + torch.nonzero(mask).flatten().squeeze() for mask in target_masks + ] x[self.get_name(x)] = torch.tensor([scales, aspect_ratios]) return x diff --git a/stable_ssl/tests/scripts/train_ijepa_cifar10.py b/stable_ssl/tests/scripts/train_ijepa_cifar10.py index 60be23f9..2fe30911 100644 --- a/stable_ssl/tests/scripts/train_ijepa_cifar10.py +++ b/stable_ssl/tests/scripts/train_ijepa_cifar10.py @@ -1,20 +1,16 @@ import lightning as pl import torch -from torch import nn -from typing import Optional -import torchvision -import torchmetrics import torch.nn.functional as F +import torchmetrics +import torchvision from lightning.pytorch.loggers import WandbLogger +from timm.models.vision_transformer import VisionTransformer +from torch import nn import stable_ssl as ssl +from stable_ssl.backbone.utils import TeacherStudentModule from stable_ssl.data import transforms from stable_ssl.data.utils import Dataset -from stable_ssl.backbone.utils import TeacherStudentModule -from stable_ssl.utils.pos_embed import get_1d_sincos_pos_embed_from_grid - -from timm.models.vision_transformer import VisionTransformer - mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] @@ -29,8 +25,8 @@ patch_size=patch_size, context_scale=(0.85, 1.0), context_aspect_ratio=(1.0, 1.0), - target_scales=((0.15, 0.2),)*4, - target_aspect_ratios=((0.75, 1.5),)*4, + target_scales=((0.15, 0.2),) * 4, + target_aspect_ratios=((0.75, 1.5),) * 4, min_keep=20, ) @@ -41,7 +37,7 @@ transforms.ContextTargetsMultiBlockMask(**mask_transform_kwargs), transforms.ToImage(mean=mean, std=std), ) -# Don't mask during validation +# Don't mask during validation val_transform = transforms.Compose( transforms.RGB(), transforms.Resize((height, width)), @@ -75,32 +71,28 @@ def __len__(self): def standardize_masks(batch: list[dict]): - - context_indices = [item.pop('mask_context') for item in batch] - target_indices = [item.pop('masks_target') for item in batch] - batch = torch.utils.data.default_collate(batch) + context_indices = [item.pop("mask_context") for item in batch] + target_indices = [item.pop("masks_target") for item in batch] + batch = torch.utils.data.default_collate(batch) """ [c.unique().shape for c in collated_masks_pred] [torch.Size([194]), torch.Size([168]), torch.Size([190]), torch.Size([158])] """ - min_keep_enc = min(len(ctx) for ctx in context_indices) - min_keep_pred = min( - len(block) - for multiblock in target_indices - for block in multiblock + min_keep_enc = min(len(ctx) for ctx in context_indices) + min_keep_pred = min( + len(block) for multiblock in target_indices for block in multiblock ) - - context_batch = [ctx[:min_keep_enc] for ctx in context_indices] - target_batch = [ - [tgt[:min_keep_pred] for tgt in multiblock] - for multiblock in target_indices + + context_batch = [ctx[:min_keep_enc] for ctx in context_indices] + target_batch = [ + [tgt[:min_keep_pred] for tgt in multiblock] for multiblock in target_indices ] collated_masks_context = torch.utils.data.default_collate(context_batch) collated_masks_target = torch.utils.data.default_collate(target_batch) - batch['mask_context'] = collated_masks_context - batch['masks_target'] = collated_masks_target + batch["mask_context"] = collated_masks_context + batch["masks_target"] = collated_masks_target return batch @@ -113,7 +105,7 @@ def standardize_masks(batch: list[dict]): shuffle=True, # Regular shuffling, no RepeatedRandomSampler num_workers=0, drop_last=True, - collate_fn=standardize_masks + collate_fn=standardize_masks, ) val_dataset = IndexedDataset(cifar_val, transform=val_transform) @@ -129,40 +121,42 @@ def standardize_masks(batch: list[dict]): # TODO sinusoidal posembed. for now this is just a dummy fn -def pos_embed(patches: torch.Tensor, device: torch.device, TMP_DIM = 768) -> torch.Tensor: +def pos_embed(patches: torch.Tensor, device: torch.device, TMP_DIM=768) -> torch.Tensor: return torch.zeros(patches.shape[0], patches.shape[1], TMP_DIM, device=device) def apply_masks(x: torch.Tensor, *masks: torch.Tensor) -> torch.Tensor: B, N, D = x.shape M = len(masks) - idx = torch.stack([m.to(x.device, dtype=torch.long) for m in masks], dim=1) # [B, M, K] - x_exp = x.unsqueeze(1).expand(-1, M, -1, -1) # [B, M, N, D] - out = x_exp.gather(2, idx.unsqueeze(-1).expand(-1, -1, -1, D)) # [B, M, K, D] - return out.reshape(B * M, idx.size(-1), D) # [B*M, K, D] + idx = torch.stack( + [m.to(x.device, dtype=torch.long) for m in masks], dim=1 + ) # [B, M, K] + x_exp = x.unsqueeze(1).expand(-1, M, -1, -1) # [B, M, N, D] + out = x_exp.gather(2, idx.unsqueeze(-1).expand(-1, -1, -1, D)) # [B, M, K, D] + return out.reshape(B * M, idx.size(-1), D) # [B*M, K, D] class IJEPA_Encoder(VisionTransformer): # TODO def __init__(self, *args, **kwargs): - self.weight_init = kwargs.get('weight_init', '') - self.fix_init = kwargs.get('fix_init', False) - ijepa_in_dim = kwargs.pop('ijepa_in_dim') + self.weight_init = kwargs.get("weight_init", "") + self.fix_init = kwargs.get("fix_init", False) + ijepa_in_dim = kwargs.pop("ijepa_in_dim") super().__init__(*args, **kwargs) self.ijepa_patch_project = nn.Linear(ijepa_in_dim, self.embed_dim) - if self.weight_init != 'skip': + if self.weight_init != "skip": self.init_weights(self.weight_init) if self.fix_init: self.fix_init_weight() def patchify(self, image: torch.Tensor) -> torch.Tensor: """Convert image tensor into patches. - + Args: image: Tensor of shape [B, C, H, W] - + Returns: patches: Tensor of shape [B, N, P*P*C] where: N = number of patches (H/P * W/P) @@ -170,22 +164,24 @@ def patchify(self, image: torch.Tensor) -> torch.Tensor: """ B, C, H, W = image.shape P = patch_size - + # Unfold into patches patches = image.unfold(2, P, P).unfold(3, P, P) - + # Reshape to [B, num_patches, patch_dim] patches = patches.permute(0, 2, 3, 1, 4, 5) num_patch_h, num_patch_w = patches.shape[1], patches.shape[2] - patches = patches.reshape(B, num_patch_h * num_patch_w, P*P*C) + patches = patches.reshape(B, num_patch_h * num_patch_w, P * P * C) return patches def project_patches(self, patches: torch.Tensor) -> torch.Tensor: - # assume they are already reshaped to patches + # assume they are already reshaped to patches return self.ijepa_patch_project(patches) - def encode_patches(self, patches: torch.Tensor, with_layernorm: bool = True) -> torch.Tensor: + def encode_patches( + self, patches: torch.Tensor, with_layernorm: bool = True + ) -> torch.Tensor: x = self.blocks(patches) if with_layernorm: x = F.layer_norm(x, (x.size(-1),)) @@ -200,43 +196,51 @@ def encode_image(self, image: torch.Tensor) -> torch.Tensor: class IJEPA_Predictor(VisionTransformer): def __init__(self, *args, **kwargs): - self.weight_init = kwargs.get('weight_init', '') - self.fix_init = kwargs.get('fix_init', False) - self.predictor_num_patches = kwargs.pop('predictor_num_patches') - self.ijepa_encoder_dim = kwargs.pop('ijepa_encoder_dim') + self.weight_init = kwargs.get("weight_init", "") + self.fix_init = kwargs.get("fix_init", False) + self.predictor_num_patches = kwargs.pop("predictor_num_patches") + self.ijepa_encoder_dim = kwargs.pop("ijepa_encoder_dim") # TODO Fix device somehow and embeddim - self.predictor_pos_embed = pos_embed(torch.zeros(1, self.predictor_num_patches, kwargs['embed_dim']), device=torch.device('cuda'), TMP_DIM=kwargs['embed_dim']) + self.predictor_pos_embed = pos_embed( + torch.zeros(1, self.predictor_num_patches, kwargs["embed_dim"]), + device=torch.device("cuda"), + TMP_DIM=kwargs["embed_dim"], + ) super().__init__(*args, **kwargs) - self.predictor_pos_embed = nn.Parameter(self.predictor_pos_embed, requires_grad=False) - self.predictor_inproj = nn.Linear(self.ijepa_encoder_dim, self.embed_dim) - self.predictor_outproj = nn.Linear(self.embed_dim, self.ijepa_encoder_dim) - self.predictor_norm = nn.LayerNorm(self.embed_dim) - self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) - - if self.weight_init != 'skip': + self.predictor_pos_embed = nn.Parameter( + self.predictor_pos_embed, requires_grad=False + ) + self.predictor_inproj = nn.Linear(self.ijepa_encoder_dim, self.embed_dim) + self.predictor_outproj = nn.Linear(self.embed_dim, self.ijepa_encoder_dim) + self.predictor_norm = nn.LayerNorm(self.embed_dim) + self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) + + if self.weight_init != "skip": self.init_weights(self.weight_init) if self.fix_init: self.fix_init_weight() - def predict_targets(self, context_patches: torch.Tensor, masks_target: list[torch.Tensor]) -> torch.Tensor: + def predict_targets( + self, context_patches: torch.Tensor, masks_target: list[torch.Tensor] + ) -> torch.Tensor: # NOTE context_patches already positionally embedded B, *_ = context_patches.shape M = len(masks_target) - ctx: torch.Tensor = self.predictor_inproj(context_patches) # [B, N_ctx, Dp] + ctx: torch.Tensor = self.predictor_inproj(context_patches) # [B, N_ctx, Dp] # target position embeddings (stacked per mask): [B*M, K_tgt, D] - pos_all = self.predictor_pos_embed.expand(B, -1, -1) - tgt_pos = apply_masks(pos_all, *masks_target) + pos_all = self.predictor_pos_embed.expand(B, -1, -1) + tgt_pos = apply_masks(pos_all, *masks_target) - # repeat context across M target blocks: [B*M, N_ctx, D]. this means that + # repeat context across M target blocks: [B*M, N_ctx, D]. this means that # the predictor predicts each target block independently, and not their union, # as the ijepa repo does ctx = ctx.repeat_interleave(M, dim=0) # mask tokens placed at target positions N_tgt = tgt_pos.size(1) - pred_tokens = self.mask_token.expand(B*M, N_tgt, -1) + tgt_pos + pred_tokens = self.mask_token.expand(B * M, N_tgt, -1) + tgt_pos # each target block now gets predicted: [B*M, N_ctx+N_tgt, D] x = torch.cat([ctx, pred_tokens], dim=1) @@ -248,39 +252,57 @@ def predict_targets(self, context_patches: torch.Tensor, masks_target: list[torc return pred -encoder_kwargs = dict(patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, ijepa_in_dim=patch_channel_dim) -predictor_kwargs = dict(patch_size=patch_size, embed_dim=384, depth=6, num_heads=6, qkv_bias=False, ijepa_encoder_dim=768, - predictor_num_patches=num_patches) +encoder_kwargs = dict( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + qkv_bias=False, + ijepa_in_dim=patch_channel_dim, +) +predictor_kwargs = dict( + patch_size=patch_size, + embed_dim=384, + depth=6, + num_heads=6, + qkv_bias=False, + ijepa_encoder_dim=768, + predictor_num_patches=num_patches, +) + def forward(self: ssl.Module, batch, stage): out = {} - target_encoder: IJEPA_Encoder = self.target_encoder.teacher # NOTE Would this break anything? + target_encoder: IJEPA_Encoder = ( + self.target_encoder.teacher + ) # NOTE Would this break anything? context_encoder: IJEPA_Encoder = self.context_encoder predictor: IJEPA_Predictor = self.predictor ijepa_loss: nn.Module = self.ijepa_loss - image_patches = target_encoder.patchify(batch['image']) - pos_embedding = pos_embed(image_patches, device=batch['image'].device) - target_patches = target_encoder.project_patches(image_patches) + pos_embedding - out['embedding'] = target_encoder.encode_patches(target_patches, with_layernorm=True) - + image_patches = target_encoder.patchify(batch["image"]) + pos_embedding = pos_embed(image_patches, device=batch["image"].device) + target_patches = target_encoder.project_patches(image_patches) + pos_embedding + out["embedding"] = target_encoder.encode_patches( + target_patches, with_layernorm=True + ) + if not self.training: return out - mask_context, masks_target = batch['mask_context'], batch['masks_target'] + mask_context, masks_target = batch["mask_context"], batch["masks_target"] # target encoder is applied on full patches, then masked - out['target_patches'] = apply_masks(out['embedding'], *masks_target) + out["target_patches"] = apply_masks(out["embedding"], *masks_target) # context encoder is applied on masked patches - context_patches = apply_masks(image_patches, mask_context) - context_patches = context_encoder.project_patches(context_patches) - context_patches = context_patches + apply_masks(pos_embedding, mask_context) - out['context_patches'] = context_encoder.encode_patches(context_patches, with_layernorm=False) - out['predicted_patches'] = predictor.predict_targets(context_patches, masks_target) - - out['loss'] = ijepa_loss( - out['predicted_patches'], - out['target_patches'] + context_patches = apply_masks(image_patches, mask_context) + context_patches = context_encoder.project_patches(context_patches) + context_patches = context_patches + apply_masks(pos_embedding, mask_context) + out["context_patches"] = context_encoder.encode_patches( + context_patches, with_layernorm=False ) + out["predicted_patches"] = predictor.predict_targets(context_patches, masks_target) + + out["loss"] = ijepa_loss(out["predicted_patches"], out["target_patches"]) return out @@ -295,7 +317,7 @@ def forward(self: ssl.Module, batch, stage): trainer = pl.Trainer( max_epochs=6, num_sanity_val_steps=0, - precision='16-mixed', + precision="16-mixed", enable_checkpointing=False, ) From c4052ab2cc367b9cfe68a3947671a8f35c9a82f5 Mon Sep 17 00:00:00 2001 From: Sami Date: Sun, 17 Aug 2025 07:13:39 +0000 Subject: [PATCH 10/28] ijepa example, masked args context/target --- stable_ssl/data/transforms.py | 23 ++++-- .../tests/scripts/train_ijepa_cifar10.py | 79 +++++++++++-------- 2 files changed, 64 insertions(+), 38 deletions(-) diff --git a/stable_ssl/data/transforms.py b/stable_ssl/data/transforms.py index 82985888..ecf68d48 100644 --- a/stable_ssl/data/transforms.py +++ b/stable_ssl/data/transforms.py @@ -732,6 +732,9 @@ def __init__(self, target_scales=((0.15, 0.2),)*4, target_aspect_ratios=((0.75, 1.5),)*4, min_keep=10, + source: str = "image", + target_context: str = "mask_context", + target_targets: str = "masks_target", ): super().__init__() self.patch_size = patch_size @@ -739,7 +742,9 @@ def __init__(self, self.context_aspect_ratio = context_aspect_ratio self.target_scales = target_scales self.target_aspect_ratios = target_aspect_ratios - + self.source = source + self.target_context = target_context + self.target_targets = target_targets if len(target_scales) != len(target_aspect_ratios): raise ValueError( f'Each scale must have its associated aspect ratio and vice versa.', @@ -749,9 +754,15 @@ def __init__(self, self.min_keep = min_keep def __call__(self, x): - H, W = x["image"]._size - # TODO Could this ever fully hide the context? If so, should - # the guardrail be in here or in multi_block_mask? Definitely shouldn't be after batch is formed + source = self.nested_get(x, self.source) + if isinstance(source, PIL.Image.Image): + H, W = source.size + elif isinstance(source, torch.Tensor): + # NOTE assumes _HW + H, W = source.shape[-2:] + else: + raise ValueError(f"Source must be a PIL.Image.Image or a torch.Tensor, but got {type(source)} instead.") + scales = [self.context_scale, *self.target_scales] aspect_ratios = [self.context_aspect_ratio, *self.target_aspect_ratios] context_mask, *target_masks = multi_block_mask( @@ -765,8 +776,8 @@ def __call__(self, x): for mask in target_masks: context_mask &= ~mask - x["mask_context"] = torch.nonzero(context_mask).flatten().squeeze() - x["masks_target"] = [torch.nonzero(mask).flatten().squeeze() for mask in target_masks] + x[self.target_context] = torch.nonzero(context_mask).flatten().squeeze() + x[self.target_targets] = [torch.nonzero(mask).flatten().squeeze() for mask in target_masks] x[self.get_name(x)] = torch.tensor([scales, aspect_ratios]) return x diff --git a/stable_ssl/tests/scripts/train_ijepa_cifar10.py b/stable_ssl/tests/scripts/train_ijepa_cifar10.py index 60be23f9..557a3509 100644 --- a/stable_ssl/tests/scripts/train_ijepa_cifar10.py +++ b/stable_ssl/tests/scripts/train_ijepa_cifar10.py @@ -1,5 +1,6 @@ import lightning as pl import torch +import math from torch import nn from typing import Optional import torchvision @@ -11,7 +12,7 @@ from stable_ssl.data import transforms from stable_ssl.data.utils import Dataset from stable_ssl.backbone.utils import TeacherStudentModule -from stable_ssl.utils.pos_embed import get_1d_sincos_pos_embed_from_grid +from stable_ssl.utils.pos_embed import get_2d_sincos_pos_embed from timm.models.vision_transformer import VisionTransformer @@ -79,10 +80,7 @@ def standardize_masks(batch: list[dict]): context_indices = [item.pop('mask_context') for item in batch] target_indices = [item.pop('masks_target') for item in batch] batch = torch.utils.data.default_collate(batch) - """ - [c.unique().shape for c in collated_masks_pred] - [torch.Size([194]), torch.Size([168]), torch.Size([190]), torch.Size([158])] - """ + min_keep_enc = min(len(ctx) for ctx in context_indices) min_keep_pred = min( len(block) @@ -109,9 +107,9 @@ def standardize_masks(batch: list[dict]): # single views and handles masking at the model level train = torch.utils.data.DataLoader( dataset=train_dataset, - batch_size=64, + batch_size=512, shuffle=True, # Regular shuffling, no RepeatedRandomSampler - num_workers=0, + num_workers=32, drop_last=True, collate_fn=standardize_masks ) @@ -120,7 +118,7 @@ def standardize_masks(batch: list[dict]): val = torch.utils.data.DataLoader( dataset=val_dataset, batch_size=128, - num_workers=0, + num_workers=32, shuffle=True, ) @@ -128,22 +126,27 @@ def standardize_masks(batch: list[dict]): data = ssl.data.DataModule(train=train, val=val) -# TODO sinusoidal posembed. for now this is just a dummy fn -def pos_embed(patches: torch.Tensor, device: torch.device, TMP_DIM = 768) -> torch.Tensor: - return torch.zeros(patches.shape[0], patches.shape[1], TMP_DIM, device=device) +def pos_embed(patches: torch.Tensor) -> torch.Tensor: + return torch.from_numpy( + get_2d_sincos_pos_embed( + patches.shape[-1], + int(math.sqrt(patches.shape[1])) + ) + ).to(patches.device).float().repeat(patches.shape[0], 1, 1) def apply_masks(x: torch.Tensor, *masks: torch.Tensor) -> torch.Tensor: + # TODO Does this assume all masks have the same number of indices? + # we can maybe generalize this by returning a list and stacking them for ijepa in the forward B, N, D = x.shape M = len(masks) idx = torch.stack([m.to(x.device, dtype=torch.long) for m in masks], dim=1) # [B, M, K] - x_exp = x.unsqueeze(1).expand(-1, M, -1, -1) # [B, M, N, D] - out = x_exp.gather(2, idx.unsqueeze(-1).expand(-1, -1, -1, D)) # [B, M, K, D] - return out.reshape(B * M, idx.size(-1), D) # [B*M, K, D] + x_exp = x.unsqueeze(1).expand(-1, M, -1, -1) # [B, M, N, D] + out = x_exp.gather(2, idx.unsqueeze(-1).expand(-1, -1, -1, D)) # [B, M, K, D] + return out.reshape(B * M, idx.size(-1), D) # [B*M, K, D] class IJEPA_Encoder(VisionTransformer): - # TODO def __init__(self, *args, **kwargs): self.weight_init = kwargs.get('weight_init', '') self.fix_init = kwargs.get('fix_init', False) @@ -204,8 +207,7 @@ def __init__(self, *args, **kwargs): self.fix_init = kwargs.get('fix_init', False) self.predictor_num_patches = kwargs.pop('predictor_num_patches') self.ijepa_encoder_dim = kwargs.pop('ijepa_encoder_dim') - # TODO Fix device somehow and embeddim - self.predictor_pos_embed = pos_embed(torch.zeros(1, self.predictor_num_patches, kwargs['embed_dim']), device=torch.device('cuda'), TMP_DIM=kwargs['embed_dim']) + self.predictor_pos_embed = pos_embed(torch.zeros(1, self.predictor_num_patches, kwargs['embed_dim'])) super().__init__(*args, **kwargs) self.predictor_pos_embed = nn.Parameter(self.predictor_pos_embed, requires_grad=False) self.predictor_inproj = nn.Linear(self.ijepa_encoder_dim, self.embed_dim) @@ -223,7 +225,7 @@ def predict_targets(self, context_patches: torch.Tensor, masks_target: list[torc B, *_ = context_patches.shape M = len(masks_target) - ctx: torch.Tensor = self.predictor_inproj(context_patches) # [B, N_ctx, Dp] + ctx: torch.Tensor = self.predictor_inproj(context_patches) # [B, N_ctx, D] # target position embeddings (stacked per mask): [B*M, K_tgt, D] pos_all = self.predictor_pos_embed.expand(B, -1, -1) @@ -260,9 +262,12 @@ def forward(self: ssl.Module, batch, stage): ijepa_loss: nn.Module = self.ijepa_loss image_patches = target_encoder.patchify(batch['image']) - pos_embedding = pos_embed(image_patches, device=batch['image'].device) - target_patches = target_encoder.project_patches(image_patches) + pos_embedding + target_patches = target_encoder.project_patches(image_patches) + pos_embedding = pos_embed(target_patches) + target_patches = target_patches + pos_embedding out['embedding'] = target_encoder.encode_patches(target_patches, with_layernorm=True) + out['sum_embedding'] = out['embedding'].sum(dim=1) + out['flat_embedding'] = out['embedding'].reshape(out['embedding'].shape[0], -1) if not self.training: return out @@ -292,6 +297,26 @@ def forward(self: ssl.Module, batch, stage): ijepa_loss=F.smooth_l1_loss, ) + +linear_probe = ssl.callbacks.OnlineProbe( + "linear_probe", + "sum_embedding", + "label", + probe=torch.nn.Linear(768, 10), + loss_fn=torch.nn.CrossEntropyLoss(), + metrics={ + "top1": torchmetrics.classification.MulticlassAccuracy(10), + "top5": torchmetrics.classification.MulticlassAccuracy(10, top_k=5), + }, +) + +rankme = ssl.callbacks.RankMe( + name="rankme", + target="flat_embedding", + queue_length=512, # NOTE must be >= batch_size + target_shape=(num_patches, 768), +) + trainer = pl.Trainer( max_epochs=6, num_sanity_val_steps=0, @@ -299,29 +324,19 @@ def forward(self: ssl.Module, batch, stage): enable_checkpointing=False, ) -knn_probe = ssl.callbacks.OnlineKNN( - name="knn_probe", - input="embedding", - target="label", - queue_length=20000, - metrics={"accuracy": torchmetrics.classification.MulticlassAccuracy(10)}, - input_dim=512, - k=10, -) - # Initialize W&B logger with explicit settings wandb_logger = WandbLogger( project="ijepa-cifar10", entity="slightly-more-badass", # Your W&B entity name="ijepa-cifar10-run", log_model=False, # Set to True if you want to save model artifacts - offline=False, # Ensure online mode + offline=True, # Ensure offline mode ) trainer = pl.Trainer( max_epochs=6, num_sanity_val_steps=0, # Skip sanity check as queues need to be filled first - callbacks=[knn_probe], + callbacks=[linear_probe, rankme], precision="16-mixed", logger=wandb_logger, enable_checkpointing=False, From ad64c0f12161fd75e6ae66f32224e62da229984d Mon Sep 17 00:00:00 2001 From: Sami Date: Sun, 17 Aug 2025 07:21:50 +0000 Subject: [PATCH 11/28] precommits --- .../tests/scripts/train_ijepa_cifar10.py | 95 +++++++++++-------- 1 file changed, 53 insertions(+), 42 deletions(-) diff --git a/stable_ssl/tests/scripts/train_ijepa_cifar10.py b/stable_ssl/tests/scripts/train_ijepa_cifar10.py index c3abf839..2a5e565f 100644 --- a/stable_ssl/tests/scripts/train_ijepa_cifar10.py +++ b/stable_ssl/tests/scripts/train_ijepa_cifar10.py @@ -1,10 +1,7 @@ +import math + import lightning as pl import torch -import math -from torch import nn -from typing import Optional -import torchvision -import torchmetrics import torch.nn.functional as F import torchmetrics import torchvision @@ -16,12 +13,8 @@ from stable_ssl.backbone.utils import TeacherStudentModule from stable_ssl.data import transforms from stable_ssl.data.utils import Dataset -from stable_ssl.backbone.utils import TeacherStudentModule from stable_ssl.utils.pos_embed import get_2d_sincos_pos_embed -from timm.models.vision_transformer import VisionTransformer - - mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] height, width, patch_size = 32, 32, 4 @@ -81,16 +74,13 @@ def __len__(self): def standardize_masks(batch: list[dict]): + context_indices = [item.pop("mask_context") for item in batch] + target_indices = [item.pop("masks_target") for item in batch] + batch = torch.utils.data.default_collate(batch) - context_indices = [item.pop('mask_context') for item in batch] - target_indices = [item.pop('masks_target') for item in batch] - batch = torch.utils.data.default_collate(batch) - - min_keep_enc = min(len(ctx) for ctx in context_indices) - min_keep_pred = min( - len(block) - for multiblock in target_indices - for block in multiblock + min_keep_enc = min(len(ctx) for ctx in context_indices) + min_keep_pred = min( + len(block) for multiblock in target_indices for block in multiblock ) context_batch = [ctx[:min_keep_enc] for ctx in context_indices] @@ -131,12 +121,14 @@ def standardize_masks(batch: list[dict]): def pos_embed(patches: torch.Tensor) -> torch.Tensor: - return torch.from_numpy( - get_2d_sincos_pos_embed( - patches.shape[-1], - int(math.sqrt(patches.shape[1])) + return ( + torch.from_numpy( + get_2d_sincos_pos_embed(patches.shape[-1], int(math.sqrt(patches.shape[1]))) ) - ).to(patches.device).float().repeat(patches.shape[0], 1, 1) + .to(patches.device) + .float() + .repeat(patches.shape[0], 1, 1) + ) def apply_masks(x: torch.Tensor, *masks: torch.Tensor) -> torch.Tensor: @@ -144,13 +136,21 @@ def apply_masks(x: torch.Tensor, *masks: torch.Tensor) -> torch.Tensor: # we can maybe generalize this by returning a list and stacking them for ijepa in the forward B, N, D = x.shape M = len(masks) - idx = torch.stack([m.to(x.device, dtype=torch.long) for m in masks], dim=1) # [B, M, K] - x_exp = x.unsqueeze(1).expand(-1, M, -1, -1) # [B, M, N, D] - out = x_exp.gather(2, idx.unsqueeze(-1).expand(-1, -1, -1, D)) # [B, M, K, D] - return out.reshape(B * M, idx.size(-1), D) # [B*M, K, D] + idx = torch.stack( + [m.to(x.device, dtype=torch.long) for m in masks], dim=1 + ) # [B, M, K] + x_exp = x.unsqueeze(1).expand(-1, M, -1, -1) # [B, M, N, D] + out = x_exp.gather(2, idx.unsqueeze(-1).expand(-1, -1, -1, D)) # [B, M, K, D] + return out.reshape(B * M, idx.size(-1), D) # [B*M, K, D] class IJEPA_Encoder(VisionTransformer): + """IJEPA encoder. + + Args: + ijepa_in_dim: Input dimension of the encoder, which is the patch dimension after re-arranging the image. + """ + def __init__(self, *args, **kwargs): self.weight_init = kwargs.get("weight_init", "") self.fix_init = kwargs.get("fix_init", False) @@ -208,12 +208,21 @@ def encode_image(self, image: torch.Tensor) -> torch.Tensor: class IJEPA_Predictor(VisionTransformer): + """IJEPA predictor, handles the logic of conditioning the predictor based on the context and target masks. + + Args: + predictor_num_patches: Number of patches in the predictor. This is typically equal to the number of patches in the context/target encoder. + ijepa_encoder_dim: Dimension of the IJEPA context/target encoder. This is used to up/down project the latents. + """ + def __init__(self, *args, **kwargs): - self.weight_init = kwargs.get('weight_init', '') - self.fix_init = kwargs.get('fix_init', False) - self.predictor_num_patches = kwargs.pop('predictor_num_patches') - self.ijepa_encoder_dim = kwargs.pop('ijepa_encoder_dim') - self.predictor_pos_embed = pos_embed(torch.zeros(1, self.predictor_num_patches, kwargs['embed_dim'])) + self.weight_init = kwargs.get("weight_init", "") + self.fix_init = kwargs.get("fix_init", False) + self.predictor_num_patches = kwargs.pop("predictor_num_patches") + self.ijepa_encoder_dim = kwargs.pop("ijepa_encoder_dim") + self.predictor_pos_embed = pos_embed( + torch.zeros(1, self.predictor_num_patches, kwargs["embed_dim"]) + ) super().__init__(*args, **kwargs) self.predictor_pos_embed = nn.Parameter( self.predictor_pos_embed, requires_grad=False @@ -235,7 +244,7 @@ def predict_targets( B, *_ = context_patches.shape M = len(masks_target) - ctx: torch.Tensor = self.predictor_inproj(context_patches) # [B, N_ctx, D] + ctx: torch.Tensor = self.predictor_inproj(context_patches) # [B, N_ctx, D] # target position embeddings (stacked per mask): [B*M, K_tgt, D] pos_all = self.predictor_pos_embed.expand(B, -1, -1) @@ -288,14 +297,16 @@ def forward(self: ssl.Module, batch, stage): predictor: IJEPA_Predictor = self.predictor ijepa_loss: nn.Module = self.ijepa_loss - image_patches = target_encoder.patchify(batch['image']) - target_patches = target_encoder.project_patches(image_patches) - pos_embedding = pos_embed(target_patches) - target_patches = target_patches + pos_embedding - out['embedding'] = target_encoder.encode_patches(target_patches, with_layernorm=True) - out['sum_embedding'] = out['embedding'].sum(dim=1) - out['flat_embedding'] = out['embedding'].reshape(out['embedding'].shape[0], -1) - + image_patches = target_encoder.patchify(batch["image"]) + target_patches = target_encoder.project_patches(image_patches) + pos_embedding = pos_embed(target_patches) + target_patches = target_patches + pos_embedding + out["embedding"] = target_encoder.encode_patches( + target_patches, with_layernorm=True + ) + out["sum_embedding"] = out["embedding"].sum(dim=1) + out["flat_embedding"] = out["embedding"].reshape(out["embedding"].shape[0], -1) + if not self.training: return out @@ -339,7 +350,7 @@ def forward(self: ssl.Module, batch, stage): rankme = ssl.callbacks.RankMe( name="rankme", target="flat_embedding", - queue_length=512, # NOTE must be >= batch_size + queue_length=512, # NOTE must be >= batch_size target_shape=(num_patches, 768), ) From 979fc9184ea41155195c0bbe7b2f7c266fbf9c6a Mon Sep 17 00:00:00 2001 From: Sami Date: Sun, 17 Aug 2025 07:22:19 +0000 Subject: [PATCH 12/28] precommits --- stable_ssl/data/transforms.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/stable_ssl/data/transforms.py b/stable_ssl/data/transforms.py index 407544a4..21be7b4a 100644 --- a/stable_ssl/data/transforms.py +++ b/stable_ssl/data/transforms.py @@ -759,13 +759,15 @@ def __call__(self, x): if isinstance(source, PIL.Image.Image): H, W = source.size elif isinstance(source, torch.Tensor): - # NOTE assumes _HW + # NOTE assumes _HW H, W = source.shape[-2:] else: - raise ValueError(f"Source must be a PIL.Image.Image or a torch.Tensor, but got {type(source)} instead.") + raise ValueError( + f"Source must be a PIL.Image.Image or a torch.Tensor, but got {type(source)} instead." + ) - scales = [self.context_scale, *self.target_scales] - aspect_ratios = [self.context_aspect_ratio, *self.target_aspect_ratios] + scales = [self.context_scale, *self.target_scales] + aspect_ratios = [self.context_aspect_ratio, *self.target_aspect_ratios] context_mask, *target_masks = multi_block_mask( H // self.patch_size, W // self.patch_size, @@ -778,7 +780,9 @@ def __call__(self, x): context_mask &= ~mask x[self.target_context] = torch.nonzero(context_mask).flatten().squeeze() - x[self.target_targets] = [torch.nonzero(mask).flatten().squeeze() for mask in target_masks] + x[self.target_targets] = [ + torch.nonzero(mask).flatten().squeeze() for mask in target_masks + ] x[self.get_name(x)] = torch.tensor([scales, aspect_ratios]) return x From 5b1329bd4222824ec4ab2113a745d7ce639c2696 Mon Sep 17 00:00:00 2001 From: Sami Date: Sun, 17 Aug 2025 07:39:16 +0000 Subject: [PATCH 13/28] inet test script --- .../tests/scripts/train_ijepa_cifar10.py | 2 - .../tests/scripts/train_ijepa_inet1k.py | 401 ++++++++++++++++++ 2 files changed, 401 insertions(+), 2 deletions(-) create mode 100644 stable_ssl/tests/scripts/train_ijepa_inet1k.py diff --git a/stable_ssl/tests/scripts/train_ijepa_cifar10.py b/stable_ssl/tests/scripts/train_ijepa_cifar10.py index 2a5e565f..1e8ed6bc 100644 --- a/stable_ssl/tests/scripts/train_ijepa_cifar10.py +++ b/stable_ssl/tests/scripts/train_ijepa_cifar10.py @@ -132,8 +132,6 @@ def pos_embed(patches: torch.Tensor) -> torch.Tensor: def apply_masks(x: torch.Tensor, *masks: torch.Tensor) -> torch.Tensor: - # TODO Does this assume all masks have the same number of indices? - # we can maybe generalize this by returning a list and stacking them for ijepa in the forward B, N, D = x.shape M = len(masks) idx = torch.stack( diff --git a/stable_ssl/tests/scripts/train_ijepa_inet1k.py b/stable_ssl/tests/scripts/train_ijepa_inet1k.py new file mode 100644 index 00000000..b3960d73 --- /dev/null +++ b/stable_ssl/tests/scripts/train_ijepa_inet1k.py @@ -0,0 +1,401 @@ +import math + +import lightning as pl +import torch +import torch.nn.functional as F +import torchmetrics +import torchvision +from lightning.pytorch.loggers import WandbLogger +from timm.models.vision_transformer import VisionTransformer +from torch import nn + +import stable_ssl as ssl +from stable_ssl.backbone.utils import TeacherStudentModule +from stable_ssl.data import transforms +from stable_ssl.data.utils import Dataset +from stable_ssl.utils.pos_embed import get_2d_sincos_pos_embed + +mean = [0.485, 0.456, 0.406] +std = [0.229, 0.224, 0.225] +height, width, patch_size = 32, 32, 4 +crop_height, crop_width = 160, 160 # CIFAR-10 is 32x32, but on INET, IJEPA uses 224 +# We precompute these so the predictor can make sinusoidal posembeds +num_patches = (height // patch_size) * (width // patch_size) +patch_channel_dim = 3 * patch_size * patch_size + +# Based on the in1k_vith14_ep300.yaml config in the ijepa repository +mask_transform_kwargs = dict( + patch_size=patch_size, + context_scale=(0.85, 1.0), + context_aspect_ratio=(1.0, 1.0), + target_scales=((0.15, 0.2),) * 4, + target_aspect_ratios=((0.75, 1.5),) * 4, + min_keep=20, +) + + +train_transform = transforms.Compose( + transforms.RGB(), + transforms.RandomResizedCrop((crop_height, crop_width), scale=(0.3, 1.0)), + transforms.ContextTargetsMultiBlockMask(**mask_transform_kwargs), + transforms.ToImage(mean=mean, std=std), +) +# Don't mask during validation +val_transform = transforms.Compose( + transforms.RGB(), + transforms.Resize((height, width)), + transforms.CenterCrop((height, width)), + transforms.ToImage(mean=mean, std=std), +) + +# Use torchvision CIFAR-10 wrapped in FromTorchDataset +inet1k_train = ssl.data.HFDataset( + path="frgfm/imagenette", + name="160px", + split="train", + transform=val_transform, +) + +inet1k_val = ssl.data.HFDataset( + path="frgfm/imagenette", + name="160px", + split="val", + transform=val_transform, +) + +# TODO For some reason streaming=True hangs. I could also use noface imagenet from randall-lab/ on hf +# inet1k_train = ssl.data.HFDataset( +# path="mlx-vision/imagenet-1k", +# split="train", +# transform=train_transform, +# streaming=True, +# ) + +# inet1k_val = ssl.data.HFDataset( +# path="mlx-vision/imagenet-1k", +# split="val", +# transform=val_transform, +# streaming=True, +# ) + +class IndexedDataset(Dataset): + """Custom dataset wrapper that adds sample_idx to each sample.""" + + def __init__(self, dataset, transform=None): + super().__init__(transform) + self.dataset = dataset + + def __getitem__(self, idx): + image, label = self.dataset[idx] + sample = {"image": image, "label": label, "sample_idx": idx} + return self.process_sample(sample) + + def __len__(self): + return len(self.dataset) + + +def standardize_masks(batch: list[dict]): + context_indices = [item.pop("mask_context") for item in batch] + target_indices = [item.pop("masks_target") for item in batch] + batch = torch.utils.data.default_collate(batch) + + min_keep_enc = min(len(ctx) for ctx in context_indices) + min_keep_pred = min( + len(block) for multiblock in target_indices for block in multiblock + ) + + context_batch = [ctx[:min_keep_enc] for ctx in context_indices] + target_batch = [ + [tgt[:min_keep_pred] for tgt in multiblock] for multiblock in target_indices + ] + + collated_masks_context = torch.utils.data.default_collate(context_batch) + collated_masks_target = torch.utils.data.default_collate(target_batch) + + batch["mask_context"] = collated_masks_context + batch["masks_target"] = collated_masks_target + return batch + + +train_dataset = IndexedDataset(inet1k_train, transform=train_transform) +# IJEPA does not use multi-view sampling like SimCLR etc. because it processes +# single views and handles masking at the model level +train = torch.utils.data.DataLoader( + dataset=train_dataset, + batch_size=512, + shuffle=True, # Regular shuffling, no RepeatedRandomSampler + num_workers=0, + drop_last=True, + collate_fn=standardize_masks, +) + +val_dataset = IndexedDataset(inet1k_val, transform=val_transform) +val = torch.utils.data.DataLoader( + dataset=val_dataset, + batch_size=128, + num_workers=0, + shuffle=True, +) + + +data = ssl.data.DataModule(train=train, val=val) + + +def pos_embed(patches: torch.Tensor) -> torch.Tensor: + return ( + torch.from_numpy( + get_2d_sincos_pos_embed(patches.shape[-1], int(math.sqrt(patches.shape[1]))) + ) + .to(patches.device) + .float() + .repeat(patches.shape[0], 1, 1) + ) + + +def apply_masks(x: torch.Tensor, *masks: torch.Tensor) -> torch.Tensor: + B, N, D = x.shape + M = len(masks) + idx = torch.stack( + [m.to(x.device, dtype=torch.long) for m in masks], dim=1 + ) # [B, M, K] + x_exp = x.unsqueeze(1).expand(-1, M, -1, -1) # [B, M, N, D] + out = x_exp.gather(2, idx.unsqueeze(-1).expand(-1, -1, -1, D)) # [B, M, K, D] + return out.reshape(B * M, idx.size(-1), D) # [B*M, K, D] + + +class IJEPA_Encoder(VisionTransformer): + """IJEPA encoder. + + Args: + ijepa_in_dim: Input dimension of the encoder, which is the patch dimension after re-arranging the image. + """ + + def __init__(self, *args, **kwargs): + self.weight_init = kwargs.get("weight_init", "") + self.fix_init = kwargs.get("fix_init", False) + ijepa_in_dim = kwargs.pop("ijepa_in_dim") + super().__init__(*args, **kwargs) + + self.ijepa_patch_project = nn.Linear(ijepa_in_dim, self.embed_dim) + + if self.weight_init != "skip": + self.init_weights(self.weight_init) + if self.fix_init: + self.fix_init_weight() + + def patchify(self, image: torch.Tensor) -> torch.Tensor: + """Convert image tensor into patches. + + Args: + image: Tensor of shape [B, C, H, W] + + Returns: + patches: Tensor of shape [B, N, P*P*C] where: + N = number of patches (H/P * W/P) + P = patch size + """ + B, C, H, W = image.shape + P = patch_size + + # Unfold into patches + patches = image.unfold(2, P, P).unfold(3, P, P) + + # Reshape to [B, num_patches, patch_dim] + patches = patches.permute(0, 2, 3, 1, 4, 5) + num_patch_h, num_patch_w = patches.shape[1], patches.shape[2] + patches = patches.reshape(B, num_patch_h * num_patch_w, P * P * C) + + return patches + + def project_patches(self, patches: torch.Tensor) -> torch.Tensor: + # assume they are already reshaped to patches + return self.ijepa_patch_project(patches) + + def encode_patches( + self, patches: torch.Tensor, with_layernorm: bool = True + ) -> torch.Tensor: + x = self.blocks(patches) + if with_layernorm: + x = F.layer_norm(x, (x.size(-1),)) + return x + + def encode_image(self, image: torch.Tensor) -> torch.Tensor: + patches = self.patchify(image) + patches = self.project_patches(patches) + patches = patches + pos_embed(patches) + return self.encode_patches(patches) + + +class IJEPA_Predictor(VisionTransformer): + """IJEPA predictor, handles the logic of conditioning the predictor based on the context and target masks. + + Args: + predictor_num_patches: Number of patches in the predictor. This is typically equal to the number of patches in the context/target encoder. + ijepa_encoder_dim: Dimension of the IJEPA context/target encoder. This is used to up/down project the latents. + """ + + def __init__(self, *args, **kwargs): + self.weight_init = kwargs.get("weight_init", "") + self.fix_init = kwargs.get("fix_init", False) + self.predictor_num_patches = kwargs.pop("predictor_num_patches") + self.ijepa_encoder_dim = kwargs.pop("ijepa_encoder_dim") + self.predictor_pos_embed = pos_embed( + torch.zeros(1, self.predictor_num_patches, kwargs["embed_dim"]) + ) + super().__init__(*args, **kwargs) + self.predictor_pos_embed = nn.Parameter( + self.predictor_pos_embed, requires_grad=False + ) + self.predictor_inproj = nn.Linear(self.ijepa_encoder_dim, self.embed_dim) + self.predictor_outproj = nn.Linear(self.embed_dim, self.ijepa_encoder_dim) + self.predictor_norm = nn.LayerNorm(self.embed_dim) + self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) + + if self.weight_init != "skip": + self.init_weights(self.weight_init) + if self.fix_init: + self.fix_init_weight() + + def predict_targets( + self, context_patches: torch.Tensor, masks_target: list[torch.Tensor] + ) -> torch.Tensor: + # NOTE context_patches already positionally embedded + B, *_ = context_patches.shape + M = len(masks_target) + + ctx: torch.Tensor = self.predictor_inproj(context_patches) # [B, N_ctx, D] + + # target position embeddings (stacked per mask): [B*M, K_tgt, D] + pos_all = self.predictor_pos_embed.expand(B, -1, -1) + tgt_pos = apply_masks(pos_all, *masks_target) + + # repeat context across M target blocks: [B*M, N_ctx, D]. this means that + # the predictor predicts each target block independently, and not their union, + # as the ijepa repo does + ctx = ctx.repeat_interleave(M, dim=0) + + # mask tokens placed at target positions + N_tgt = tgt_pos.size(1) + pred_tokens = self.mask_token.expand(B * M, N_tgt, -1) + tgt_pos + + # each target block now gets predicted: [B*M, N_ctx+N_tgt, D] + x = torch.cat([ctx, pred_tokens], dim=1) + x = self.blocks(x) + x = self.predictor_norm(x) + + pred = x[:, -N_tgt:] + pred = self.predictor_outproj(pred) + return pred + + +encoder_kwargs = dict( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + qkv_bias=False, + ijepa_in_dim=patch_channel_dim, +) +predictor_kwargs = dict( + patch_size=patch_size, + embed_dim=384, + depth=6, + num_heads=6, + qkv_bias=False, + ijepa_encoder_dim=768, + predictor_num_patches=num_patches, +) + + +def forward(self: ssl.Module, batch, stage): + out = {} + target_encoder: IJEPA_Encoder = ( + self.target_encoder.teacher + ) # NOTE Would this break anything? + context_encoder: IJEPA_Encoder = self.context_encoder + predictor: IJEPA_Predictor = self.predictor + ijepa_loss: nn.Module = self.ijepa_loss + + image_patches = target_encoder.patchify(batch["image"]) + target_patches = target_encoder.project_patches(image_patches) + pos_embedding = pos_embed(target_patches) + target_patches = target_patches + pos_embedding + out["embedding"] = target_encoder.encode_patches( + target_patches, with_layernorm=True + ) + out["sum_embedding"] = out["embedding"].sum(dim=1) + out["flat_embedding"] = out["embedding"].reshape(out["embedding"].shape[0], -1) + + if not self.training: + return out + + mask_context, masks_target = batch["mask_context"], batch["masks_target"] + # target encoder is applied on full patches, then masked + out["target_patches"] = apply_masks(out["embedding"], *masks_target) + # context encoder is applied on masked patches + context_patches = apply_masks(image_patches, mask_context) + context_patches = context_encoder.project_patches(context_patches) + context_patches = context_patches + apply_masks(pos_embedding, mask_context) + out["context_patches"] = context_encoder.encode_patches( + context_patches, with_layernorm=False + ) + out["predicted_patches"] = predictor.predict_targets(context_patches, masks_target) + + out["loss"] = ijepa_loss(out["predicted_patches"], out["target_patches"]) + return out + + +module = ssl.Module( + context_encoder=(ctx := IJEPA_Encoder(**encoder_kwargs)), + target_encoder=TeacherStudentModule(ctx), + predictor=IJEPA_Predictor(**predictor_kwargs), + forward=forward, + ijepa_loss=F.smooth_l1_loss, +) + + +linear_probe = ssl.callbacks.OnlineProbe( + "linear_probe", + "sum_embedding", + "label", + probe=torch.nn.Linear(768, 10), + loss_fn=torch.nn.CrossEntropyLoss(), + metrics={ + "top1": torchmetrics.classification.MulticlassAccuracy(10), + "top5": torchmetrics.classification.MulticlassAccuracy(10, top_k=5), + }, +) + +rankme = ssl.callbacks.RankMe( + name="rankme", + target="flat_embedding", + queue_length=512, # NOTE must be >= batch_size + target_shape=(num_patches, 768), +) + +trainer = pl.Trainer( + max_epochs=6, + num_sanity_val_steps=0, + precision="16-mixed", + enable_checkpointing=False, +) + +# Initialize W&B logger with explicit settings +wandb_logger = WandbLogger( + project="ijepa-imagenette", + entity="slightly-more-badass", # Your W&B entity + name="ijepa-cifar10-run", + log_model=False, # Set to True if you want to save model artifacts + offline=True, # Ensure offline mode +) + +trainer = pl.Trainer( + max_epochs=6, + num_sanity_val_steps=0, # Skip sanity check as queues need to be filled first + callbacks=[linear_probe, rankme], + precision="16-mixed", + logger=wandb_logger, + enable_checkpointing=False, +) +manager = ssl.Manager(trainer=trainer, module=module, data=data) +manager() From 6254db3e1c400a4f309e54434f69ce39bf9a28f4 Mon Sep 17 00:00:00 2001 From: Sami Date: Sun, 17 Aug 2025 16:55:31 +0000 Subject: [PATCH 14/28] IJEPA INET1k HF dataset --- .../tests/scripts/train_ijepa_inet1k.py | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/stable_ssl/tests/scripts/train_ijepa_inet1k.py b/stable_ssl/tests/scripts/train_ijepa_inet1k.py index b3960d73..c1fdf802 100644 --- a/stable_ssl/tests/scripts/train_ijepa_inet1k.py +++ b/stable_ssl/tests/scripts/train_ijepa_inet1k.py @@ -1,5 +1,11 @@ -import math +import os + +os.environ["HF_HOME"] = "/mnt/data/sami/huggingface" +os.environ["HF_DATASETS_CACHE"] = "/mnt/data/sami/huggingface/datasets" +os.environ["HF_HUB_CACHE"] = "/mnt/data/sami/huggingface/hub" +os.environ["TORCH_HOME"] = "/mnt/data/sami/torch" +import math import lightning as pl import torch import torch.nn.functional as F @@ -17,8 +23,8 @@ mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] -height, width, patch_size = 32, 32, 4 -crop_height, crop_width = 160, 160 # CIFAR-10 is 32x32, but on INET, IJEPA uses 224 +height, width, patch_size = 256, 256, 16 +crop_height, crop_width = 224, 224 # CIFAR-10 is 32x32, but on INET, IJEPA uses 224 # We precompute these so the predictor can make sinusoidal posembeds num_patches = (height // patch_size) * (width // patch_size) patch_channel_dim = 3 * patch_size * patch_size @@ -40,6 +46,7 @@ transforms.ContextTargetsMultiBlockMask(**mask_transform_kwargs), transforms.ToImage(mean=mean, std=std), ) + # Don't mask during validation val_transform = transforms.Compose( transforms.RGB(), @@ -48,21 +55,20 @@ transforms.ToImage(mean=mean, std=std), ) -# Use torchvision CIFAR-10 wrapped in FromTorchDataset + inet1k_train = ssl.data.HFDataset( - path="frgfm/imagenette", - name="160px", + path="ilsvrc/imagenet-1k", split="train", transform=val_transform, ) inet1k_val = ssl.data.HFDataset( - path="frgfm/imagenette", - name="160px", + path="ilsvrc/imagenet-1k", split="val", transform=val_transform, ) + # TODO For some reason streaming=True hangs. I could also use noface imagenet from randall-lab/ on hf # inet1k_train = ssl.data.HFDataset( # path="mlx-vision/imagenet-1k", From 129997e55588353284b4a5bf0ff2128f608e63cf Mon Sep 17 00:00:00 2001 From: Sami Date: Sun, 17 Aug 2025 17:22:50 +0000 Subject: [PATCH 15/28] mae, random masking, inet --- stable_ssl/data/transforms.py | 39 ++ .../tests/scripts/train_ijepa_inet1k.py | 4 +- stable_ssl/tests/scripts/train_mae_cifar10.py | 343 ++++++++++++++++++ 3 files changed, 384 insertions(+), 2 deletions(-) create mode 100644 stable_ssl/tests/scripts/train_mae_cifar10.py diff --git a/stable_ssl/data/transforms.py b/stable_ssl/data/transforms.py index 21be7b4a..7ad57a8b 100644 --- a/stable_ssl/data/transforms.py +++ b/stable_ssl/data/transforms.py @@ -787,6 +787,45 @@ def __call__(self, x): return x +class RandomMask(Transform): + def __init__( + self, + patch_size=16, + mask_ratio=0.75, + source: str = "image", + target_visible: str = "mask_visible", + target_masked: str = "mask_masked", + ): + super().__init__() + self.patch_size = patch_size + self.mask_ratio = mask_ratio + self.source = source + self.target_visible = target_visible + self.target_masked = target_masked + + def __call__(self, x): + source = self.nested_get(x, self.source) + if isinstance(source, PIL.Image.Image): + H, W = source.size + elif isinstance(source, torch.Tensor): + # NOTE assumes _HW + H, W = source.shape[-2:] + else: + raise ValueError( + f"Source must be a PIL.Image.Image or a torch.Tensor, but got {type(source)} instead." + ) + + num_patches = (H // self.patch_size) * (W // self.patch_size) + num_masked = int(num_patches * self.mask_ratio) + indices = torch.randperm(num_patches) + mask_masked = indices[:num_masked] + mask_visible = indices[num_masked:] + + x[self.target_visible] = mask_visible + x[self.target_masked] = mask_masked + return x + + # class MultiTransforms(v2.Transform): # def __init__(self, transforms, repeats: list = None): # super().__init__() diff --git a/stable_ssl/tests/scripts/train_ijepa_inet1k.py b/stable_ssl/tests/scripts/train_ijepa_inet1k.py index c1fdf802..53202942 100644 --- a/stable_ssl/tests/scripts/train_ijepa_inet1k.py +++ b/stable_ssl/tests/scripts/train_ijepa_inet1k.py @@ -57,13 +57,13 @@ inet1k_train = ssl.data.HFDataset( - path="ilsvrc/imagenet-1k", + path="ILSVRC/imagenet-1k", split="train", transform=val_transform, ) inet1k_val = ssl.data.HFDataset( - path="ilsvrc/imagenet-1k", + path="ILSVRC/imagenet-1k", split="val", transform=val_transform, ) diff --git a/stable_ssl/tests/scripts/train_mae_cifar10.py b/stable_ssl/tests/scripts/train_mae_cifar10.py new file mode 100644 index 00000000..2ee83210 --- /dev/null +++ b/stable_ssl/tests/scripts/train_mae_cifar10.py @@ -0,0 +1,343 @@ +import math + +import lightning as pl +import torch +import torch.nn.functional as F +import torchmetrics +import torchvision +from lightning.pytorch.loggers import WandbLogger +from timm.models.vision_transformer import VisionTransformer +from torch import nn + +import stable_ssl as ssl +from stable_ssl.data import transforms +from stable_ssl.data.utils import Dataset +from stable_ssl.utils.pos_embed import get_2d_sincos_pos_embed + +mean = [0.485, 0.456, 0.406] +std = [0.229, 0.224, 0.225] +height, width, patch_size = 32, 32, 4 +crop_height, crop_width = 32, 32 +num_patches = (height // patch_size) * (width // patch_size) +patch_channel_dim = 3 * patch_size * patch_size + +mask_transform_kwargs = dict( + patch_size=patch_size, + mask_ratio=0.75, + source="image", + target_visible="mask_visible", + target_masked="mask_masked", +) + +train_transform = transforms.Compose( + transforms.RGB(), + transforms.RandomResizedCrop((crop_height, crop_width), scale=(0.2, 1.0)), + transforms.RandomMask(**mask_transform_kwargs), + transforms.ToImage(mean=mean, std=std), +) + +val_transform = transforms.Compose( + transforms.RGB(), + transforms.Resize((height, width)), + transforms.CenterCrop((height, width)), + transforms.ToImage(mean=mean, std=std), +) + +# Use torchvision CIFAR-10 +cifar_train = torchvision.datasets.CIFAR10( + root="/tmp/cifar10", train=True, download=True +) +cifar_val = torchvision.datasets.CIFAR10( + root="/tmp/cifar10", train=False, download=True +) + +class IndexedDataset(Dataset): + def __init__(self, dataset, transform=None): + super().__init__(transform) + self.dataset = dataset + + def __getitem__(self, idx): + image, label = self.dataset[idx] + sample = {"image": image, "label": label, "sample_idx": idx} + return self.process_sample(sample) + + def __len__(self): + return len(self.dataset) + +def standardize_masks(batch: list[dict]): + """Simpler collate function for MAE - just handle visible/masked indices""" + visible_indices = [item.pop("mask_visible") for item in batch] + masked_indices = [item.pop("mask_masked") for item in batch] + batch = torch.utils.data.default_collate(batch) + + # Standardize to minimum length + min_visible = min(len(vis) for vis in visible_indices) + min_masked = min(len(mask) for mask in masked_indices) + + visible_batch = [vis[:min_visible] for vis in visible_indices] + masked_batch = [mask[:min_masked] for mask in masked_indices] + + batch["mask_visible"] = torch.utils.data.default_collate(visible_batch) + batch["mask_masked"] = torch.utils.data.default_collate(masked_batch) + return batch + +train_dataset = IndexedDataset(cifar_train, transform=train_transform) +train = torch.utils.data.DataLoader( + dataset=train_dataset, + batch_size=512, + shuffle=True, + num_workers=32, + drop_last=True, + collate_fn=standardize_masks, +) + +val_dataset = IndexedDataset(cifar_val, transform=val_transform) +val = torch.utils.data.DataLoader( + dataset=val_dataset, + batch_size=128, + num_workers=32, + shuffle=True, +) + +data = ssl.data.DataModule(train=train, val=val) + + +def pos_embed(patches: torch.Tensor) -> torch.Tensor: + return ( + torch.from_numpy( + get_2d_sincos_pos_embed(patches.shape[-1], int(math.sqrt(patches.shape[1]))) + ) + .to(patches.device) + .float() + .repeat(patches.shape[0], 1, 1) + ) + + +def apply_mask(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """Apply single mask to tensor""" + B, N, D = x.shape + mask_expanded = mask.unsqueeze(-1).expand(-1, -1, D) + return torch.gather(x, dim=1, index=mask_expanded) + + +class MAE_Encoder(VisionTransformer): + """MAE encoder - processes only visible patches""" + + def __init__(self, *args, **kwargs): + self.weight_init = kwargs.get("weight_init", "") + self.fix_init = kwargs.get("fix_init", False) + mae_in_dim = kwargs.pop("mae_in_dim") + super().__init__(*args, **kwargs) + + self.mae_patch_project = nn.Linear(mae_in_dim, self.embed_dim) + + if self.weight_init != "skip": + self.init_weights(self.weight_init) + if self.fix_init: + self.fix_init_weight() + + def patchify(self, image: torch.Tensor) -> torch.Tensor: + B, C, H, W = image.shape + P = patch_size + + patches = image.unfold(2, P, P).unfold(3, P, P) + patches = patches.permute(0, 2, 3, 1, 4, 5) + num_patch_h, num_patch_w = patches.shape[1], patches.shape[2] + patches = patches.reshape(B, num_patch_h * num_patch_w, P * P * C) + return patches + + def project_patches(self, patches: torch.Tensor) -> torch.Tensor: + return self.mae_patch_project(patches) + + def encode_patches(self, patches: torch.Tensor) -> torch.Tensor: + x = self.blocks(patches) + x = self.norm(x) + return x + +class MAE_Decoder(nn.Module): + """MAE decoder - reconstructs full image from visible patches + mask tokens""" + + def __init__(self, encoder_dim=768, decoder_dim=512, decoder_depth=8, decoder_heads=16): + super().__init__() + self.decoder_dim = decoder_dim + self.decoder_embed = nn.Linear(encoder_dim, decoder_dim) + + # Decoder transformer blocks + self.decoder_blocks = nn.ModuleList([ + nn.TransformerEncoderLayer( + d_model=decoder_dim, + nhead=decoder_heads, + dim_feedforward=decoder_dim * 4, + dropout=0.0, + activation='gelu', + batch_first=True + ) + for _ in range(decoder_depth) + ]) + + self.decoder_norm = nn.LayerNorm(decoder_dim) + self.decoder_pred = nn.Linear(decoder_dim, patch_channel_dim) # Predict pixel values + + # Learnable mask token + self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim)) + + # Fixed positional embeddings + self.decoder_pos_embed = nn.Parameter( + torch.from_numpy( + get_2d_sincos_pos_embed(decoder_dim, int(math.sqrt(num_patches))) + ).float(), + requires_grad=False + ) + + def forward(self, x_visible: torch.Tensor, mask_visible: torch.Tensor, mask_masked: torch.Tensor) -> torch.Tensor: + """ + Args: + x_visible: [B, N_visible, encoder_dim] - encoded visible patches + mask_visible: [B, N_visible] - indices of visible patches + mask_masked: [B, N_masked] - indices of masked patches + Returns: + x_reconstructed: [B, N_masked, patch_pixels] - reconstructed masked patches + """ + B = x_visible.shape[0] + N_visible = mask_visible.shape[1] + N_masked = mask_masked.shape[1] + + # Project encoder features to decoder dimension + x_visible = self.decoder_embed(x_visible) # [B, N_visible, decoder_dim] + + # Create mask tokens for masked positions + mask_tokens = self.mask_token.expand(B, N_masked, -1) # [B, N_masked, decoder_dim] + + # Get positional embeddings for all patches + pos_embed_all = self.decoder_pos_embed.unsqueeze(0).expand(B, -1, -1) # [B, N_total, decoder_dim] + pos_visible = apply_mask(pos_embed_all, mask_visible) # [B, N_visible, decoder_dim] + pos_masked = apply_mask(pos_embed_all, mask_masked) # [B, N_masked, decoder_dim] + + # Add positional embeddings + x_visible = x_visible + pos_visible + mask_tokens = mask_tokens + pos_masked + + # Combine visible and masked tokens + # For simplicity, concatenate them (real MAE uses more sophisticated ordering) + x_full = torch.cat([x_visible, mask_tokens], dim=1) # [B, N_visible + N_masked, decoder_dim] + + # Apply decoder transformer blocks + for block in self.decoder_blocks: + x_full = block(x_full) + + x_full = self.decoder_norm(x_full) + + # Extract predictions for masked patches only + x_masked_pred = x_full[:, N_visible:] # [B, N_masked, decoder_dim] + + # Predict pixel values + x_reconstructed = self.decoder_pred(x_masked_pred) # [B, N_masked, patch_pixels] + + return x_reconstructed + +encoder_kwargs = dict( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + qkv_bias=True, # MAE typically uses bias + mae_in_dim=patch_channel_dim, +) + +decoder_kwargs = dict( + encoder_dim=768, + decoder_dim=512, + decoder_depth=8, + decoder_heads=16, +) + +def forward(self: ssl.Module, batch, stage): + out = {} + encoder: MAE_Encoder = self.encoder + decoder: MAE_Decoder = self.decoder + mae_loss: nn.Module = self.mae_loss + + # Patchify and get all patches with positions + image_patches = encoder.patchify(batch["image"]) # [B, N, patch_pixels] + all_patches = encoder.project_patches(image_patches) # [B, N, embed_dim] + pos_embedding = pos_embed(all_patches) + all_patches = all_patches + pos_embedding + + if not self.training: + # For validation, encode all patches for downstream tasks + out["embedding"] = encoder.encode_patches(all_patches) + out["sum_embedding"] = out["embedding"].sum(dim=1) + out["flat_embedding"] = out["embedding"].reshape(out["embedding"].shape[0], -1) + return out + + mask_visible, mask_masked = batch["mask_visible"], batch["mask_masked"] + + # Encode only visible patches + visible_patches = apply_mask(all_patches, mask_visible) # [B, N_visible, embed_dim] + encoded_visible = encoder.encode_patches(visible_patches) # [B, N_visible, embed_dim] + + # Decode to reconstruct masked patches + reconstructed_masked = decoder(encoded_visible, mask_visible, mask_masked) # [B, N_masked, patch_pixels] + + # Get ground truth masked patches + gt_masked_patches = apply_mask(image_patches, mask_masked) # [B, N_masked, patch_pixels] + + # Compute reconstruction loss (MSE in pixel space) + out["loss"] = mae_loss(reconstructed_masked, gt_masked_patches) + + # For monitoring + out["embedding"] = encoded_visible + out["sum_embedding"] = encoded_visible.sum(dim=1) + out["reconstructed"] = reconstructed_masked + out["ground_truth"] = gt_masked_patches + + return out + +module = ssl.Module( + encoder=MAE_Encoder(**encoder_kwargs), + decoder=MAE_Decoder(**decoder_kwargs), + forward=forward, + mae_loss=F.mse_loss, # Pixel MSE loss +) + +# Note: Linear probe uses visible patches only during training +linear_probe = ssl.callbacks.OnlineProbe( + "linear_probe", + "sum_embedding", + "label", + probe=torch.nn.Linear(768, 10), + loss_fn=torch.nn.CrossEntropyLoss(), + metrics={ + "top1": torchmetrics.classification.MulticlassAccuracy(10), + "top5": torchmetrics.classification.MulticlassAccuracy(10, top_k=5), + }, +) + +# RankMe on encoder outputs +rankme = ssl.callbacks.RankMe( + name="rankme", + target="flat_embedding", + queue_length=512, + target_shape=(num_patches, 768), +) + +# Initialize W&B logger +wandb_logger = WandbLogger( + project="mae-cifar10", + entity="slightly-more-badass", + name="mae-cifar10-run", + log_model=False, + offline=True, +) + +trainer = pl.Trainer( + max_epochs=100, # MAE typically needs more epochs + num_sanity_val_steps=0, + callbacks=[linear_probe, rankme], + precision="16-mixed", + logger=wandb_logger, + enable_checkpointing=False, +) + +manager = ssl.Manager(trainer=trainer, module=module, data=data) +manager() \ No newline at end of file From 509afdd93138c2f0f53d0bba5d6c5876c73bb804 Mon Sep 17 00:00:00 2001 From: Sami Date: Sun, 17 Aug 2025 21:44:06 +0000 Subject: [PATCH 16/28] WIP MAE simplifying arch --- stable_ssl/data/transforms.py | 27 +++- stable_ssl/tests/scripts/mae_arch_test.py | 109 +++++++++++++++ .../tests/scripts/train_ijepa_cifar10.py | 7 - .../tests/scripts/train_ijepa_inet1k.py | 127 +++++++----------- stable_ssl/tests/scripts/train_mae_cifar10.py | 74 +++++++--- 5 files changed, 230 insertions(+), 114 deletions(-) create mode 100644 stable_ssl/tests/scripts/mae_arch_test.py diff --git a/stable_ssl/data/transforms.py b/stable_ssl/data/transforms.py index 7ad57a8b..c4f6691d 100644 --- a/stable_ssl/data/transforms.py +++ b/stable_ssl/data/transforms.py @@ -795,6 +795,8 @@ def __init__( source: str = "image", target_visible: str = "mask_visible", target_masked: str = "mask_masked", + target_ids_restore: str = "ids_restore", + target_len_keep: str = "len_keep", ): super().__init__() self.patch_size = patch_size @@ -802,6 +804,8 @@ def __init__( self.source = source self.target_visible = target_visible self.target_masked = target_masked + self.target_ids_restore = target_ids_restore + self.target_len_keep = target_len_keep def __call__(self, x): source = self.nested_get(x, self.source) @@ -816,16 +820,27 @@ def __call__(self, x): ) num_patches = (H // self.patch_size) * (W // self.patch_size) - num_masked = int(num_patches * self.mask_ratio) - indices = torch.randperm(num_patches) - mask_masked = indices[:num_masked] - mask_visible = indices[num_masked:] - + len_keep = int(num_patches * (1 - self.mask_ratio)) + + # Generate random noise and shuffle indices (like MAE) + noise = torch.rand(num_patches) + ids_shuffle = torch.argsort(noise) + ids_restore = torch.argsort(ids_shuffle) # inverse permutation + + # Split into visible and masked + mask_visible = ids_shuffle[:len_keep] # first len_keep are visible + mask_masked = ids_shuffle[len_keep:] # rest are masked + + # Add to sample x[self.target_visible] = mask_visible - x[self.target_masked] = mask_masked + x[self.target_masked] = mask_masked + x[self.target_ids_restore] = ids_restore # NEW: for reconstructing full sequence + x[self.target_len_keep] = len_keep + return x + # class MultiTransforms(v2.Transform): # def __init__(self, transforms, repeats: list = None): # super().__init__() diff --git a/stable_ssl/tests/scripts/mae_arch_test.py b/stable_ssl/tests/scripts/mae_arch_test.py new file mode 100644 index 00000000..c7f9c812 --- /dev/null +++ b/stable_ssl/tests/scripts/mae_arch_test.py @@ -0,0 +1,109 @@ +import torch +import torch.nn as nn +from timm.models.vision_transformer import VisionTransformer + +encoder_kwargs = dict( + img_size=32, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + # include cls token for positional embedding: + # https://github.com/facebookresearch/mae/blob/main/models_mae.py#L68-L69 + no_embed_class=False, + norm_layer=nn.LayerNorm, +) + +decoder_kwargs = dict( + img_size=32, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + class_token=True, + no_embed_class=False, + norm_layer=nn.LayerNorm, +) + +class MAE_Encoder(VisionTransformer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # -------------------------------------------------------------------------- + # MAE encoder specifics + # replicated from timm's + self.num_patches = self.patch_embed.num_patches + self.num_prefix_tokens + # TODO Exclude this and add posembeds from outside ? Can we do that ? I dont think so since we work on raw pixels + # self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, self.embed_dim), requires_grad=False) # fixed sin-cos embedding + + +class MAE_Decoder(VisionTransformer): + def __init__(self, *args, **kwargs): + patch_size = kwargs.get('patch_size', 16) + super().__init__(*args, **kwargs) + # -------------------------------------------------------------------------- + # MAE decoder specifics + # replicated from timm's + self.num_patches = self.patch_embed.num_patches + self.num_prefix_tokens + self.out_proj = nn.Linear(self.embed_dim, self.num_patches * patch_size**2) + +def pos_embed(patches: torch.Tensor, with_cls: bool = True) -> torch.Tensor: + pass + + +def forward_encoder(encoder: MAE_Encoder, images: torch.Tensor, mask:torch.Tensor) -> torch.Tensor: + patches = encoder.patch_embed(images) + posemb = pos_embed(patches, with_cls = True) + posemb_patches = patches + posemb[:, 1:, :] + cls_token = encoder.cls_token + posemb[:, :1, :] + cls_tokens = cls_token.expand(patches.shape[0], -1, -1) + + masked_patches = apply_mask(posemb_patches, mask) + + patches = torch.cat([cls_tokens, masked_patches], dim=1) + for blk in encoder.blocks: + patches = blk(patches) + + patches = encoder.norm(patches) + return patches + +def apply_mask(patches: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + # TODO + pass + +def forward_decoder(decoder: MAE_Decoder, batch: dict) -> torch.Tensor: + patches = batch["patches"] + ids_restore = batch["ids_restore"] + + patches = decoder.patch_embed(patches) + + # NOTE from mae repo directly + # ids_restore im assuming can be the total set of indices + # dim1 is the number of tokens to predict. +1 for cls token? + # so we dont predict cls token? + mask_tokens = decoder.mask_token.repeat( + patches.shape[0], + ids_restore.shape[1] + 1 - patches.shape[1], + 1 + ) + # we want to pos-embed according to their original positions + # so we unshuffle first using ids_restore which is the inverse permutation + patches_ = torch.cat([patches[:, 1:, :], mask_tokens], dim=1) + patches_ = torch.gather(patches_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, patches.shape[2])) # unshuffle + # append cls token + patches = torch.cat([patches[:, :1, :], patches_], dim=1) + patches = patches + pos_embed(patches, with_cls = True) + # apply transformer blocks + patches = decoder.blocks(patches) + patches = decoder.norm(patches) + # predictor projection + patches = decoder.pred(patches) + # remove cls token + return patches[:, 1:, :] + +def forward_mae(self, batch: dict, stage): + out = {} + encoder = self.encoder + decoder = self.decoder + + if stage == "train": + pass \ No newline at end of file diff --git a/stable_ssl/tests/scripts/train_ijepa_cifar10.py b/stable_ssl/tests/scripts/train_ijepa_cifar10.py index 1e8ed6bc..590903e1 100644 --- a/stable_ssl/tests/scripts/train_ijepa_cifar10.py +++ b/stable_ssl/tests/scripts/train_ijepa_cifar10.py @@ -352,13 +352,6 @@ def forward(self: ssl.Module, batch, stage): target_shape=(num_patches, 768), ) -trainer = pl.Trainer( - max_epochs=6, - num_sanity_val_steps=0, - precision="16-mixed", - enable_checkpointing=False, -) - # Initialize W&B logger with explicit settings wandb_logger = WandbLogger( project="ijepa-cifar10", diff --git a/stable_ssl/tests/scripts/train_ijepa_inet1k.py b/stable_ssl/tests/scripts/train_ijepa_inet1k.py index 53202942..315f8d3b 100644 --- a/stable_ssl/tests/scripts/train_ijepa_inet1k.py +++ b/stable_ssl/tests/scripts/train_ijepa_inet1k.py @@ -1,32 +1,30 @@ -import os - -os.environ["HF_HOME"] = "/mnt/data/sami/huggingface" -os.environ["HF_DATASETS_CACHE"] = "/mnt/data/sami/huggingface/datasets" -os.environ["HF_HUB_CACHE"] = "/mnt/data/sami/huggingface/hub" -os.environ["TORCH_HOME"] = "/mnt/data/sami/torch" - -import math -import lightning as pl -import torch -import torch.nn.functional as F -import torchmetrics -import torchvision -from lightning.pytorch.loggers import WandbLogger -from timm.models.vision_transformer import VisionTransformer -from torch import nn - -import stable_ssl as ssl -from stable_ssl.backbone.utils import TeacherStudentModule -from stable_ssl.data import transforms -from stable_ssl.data.utils import Dataset -from stable_ssl.utils.pos_embed import get_2d_sincos_pos_embed +import torch +import math +import lightning as pl +import torch.nn.functional as F +import torchmetrics +from lightning.pytorch.loggers import WandbLogger +from timm.models.vision_transformer import VisionTransformer +from torch import nn + +import stable_ssl as ssl +from stable_ssl.backbone.utils import TeacherStudentModule +from stable_ssl.data import transforms +from stable_ssl.utils.pos_embed import get_2d_sincos_pos_embed +from lightning.pytorch.strategies import DDPStrategy + + +train_batch_size = 128 +val_batch_size = 128 +num_workers = 32 +num_classes = 1000 mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] height, width, patch_size = 256, 256, 16 crop_height, crop_width = 224, 224 # CIFAR-10 is 32x32, but on INET, IJEPA uses 224 # We precompute these so the predictor can make sinusoidal posembeds -num_patches = (height // patch_size) * (width // patch_size) +num_patches = (crop_height // patch_size) * (crop_width // patch_size) patch_channel_dim = 3 * patch_size * patch_size # Based on the in1k_vith14_ep300.yaml config in the ijepa repository @@ -59,47 +57,16 @@ inet1k_train = ssl.data.HFDataset( path="ILSVRC/imagenet-1k", split="train", - transform=val_transform, + transform=train_transform, ) inet1k_val = ssl.data.HFDataset( path="ILSVRC/imagenet-1k", - split="val", + split="validation", transform=val_transform, ) -# TODO For some reason streaming=True hangs. I could also use noface imagenet from randall-lab/ on hf -# inet1k_train = ssl.data.HFDataset( -# path="mlx-vision/imagenet-1k", -# split="train", -# transform=train_transform, -# streaming=True, -# ) - -# inet1k_val = ssl.data.HFDataset( -# path="mlx-vision/imagenet-1k", -# split="val", -# transform=val_transform, -# streaming=True, -# ) - -class IndexedDataset(Dataset): - """Custom dataset wrapper that adds sample_idx to each sample.""" - - def __init__(self, dataset, transform=None): - super().__init__(transform) - self.dataset = dataset - - def __getitem__(self, idx): - image, label = self.dataset[idx] - sample = {"image": image, "label": label, "sample_idx": idx} - return self.process_sample(sample) - - def __len__(self): - return len(self.dataset) - - def standardize_masks(batch: list[dict]): context_indices = [item.pop("mask_context") for item in batch] target_indices = [item.pop("masks_target") for item in batch] @@ -123,24 +90,24 @@ def standardize_masks(batch: list[dict]): return batch -train_dataset = IndexedDataset(inet1k_train, transform=train_transform) # IJEPA does not use multi-view sampling like SimCLR etc. because it processes # single views and handles masking at the model level train = torch.utils.data.DataLoader( - dataset=train_dataset, - batch_size=512, + dataset=inet1k_train, + batch_size=train_batch_size, shuffle=True, # Regular shuffling, no RepeatedRandomSampler - num_workers=0, + num_workers=num_workers, drop_last=True, collate_fn=standardize_masks, + pin_memory=True, + persistent_workers=True, ) -val_dataset = IndexedDataset(inet1k_val, transform=val_transform) val = torch.utils.data.DataLoader( - dataset=val_dataset, - batch_size=128, - num_workers=0, - shuffle=True, + dataset=inet1k_val, + batch_size=val_batch_size, + num_workers=num_workers, + shuffle=False, ) @@ -305,8 +272,8 @@ def predict_targets( predictor_kwargs = dict( patch_size=patch_size, embed_dim=384, - depth=6, - num_heads=6, + depth=12, + num_heads=12, qkv_bias=False, ijepa_encoder_dim=768, predictor_num_patches=num_patches, @@ -364,33 +331,27 @@ def forward(self: ssl.Module, batch, stage): "linear_probe", "sum_embedding", "label", - probe=torch.nn.Linear(768, 10), + probe=torch.nn.Linear(768, num_classes), loss_fn=torch.nn.CrossEntropyLoss(), metrics={ - "top1": torchmetrics.classification.MulticlassAccuracy(10), - "top5": torchmetrics.classification.MulticlassAccuracy(10, top_k=5), + "top1": torchmetrics.classification.MulticlassAccuracy(num_classes), + "top5": torchmetrics.classification.MulticlassAccuracy(num_classes, top_k=5), }, ) rankme = ssl.callbacks.RankMe( name="rankme", target="flat_embedding", - queue_length=512, # NOTE must be >= batch_size + queue_length=min(512, train_batch_size), # NOTE must be >= batch_size target_shape=(num_patches, 768), ) -trainer = pl.Trainer( - max_epochs=6, - num_sanity_val_steps=0, - precision="16-mixed", - enable_checkpointing=False, -) # Initialize W&B logger with explicit settings wandb_logger = WandbLogger( - project="ijepa-imagenette", + project="ijepa-inet1k", entity="slightly-more-badass", # Your W&B entity - name="ijepa-cifar10-run", + name="ijepa-inet1k-run", log_model=False, # Set to True if you want to save model artifacts offline=True, # Ensure offline mode ) @@ -402,6 +363,14 @@ def forward(self: ssl.Module, batch, stage): precision="16-mixed", logger=wandb_logger, enable_checkpointing=False, + accelerator="gpu", + devices=8, + strategy=DDPStrategy( + find_unused_parameters=True, # this is because only teacher's params are used in the teacher-student module + static_graph=True, + gradient_as_bucket_view=True, + ) ) + manager = ssl.Manager(trainer=trainer, module=module, data=data) manager() diff --git a/stable_ssl/tests/scripts/train_mae_cifar10.py b/stable_ssl/tests/scripts/train_mae_cifar10.py index 2ee83210..c7e2160a 100644 --- a/stable_ssl/tests/scripts/train_mae_cifar10.py +++ b/stable_ssl/tests/scripts/train_mae_cifar10.py @@ -14,24 +14,30 @@ from stable_ssl.data.utils import Dataset from stable_ssl.utils.pos_embed import get_2d_sincos_pos_embed + mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] height, width, patch_size = 32, 32, 4 crop_height, crop_width = 32, 32 num_patches = (height // patch_size) * (width // patch_size) patch_channel_dim = 3 * patch_size * patch_size +mask_ratio = 0.75 +num_visible_patches = int(num_patches * (1 - mask_ratio)) +num_classes = 10 +batch_size = 512 +val_batch_size = 128 mask_transform_kwargs = dict( patch_size=patch_size, - mask_ratio=0.75, + mask_ratio=mask_ratio, source="image", target_visible="mask_visible", target_masked="mask_masked", ) train_transform = transforms.Compose( - transforms.RGB(), - transforms.RandomResizedCrop((crop_height, crop_width), scale=(0.2, 1.0)), + transforms.RandomResizedCrop((crop_height, crop_width), scale=(0.2, 1.0), interpolation=3), # 3 is bicubic + transforms.RandomHorizontalFlip(), transforms.RandomMask(**mask_transform_kwargs), transforms.ToImage(mean=mean, std=std), ) @@ -43,7 +49,7 @@ transforms.ToImage(mean=mean, std=std), ) -# Use torchvision CIFAR-10 + cifar_train = torchvision.datasets.CIFAR10( root="/tmp/cifar10", train=True, download=True ) @@ -51,7 +57,10 @@ root="/tmp/cifar10", train=False, download=True ) + class IndexedDataset(Dataset): + """Custom dataset wrapper that adds sample_idx to each sample.""" + def __init__(self, dataset, transform=None): super().__init__(transform) self.dataset = dataset @@ -64,6 +73,7 @@ def __getitem__(self, idx): def __len__(self): return len(self.dataset) + def standardize_masks(batch: list[dict]): """Simpler collate function for MAE - just handle visible/masked indices""" visible_indices = [item.pop("mask_visible") for item in batch] @@ -81,22 +91,25 @@ def standardize_masks(batch: list[dict]): batch["mask_masked"] = torch.utils.data.default_collate(masked_batch) return batch + train_dataset = IndexedDataset(cifar_train, transform=train_transform) +val_dataset = IndexedDataset(cifar_val, transform=val_transform) train = torch.utils.data.DataLoader( dataset=train_dataset, - batch_size=512, + batch_size=batch_size, shuffle=True, - num_workers=32, + num_workers=0, drop_last=True, collate_fn=standardize_masks, + pin_memory=True, ) -val_dataset = IndexedDataset(cifar_val, transform=val_transform) val = torch.utils.data.DataLoader( dataset=val_dataset, - batch_size=128, - num_workers=32, - shuffle=True, + batch_size=val_batch_size, + num_workers=0, + shuffle=False, + pin_memory=True, ) data = ssl.data.DataModule(train=train, val=val) @@ -146,6 +159,16 @@ def patchify(self, image: torch.Tensor) -> torch.Tensor: patches = patches.reshape(B, num_patch_h * num_patch_w, P * P * C) return patches + def unpatchify(self, patches: torch.Tensor) -> torch.Tensor: + B, N, P2C = patches.shape + P = patch_size + C = patch_channel_dim // (P * P) + assert P2C == C * P * P + H = W = (num_patches * P) // (P * P) + patches = patches.reshape(B, H, W, P, P) + patches = patches.permute(0, 3, 1, 4, 2) # [B, P, H, P, W] + return patches.reshape(B, P * P * C, H, W) # [B, P * P * C, H, W] + def project_patches(self, patches: torch.Tensor) -> torch.Tensor: return self.mae_patch_project(patches) @@ -255,7 +278,7 @@ def forward(self: ssl.Module, batch, stage): out = {} encoder: MAE_Encoder = self.encoder decoder: MAE_Decoder = self.decoder - mae_loss: nn.Module = self.mae_loss + mae_loss: nn.Module = self.mae_loss # Patchify and get all patches with positions image_patches = encoder.patchify(batch["image"]) # [B, N, patch_pixels] @@ -286,11 +309,11 @@ def forward(self: ssl.Module, batch, stage): out["loss"] = mae_loss(reconstructed_masked, gt_masked_patches) # For monitoring - out["embedding"] = encoded_visible - out["sum_embedding"] = encoded_visible.sum(dim=1) - out["reconstructed"] = reconstructed_masked - out["ground_truth"] = gt_masked_patches - + out["embedding"] = encoded_visible + out["reconstructed"] = reconstructed_masked + out["sum_embedding"] = out["embedding"].sum(dim=1) + out["flat_embedding"] = out["embedding"].reshape(out["embedding"].shape[0], -1) + out["ground_truth"] = gt_masked_patches return out module = ssl.Module( @@ -305,11 +328,11 @@ def forward(self: ssl.Module, batch, stage): "linear_probe", "sum_embedding", "label", - probe=torch.nn.Linear(768, 10), + probe=torch.nn.Linear(768, num_classes), loss_fn=torch.nn.CrossEntropyLoss(), metrics={ - "top1": torchmetrics.classification.MulticlassAccuracy(10), - "top5": torchmetrics.classification.MulticlassAccuracy(10, top_k=5), + "top1": torchmetrics.classification.MulticlassAccuracy(num_classes), + "top5": torchmetrics.classification.MulticlassAccuracy(num_classes, top_k=5), }, ) @@ -317,8 +340,8 @@ def forward(self: ssl.Module, batch, stage): rankme = ssl.callbacks.RankMe( name="rankme", target="flat_embedding", - queue_length=512, - target_shape=(num_patches, 768), + queue_length=min(512, batch_size), + target_shape=(num_visible_patches, 768), ) # Initialize W&B logger @@ -331,12 +354,19 @@ def forward(self: ssl.Module, batch, stage): ) trainer = pl.Trainer( - max_epochs=100, # MAE typically needs more epochs + max_epochs=6, num_sanity_val_steps=0, callbacks=[linear_probe, rankme], precision="16-mixed", logger=wandb_logger, enable_checkpointing=False, + accelerator="gpu", + devices=1, + # strategy=DDPStrategy( + # find_unused_parameters=True, # this is because only teacher's params are used in the teacher-student module + # static_graph=True, + # gradient_as_bucket_view=True, + # ) ) manager = ssl.Manager(trainer=trainer, module=module, data=data) From de7f9fb9104e09be0874b4ed68c576b6fc5a6735 Mon Sep 17 00:00:00 2001 From: Sami Date: Sun, 17 Aug 2025 22:04:29 +0000 Subject: [PATCH 17/28] more testing archs --- stable_ssl/tests/scripts/mae_arch_test2.py | 207 +++++++++++++++++++++ 1 file changed, 207 insertions(+) create mode 100644 stable_ssl/tests/scripts/mae_arch_test2.py diff --git a/stable_ssl/tests/scripts/mae_arch_test2.py b/stable_ssl/tests/scripts/mae_arch_test2.py new file mode 100644 index 00000000..da5891d9 --- /dev/null +++ b/stable_ssl/tests/scripts/mae_arch_test2.py @@ -0,0 +1,207 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.vision_transformer import VisionTransformer + +# -------------------- utils: positional embeddings -------------------- + +def get_2d_sincos_pos_embed(embed_dim: int, grid_size: int, cls_token: bool = True): + """Returns [1, grid_size*grid_size + cls, embed_dim] fixed sin-cos embeddings.""" + def get_1d_sin_cos(d, n): + omega = torch.arange(d // 2, dtype=torch.float32) / (d // 2) + omega = 1.0 / (10000 ** omega) # [d/2] + pos = torch.arange(n, dtype=torch.float32) # [n] + out = torch.einsum('n,d->nd', pos, omega) # [n, d/2] + return torch.cat([out.sin(), out.cos()], dim=1) # [n, d] + + assert embed_dim % 2 == 0 + # 2D grid + pe_h = get_1d_sin_cos(embed_dim // 2, grid_size) # [G, D/2] + pe_w = get_1d_sin_cos(embed_dim // 2, grid_size) # [G, D/2] + pe = ( + torch.stack([pe_h.unsqueeze(1).expand(-1, grid_size, -1), + pe_w.unsqueeze(0).expand(grid_size, -1, -1)], dim=2) + .reshape(grid_size * grid_size, embed_dim) + ) # [G*G, D] + if cls_token: + pe = torch.cat([torch.zeros(1, embed_dim), pe], dim=0) # [1+G*G, D] + return pe.unsqueeze(0) # [1, 1+G*G, D] + +def pos_embed(x: torch.Tensor, with_cls: bool = True) -> torch.Tensor: + """ + x: [B, N(=H*W or tokens), D] + Returns fixed sin-cos PE of shape [B, (1+N), D] if with_cls else [B, N, D]. + """ + B, N, D = x.shape + G = int(math.sqrt(N)) # assumes square grid of patches + assert G * G == N, f"pos_embed expects square grid, got N={N}" + pe = get_2d_sincos_pos_embed(D, G, cls_token=with_cls).to(x.device, x.dtype) + return pe.expand(B, -1, -1) # [B, 1+N, D] or [B, N, D] depending on with_cls + +# -------------------- utils: masking and patchify -------------------- + +def apply_mask(x: torch.Tensor, ids_keep: torch.Tensor) -> torch.Tensor: + """ + Keep selection by indices. + x: [B, N, D], ids_keep: [B, K] + returns: [B, K, D] + """ + B, N, D = x.shape + idx = ids_keep.unsqueeze(-1).expand(-1, -1, D) # [B, K, D] + return x.gather(dim=1, index=idx) + + +# -------------------- encoder -------------------- + +class MAE_Encoder(VisionTransformer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # number of patch tokens (no cls) + self.patch_size = kwargs.get('patch_size', 16) + self.num_patches = self.patch_embed.num_patches # do NOT add prefix here + + def patchify(self, images: torch.Tensor) -> torch.Tensor: + """ + images: [B, C, H, W] -> [B, N, P*P*C] + """ + B, C, H, W = images.shape + P = self.patch_size + assert H % P == 0 and W % P == 0 + h = H // P + w = W // P + x = images.reshape(B, C, h, P, w, P) + x = x.permute(0, 2, 4, 3, 5, 1).reshape(B, h * w, P * P * C) + return x + + +def forward_encoder(encoder: MAE_Encoder, images: torch.Tensor, ids_keep: torch.Tensor) -> torch.Tensor: + """ + Returns visible token latents (with cls prepended), projected by encoder blocks. + """ + # Patch embed: [B, N, D] + x = encoder.patch_embed(images) + # Fixed PE + pe = pos_embed(x, with_cls=True) # [B, 1+N, D] + x = x + pe[:, 1:, :] # add PE to patch tokens + # cls token with PE + cls_tok = encoder.cls_token + pe[:, :1, :] # [B, 1, D] + cls_tok = cls_tok.expand(x.shape[0], -1, -1) + + # Keep only visible tokens + x_vis = apply_mask(x, ids_keep) # [B, K, D] + # prepend cls + x = torch.cat([cls_tok, x_vis], dim=1) # [B, 1+K, D] + + # Blocks + for blk in encoder.blocks: + x = blk(x) + x = encoder.norm(x) + return x # [B, 1+K, D] + +# -------------------- decoder -------------------- + +class MAE_Decoder(VisionTransformer): + def __init__(self, *args, **kwargs): + in_chans = kwargs.get('in_chans', 3) + patch_size = kwargs.get('patch_size', 16) + self.enc_dim = kwargs.get('enc_dim', 768) + super().__init__(*args, **kwargs) + + # tokens in the decoded grid (no cls) + self.num_patches = self.patch_embed.num_patches + + # Map encoder dim -> decoder dim (use embed_dim of this decoder) + # You must set this externally once you know the encoder dim: + self.decoder_embed = nn.Linear(self.enc_dim, self.embed_dim) # set to Linear(enc_dim, self.embed_dim) after init + + # mask token for missing positions + self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) + + # predict pixel values per token (P^2 * C) + self.patch_size = patch_size + self.in_chans = in_chans + self.decoder_pred = nn.Linear(self.embed_dim, (patch_size ** 2) * in_chans) + +def forward_decoder( + decoder: MAE_Decoder, + enc_tokens_vis: torch.Tensor, # [B, 1+K, D_enc] (cls + visible) + ids_restore: torch.Tensor # [B, N] +) -> torch.Tensor: + """ + Returns pixel predictions for all tokens (no cls): [B, N, P^2*C] + """ + B, _, D_enc = enc_tokens_vis.shape + + # project to decoder dim + x = decoder.decoder_embed(enc_tokens_vis) # [B, 1+K, D_dec] + + # prepare full grid by inserting mask tokens at missing positions + # 1) split cls vs visible + x_cls, x_vis = x[:, :1, :], x[:, 1:, :] # [B,1,D], [B,K,D] + # 2) number of tokens in grid + N = decoder.num_patches + K = x_vis.shape[1] + # 3) tokens to fill + mask_tokens = decoder.mask_token.expand(B, N - K, -1) # [B, N-K, D] + # 4) combine vis + mask, then unshuffle + x_ = torch.cat([x_vis, mask_tokens], dim=1) # [B, N, D] + x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).expand(-1, -1, x_.shape[-1])) # [B, N, D] + # 5) prepend cls + x = torch.cat([x_cls, x_], dim=1) # [B, 1+N, D] + + # add fixed PE for decoder + pe = pos_embed(x[:, 1:, :], with_cls=True) # reuse shape logic + x = x + pe # [B, 1+N, D] + + # transformer blocks + for blk in decoder.blocks: + x = blk(x) + x = decoder.norm(x) + + # predict pixels per token, drop cls + x = decoder.decoder_pred(x[:, 1:, :]) # [B, N, P^2*C] + return x + +# -------------------- MAE training forward -------------------- +mask_ratio = 0.75 + +def forward_mae(self, batch: dict, stage): + """ + Expects batch["image"]: [B,3,H,W] + Produces MAE reconstruction loss on masked patches. + """ + out = {} + encoder: MAE_Encoder = self.encoder + decoder: MAE_Decoder = self.decoder + + images = batch["image"] # [B,3,H,W] + ids_keep = batch["mask_visible"] + ids_restore = batch["ids_restore"] + mask_masked = batch["mask_masked"] + + B = images.shape[0] + N = encoder.num_patches + + # 2) encode visible tokens (cls+visible) + enc_tokens_vis = forward_encoder(encoder, images, ids_keep) # [B,1+K,D_enc] + + # 4) decode to pixel predictions for ALL tokens (no cls) + pred_pix = forward_decoder(decoder, enc_tokens_vis, ids_restore) # [B,N,P^2*C] + + # 5) ground-truth patches + target_pix = encoder.patchify(images) # [B,N,P^2*C] + + # 6) compute loss ONLY on masked tokens + mask_exp = mask_masked.unsqueeze(-1).type_as(pred_pix) # [B,N,1] + loss = ((pred_pix - target_pix) ** 2 * mask_exp).sum() / mask_exp.sum().clamp_min(1.0) + + # 7) populate outputs for logging / probes + if stage != "train": + return out + + out["loss"] = loss + out["mask_ratio"] = torch.tensor(mask_ratio, device=images.device) + out["ids_restore"] = ids_restore + out["mask"] = mask_masked + return out \ No newline at end of file From f7e76785b3723444310ade4da2f5a128015c0c1f Mon Sep 17 00:00:00 2001 From: Sami Date: Sun, 17 Aug 2025 22:06:43 +0000 Subject: [PATCH 18/28] more testing archs --- stable_ssl/tests/scripts/mae_arch_test2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stable_ssl/tests/scripts/mae_arch_test2.py b/stable_ssl/tests/scripts/mae_arch_test2.py index da5891d9..9380170c 100644 --- a/stable_ssl/tests/scripts/mae_arch_test2.py +++ b/stable_ssl/tests/scripts/mae_arch_test2.py @@ -194,7 +194,7 @@ def forward_mae(self, batch: dict, stage): # 6) compute loss ONLY on masked tokens mask_exp = mask_masked.unsqueeze(-1).type_as(pred_pix) # [B,N,1] - loss = ((pred_pix - target_pix) ** 2 * mask_exp).sum() / mask_exp.sum().clamp_min(1.0) + loss = self.loss_fn(pred_pix, target_pix, mask_exp) # 7) populate outputs for logging / probes if stage != "train": From 603977ac4c52c82ac06b21ff5648e6636e21d6f8 Mon Sep 17 00:00:00 2001 From: Sami Date: Mon, 18 Aug 2025 06:04:52 +0000 Subject: [PATCH 19/28] mae cifar10 --- stable_ssl/data/transforms.py | 6 +- .../tests/scripts/train_ijepa_inet1k.py | 6 +- stable_ssl/tests/scripts/train_mae_cifar10.py | 307 +++++++----------- 3 files changed, 124 insertions(+), 195 deletions(-) diff --git a/stable_ssl/data/transforms.py b/stable_ssl/data/transforms.py index c4f6691d..cfc0ddb2 100644 --- a/stable_ssl/data/transforms.py +++ b/stable_ssl/data/transforms.py @@ -757,9 +757,9 @@ def __init__( def __call__(self, x): source = self.nested_get(x, self.source) if isinstance(source, PIL.Image.Image): - H, W = source.size + W, H = source.size # PIL is W,H elif isinstance(source, torch.Tensor): - # NOTE assumes _HW + # assumes H W H, W = source.shape[-2:] else: raise ValueError( @@ -810,7 +810,7 @@ def __init__( def __call__(self, x): source = self.nested_get(x, self.source) if isinstance(source, PIL.Image.Image): - H, W = source.size + W, H = source.size # PIL is W,H elif isinstance(source, torch.Tensor): # NOTE assumes _HW H, W = source.shape[-2:] diff --git a/stable_ssl/tests/scripts/train_ijepa_inet1k.py b/stable_ssl/tests/scripts/train_ijepa_inet1k.py index 315f8d3b..259be8bd 100644 --- a/stable_ssl/tests/scripts/train_ijepa_inet1k.py +++ b/stable_ssl/tests/scripts/train_ijepa_inet1k.py @@ -282,9 +282,7 @@ def predict_targets( def forward(self: ssl.Module, batch, stage): out = {} - target_encoder: IJEPA_Encoder = ( - self.target_encoder.teacher - ) # NOTE Would this break anything? + target_encoder: IJEPA_Encoder = self.target_encoder.teacher context_encoder: IJEPA_Encoder = self.context_encoder predictor: IJEPA_Predictor = self.predictor ijepa_loss: nn.Module = self.ijepa_loss @@ -357,7 +355,7 @@ def forward(self: ssl.Module, batch, stage): ) trainer = pl.Trainer( - max_epochs=6, + max_epochs=300, num_sanity_val_steps=0, # Skip sanity check as queues need to be filled first callbacks=[linear_probe, rankme], precision="16-mixed", diff --git a/stable_ssl/tests/scripts/train_mae_cifar10.py b/stable_ssl/tests/scripts/train_mae_cifar10.py index c7e2160a..bc07d2dd 100644 --- a/stable_ssl/tests/scripts/train_mae_cifar10.py +++ b/stable_ssl/tests/scripts/train_mae_cifar10.py @@ -26,6 +26,7 @@ num_classes = 10 batch_size = 512 val_batch_size = 128 +num_workers = 16 mask_transform_kwargs = dict( patch_size=patch_size, @@ -74,23 +75,6 @@ def __len__(self): return len(self.dataset) -def standardize_masks(batch: list[dict]): - """Simpler collate function for MAE - just handle visible/masked indices""" - visible_indices = [item.pop("mask_visible") for item in batch] - masked_indices = [item.pop("mask_masked") for item in batch] - batch = torch.utils.data.default_collate(batch) - - # Standardize to minimum length - min_visible = min(len(vis) for vis in visible_indices) - min_masked = min(len(mask) for mask in masked_indices) - - visible_batch = [vis[:min_visible] for vis in visible_indices] - masked_batch = [mask[:min_masked] for mask in masked_indices] - - batch["mask_visible"] = torch.utils.data.default_collate(visible_batch) - batch["mask_masked"] = torch.utils.data.default_collate(masked_batch) - return batch - train_dataset = IndexedDataset(cifar_train, transform=train_transform) val_dataset = IndexedDataset(cifar_val, transform=val_transform) @@ -98,25 +82,26 @@ def standardize_masks(batch: list[dict]): dataset=train_dataset, batch_size=batch_size, shuffle=True, - num_workers=0, + num_workers=num_workers, drop_last=True, - collate_fn=standardize_masks, + collate_fn=torch.utils.data.default_collate, pin_memory=True, ) val = torch.utils.data.DataLoader( dataset=val_dataset, batch_size=val_batch_size, - num_workers=0, + num_workers=num_workers, shuffle=False, + collate_fn=torch.utils.data.default_collate, pin_memory=True, ) data = ssl.data.DataModule(train=train, val=val) -def pos_embed(patches: torch.Tensor) -> torch.Tensor: - return ( +def pos_embed(patches: torch.Tensor, with_cls: bool = True) -> torch.Tensor: + embed = ( torch.from_numpy( get_2d_sincos_pos_embed(patches.shape[-1], int(math.sqrt(patches.shape[1]))) ) @@ -124,6 +109,13 @@ def pos_embed(patches: torch.Tensor) -> torch.Tensor: .float() .repeat(patches.shape[0], 1, 1) ) + if with_cls: + embed = torch.cat([ + torch.zeros(embed.shape[0], 1, embed.shape[2], device=embed.device), + embed + ], dim=1) + + return embed def apply_mask(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: @@ -132,195 +124,139 @@ def apply_mask(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: mask_expanded = mask.unsqueeze(-1).expand(-1, -1, D) return torch.gather(x, dim=1, index=mask_expanded) +def patchify(images: torch.Tensor, patch_size: int) -> torch.Tensor: + """ + images: [B, C, H, W] -> [B, N, P*P*C] + """ + B, C, H, W = images.shape + P = patch_size + assert H % P == 0 and W % P == 0 + h = H // P + w = W // P + x = images.reshape(B, C, h, P, w, P) + x = x.permute(0, 2, 4, 3, 5, 1).reshape(B, h * w, P * P * C) + return x + class MAE_Encoder(VisionTransformer): - """MAE encoder - processes only visible patches""" - def __init__(self, *args, **kwargs): - self.weight_init = kwargs.get("weight_init", "") - self.fix_init = kwargs.get("fix_init", False) - mae_in_dim = kwargs.pop("mae_in_dim") + mae_in_dim = kwargs.pop('mae_in_dim', 3 * patch_size * patch_size) super().__init__(*args, **kwargs) - + # number of patch tokens (no cls) + self.patch_size = kwargs.get('patch_size', 16) + self.num_patches = self.patch_embed.num_patches # do NOT add prefix here self.mae_patch_project = nn.Linear(mae_in_dim, self.embed_dim) - - if self.weight_init != "skip": - self.init_weights(self.weight_init) - if self.fix_init: - self.fix_init_weight() - - def patchify(self, image: torch.Tensor) -> torch.Tensor: - B, C, H, W = image.shape - P = patch_size - - patches = image.unfold(2, P, P).unfold(3, P, P) - patches = patches.permute(0, 2, 3, 1, 4, 5) - num_patch_h, num_patch_w = patches.shape[1], patches.shape[2] - patches = patches.reshape(B, num_patch_h * num_patch_w, P * P * C) - return patches - - def unpatchify(self, patches: torch.Tensor) -> torch.Tensor: - B, N, P2C = patches.shape - P = patch_size - C = patch_channel_dim // (P * P) - assert P2C == C * P * P - H = W = (num_patches * P) // (P * P) - patches = patches.reshape(B, H, W, P, P) - patches = patches.permute(0, 3, 1, 4, 2) # [B, P, H, P, W] - return patches.reshape(B, P * P * C, H, W) # [B, P * P * C, H, W] def project_patches(self, patches: torch.Tensor) -> torch.Tensor: return self.mae_patch_project(patches) - def encode_patches(self, patches: torch.Tensor) -> torch.Tensor: - x = self.blocks(patches) - x = self.norm(x) - return x -class MAE_Decoder(nn.Module): - """MAE decoder - reconstructs full image from visible patches + mask tokens""" +class MAE_Decoder(VisionTransformer): + def __init__(self, *args, **kwargs): + mae_enc_dim = kwargs.pop('mae_enc_dim', 768) + super().__init__(*args, **kwargs) + in_chans = kwargs.get('in_chans', 3) + patch_size = kwargs.get('patch_size', 16) + # tokens in the decoded grid (no cls) + self.num_patches = self.patch_embed.num_patches + self.decoder_embed = nn.Linear(mae_enc_dim, self.embed_dim) + # mask token for missing positions + self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) + # predict pixel values per token (P^2 * C) + self.patch_size = patch_size + self.in_chans = in_chans + self.decoder_pred = nn.Linear(self.embed_dim, (patch_size ** 2) * in_chans) + + +def _forward_decoder(self, batch: dict, out: dict, stage) -> torch.Tensor: + decoder: MAE_Decoder = self.decoder + patches_visible = out["embeddings"] + inverse_shuffle = batch["ids_restore"].unsqueeze(-1).expand(-1, -1, decoder.embed_dim) + decoder_patches = decoder.decoder_embed(patches_visible) + + patches_cls, patches_visible = torch.split(decoder_patches, [1, decoder_patches.shape[1] - 1], dim=1) + batch_size = patches_visible.shape[0] + num_patches = decoder.num_patches + num_visible = patches_visible.shape[1] + + mask_tokens = decoder.mask_token.expand(batch_size, num_patches - num_visible, -1) + # combine visible patches and mask tokens then unshuffle into place + patches = torch.cat([patches_visible, mask_tokens], dim=1) + unshuffled_patches = torch.gather(patches, dim=1, index=inverse_shuffle) + patches = torch.cat([patches_cls, unshuffled_patches], dim=1) + + pe = pos_embed(patches[:, 1:, :], with_cls=True) + patches = patches + pe + for blk in decoder.blocks: + patches = blk(patches) + patches = decoder.norm(patches) + # remove cls token + patches = patches[:, 1:, :] + patches = decoder.decoder_pred(patches) + return patches + + +def forward(self, batch: dict, stage): + out = {} + encoder: MAE_Encoder = self.encoder + images = batch["image"] + image_patches = patchify(images, patch_size) + patches = encoder.project_patches(image_patches) + posemb_cls, posemb_patches = torch.split(pos_embed(patches, with_cls=True), [1, patches.shape[1]], dim=1) + patches = patches + posemb_patches + cls_tok = encoder.cls_token + posemb_cls - def __init__(self, encoder_dim=768, decoder_dim=512, decoder_depth=8, decoder_heads=16): - super().__init__() - self.decoder_dim = decoder_dim - self.decoder_embed = nn.Linear(encoder_dim, decoder_dim) - - # Decoder transformer blocks - self.decoder_blocks = nn.ModuleList([ - nn.TransformerEncoderLayer( - d_model=decoder_dim, - nhead=decoder_heads, - dim_feedforward=decoder_dim * 4, - dropout=0.0, - activation='gelu', - batch_first=True - ) - for _ in range(decoder_depth) - ]) - - self.decoder_norm = nn.LayerNorm(decoder_dim) - self.decoder_pred = nn.Linear(decoder_dim, patch_channel_dim) # Predict pixel values - - # Learnable mask token - self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim)) - - # Fixed positional embeddings - self.decoder_pos_embed = nn.Parameter( - torch.from_numpy( - get_2d_sincos_pos_embed(decoder_dim, int(math.sqrt(num_patches))) - ).float(), - requires_grad=False + if self.training: + indices_keep, indices_masked = batch["mask_visible"], batch["mask_masked"] + patches_visible = apply_mask(patches, indices_keep) + patches_visible = torch.cat([cls_tok, patches_visible], dim=1) + for blk in encoder.blocks: + patches_visible = blk(patches_visible) + patches_visible = encoder.norm(patches_visible) + out["embeddings"] = patches_visible + out["reconstructed_pixel_patches"] = _forward_decoder(self, batch, out, stage) + out["loss"] = self.loss_fn( + apply_mask(out["reconstructed_pixel_patches"], indices_masked), + apply_mask(image_patches, indices_masked) ) + else: + patches = torch.cat([cls_tok, patches], dim=1) + for blk in encoder.blocks: + patches = blk(patches) + patches = encoder.norm(patches) + out["embeddings"] = patches + + # exclude cls token for rankme (flat_embedding) and linear probe (sum_embedding) + out["flat_embedding"] = out["embeddings"][:, 1:, :].flatten(start_dim=1) + out["sum_embedding"] = out["embeddings"][:, 1:, :].sum(dim=1) + return out - def forward(self, x_visible: torch.Tensor, mask_visible: torch.Tensor, mask_masked: torch.Tensor) -> torch.Tensor: - """ - Args: - x_visible: [B, N_visible, encoder_dim] - encoded visible patches - mask_visible: [B, N_visible] - indices of visible patches - mask_masked: [B, N_masked] - indices of masked patches - Returns: - x_reconstructed: [B, N_masked, patch_pixels] - reconstructed masked patches - """ - B = x_visible.shape[0] - N_visible = mask_visible.shape[1] - N_masked = mask_masked.shape[1] - - # Project encoder features to decoder dimension - x_visible = self.decoder_embed(x_visible) # [B, N_visible, decoder_dim] - - # Create mask tokens for masked positions - mask_tokens = self.mask_token.expand(B, N_masked, -1) # [B, N_masked, decoder_dim] - - # Get positional embeddings for all patches - pos_embed_all = self.decoder_pos_embed.unsqueeze(0).expand(B, -1, -1) # [B, N_total, decoder_dim] - pos_visible = apply_mask(pos_embed_all, mask_visible) # [B, N_visible, decoder_dim] - pos_masked = apply_mask(pos_embed_all, mask_masked) # [B, N_masked, decoder_dim] - - # Add positional embeddings - x_visible = x_visible + pos_visible - mask_tokens = mask_tokens + pos_masked - - # Combine visible and masked tokens - # For simplicity, concatenate them (real MAE uses more sophisticated ordering) - x_full = torch.cat([x_visible, mask_tokens], dim=1) # [B, N_visible + N_masked, decoder_dim] - - # Apply decoder transformer blocks - for block in self.decoder_blocks: - x_full = block(x_full) - - x_full = self.decoder_norm(x_full) - - # Extract predictions for masked patches only - x_masked_pred = x_full[:, N_visible:] # [B, N_masked, decoder_dim] - - # Predict pixel values - x_reconstructed = self.decoder_pred(x_masked_pred) # [B, N_masked, patch_pixels] - - return x_reconstructed encoder_kwargs = dict( + img_size=(height, width), patch_size=patch_size, embed_dim=768, - depth=12, - num_heads=12, + depth=16, + num_heads=16, qkv_bias=True, # MAE typically uses bias mae_in_dim=patch_channel_dim, ) decoder_kwargs = dict( - encoder_dim=768, - decoder_dim=512, - decoder_depth=8, - decoder_heads=16, + img_size=(height, width), + patch_size=patch_size, + mae_enc_dim=768, + embed_dim=512, + depth=8, + num_heads=16, ) -def forward(self: ssl.Module, batch, stage): - out = {} - encoder: MAE_Encoder = self.encoder - decoder: MAE_Decoder = self.decoder - mae_loss: nn.Module = self.mae_loss - - # Patchify and get all patches with positions - image_patches = encoder.patchify(batch["image"]) # [B, N, patch_pixels] - all_patches = encoder.project_patches(image_patches) # [B, N, embed_dim] - pos_embedding = pos_embed(all_patches) - all_patches = all_patches + pos_embedding - - if not self.training: - # For validation, encode all patches for downstream tasks - out["embedding"] = encoder.encode_patches(all_patches) - out["sum_embedding"] = out["embedding"].sum(dim=1) - out["flat_embedding"] = out["embedding"].reshape(out["embedding"].shape[0], -1) - return out - - mask_visible, mask_masked = batch["mask_visible"], batch["mask_masked"] - - # Encode only visible patches - visible_patches = apply_mask(all_patches, mask_visible) # [B, N_visible, embed_dim] - encoded_visible = encoder.encode_patches(visible_patches) # [B, N_visible, embed_dim] - - # Decode to reconstruct masked patches - reconstructed_masked = decoder(encoded_visible, mask_visible, mask_masked) # [B, N_masked, patch_pixels] - - # Get ground truth masked patches - gt_masked_patches = apply_mask(image_patches, mask_masked) # [B, N_masked, patch_pixels] - - # Compute reconstruction loss (MSE in pixel space) - out["loss"] = mae_loss(reconstructed_masked, gt_masked_patches) - - # For monitoring - out["embedding"] = encoded_visible - out["reconstructed"] = reconstructed_masked - out["sum_embedding"] = out["embedding"].sum(dim=1) - out["flat_embedding"] = out["embedding"].reshape(out["embedding"].shape[0], -1) - out["ground_truth"] = gt_masked_patches - return out module = ssl.Module( encoder=MAE_Encoder(**encoder_kwargs), decoder=MAE_Decoder(**decoder_kwargs), forward=forward, - mae_loss=F.mse_loss, # Pixel MSE loss + loss_fn=F.mse_loss, # pixel MSE loss. we make implicit assumption that norm-pix-loss is False ) # Note: Linear probe uses visible patches only during training @@ -361,12 +297,7 @@ def forward(self: ssl.Module, batch, stage): logger=wandb_logger, enable_checkpointing=False, accelerator="gpu", - devices=1, - # strategy=DDPStrategy( - # find_unused_parameters=True, # this is because only teacher's params are used in the teacher-student module - # static_graph=True, - # gradient_as_bucket_view=True, - # ) + devices=1 ) manager = ssl.Manager(trainer=trainer, module=module, data=data) From a1f84dfcf92cdb4183b31b04770584ee31d3127e Mon Sep 17 00:00:00 2001 From: Sami Date: Mon, 18 Aug 2025 06:17:48 +0000 Subject: [PATCH 20/28] mae cifar10 --- stable_ssl/tests/scripts/train_mae_cifar10.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/stable_ssl/tests/scripts/train_mae_cifar10.py b/stable_ssl/tests/scripts/train_mae_cifar10.py index bc07d2dd..f9febd3e 100644 --- a/stable_ssl/tests/scripts/train_mae_cifar10.py +++ b/stable_ssl/tests/scripts/train_mae_cifar10.py @@ -124,6 +124,7 @@ def apply_mask(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: mask_expanded = mask.unsqueeze(-1).expand(-1, -1, D) return torch.gather(x, dim=1, index=mask_expanded) + def patchify(images: torch.Tensor, patch_size: int) -> torch.Tensor: """ images: [B, C, H, W] -> [B, N, P*P*C] @@ -144,7 +145,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # number of patch tokens (no cls) self.patch_size = kwargs.get('patch_size', 16) - self.num_patches = self.patch_embed.num_patches # do NOT add prefix here + self.num_patches = self.patch_embed.num_patches self.mae_patch_project = nn.Linear(mae_in_dim, self.embed_dim) def project_patches(self, patches: torch.Tensor) -> torch.Tensor: @@ -172,6 +173,7 @@ def _forward_decoder(self, batch: dict, out: dict, stage) -> torch.Tensor: decoder: MAE_Decoder = self.decoder patches_visible = out["embeddings"] inverse_shuffle = batch["ids_restore"].unsqueeze(-1).expand(-1, -1, decoder.embed_dim) + # project encoded patches to decoder's space decoder_patches = decoder.decoder_embed(patches_visible) patches_cls, patches_visible = torch.split(decoder_patches, [1, decoder_patches.shape[1] - 1], dim=1) From c5c97a790dce8384e3d967f1d1ce5230019b805f Mon Sep 17 00:00:00 2001 From: Sami Date: Mon, 18 Aug 2025 06:32:56 +0000 Subject: [PATCH 21/28] inet1k mae --- .../tests/scripts/train_ijepa_inet1k.py | 1 + stable_ssl/tests/scripts/train_mae_cifar10.py | 6 +- stable_ssl/tests/scripts/train_mae_inet1k.py | 302 ++++++++++++++++++ 3 files changed, 306 insertions(+), 3 deletions(-) create mode 100644 stable_ssl/tests/scripts/train_mae_inet1k.py diff --git a/stable_ssl/tests/scripts/train_ijepa_inet1k.py b/stable_ssl/tests/scripts/train_ijepa_inet1k.py index 259be8bd..4821dc3e 100644 --- a/stable_ssl/tests/scripts/train_ijepa_inet1k.py +++ b/stable_ssl/tests/scripts/train_ijepa_inet1k.py @@ -19,6 +19,7 @@ num_workers = 32 num_classes = 1000 +# TODO mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] height, width, patch_size = 256, 256, 16 diff --git a/stable_ssl/tests/scripts/train_mae_cifar10.py b/stable_ssl/tests/scripts/train_mae_cifar10.py index f9febd3e..23173d97 100644 --- a/stable_ssl/tests/scripts/train_mae_cifar10.py +++ b/stable_ssl/tests/scripts/train_mae_cifar10.py @@ -19,7 +19,7 @@ std = [0.229, 0.224, 0.225] height, width, patch_size = 32, 32, 4 crop_height, crop_width = 32, 32 -num_patches = (height // patch_size) * (width // patch_size) +num_patches = (crop_height // patch_size) * (crop_width // patch_size) patch_channel_dim = 3 * patch_size * patch_size mask_ratio = 0.75 num_visible_patches = int(num_patches * (1 - mask_ratio)) @@ -235,7 +235,7 @@ def forward(self, batch: dict, stage): encoder_kwargs = dict( - img_size=(height, width), + img_size=(crop_height, crop_width), patch_size=patch_size, embed_dim=768, depth=16, @@ -245,7 +245,7 @@ def forward(self, batch: dict, stage): ) decoder_kwargs = dict( - img_size=(height, width), + img_size=(crop_height, crop_width), patch_size=patch_size, mae_enc_dim=768, embed_dim=512, diff --git a/stable_ssl/tests/scripts/train_mae_inet1k.py b/stable_ssl/tests/scripts/train_mae_inet1k.py new file mode 100644 index 00000000..330aec06 --- /dev/null +++ b/stable_ssl/tests/scripts/train_mae_inet1k.py @@ -0,0 +1,302 @@ +import math + +import lightning as pl +import torch +import torch.nn.functional as F +import torchmetrics +import torchvision +from lightning.pytorch.loggers import WandbLogger +from timm.models.vision_transformer import VisionTransformer +from torch import nn +from lightning.pytorch.strategies import DDPStrategy + +import stable_ssl as ssl +from stable_ssl.data import transforms +from stable_ssl.data.utils import Dataset +from stable_ssl.utils.pos_embed import get_2d_sincos_pos_embed + + +train_batch_size = 128 +val_batch_size = 128 +num_workers = 32 +num_classes = 1000 + +# TODO +mean = [0.485, 0.456, 0.406] +std = [0.229, 0.224, 0.225] +height, width, patch_size = 256, 256, 16 +crop_height, crop_width = 224, 224 +num_patches = (crop_height // patch_size) * (crop_width // patch_size) +patch_channel_dim = 3 * patch_size * patch_size +mask_ratio = 0.75 +num_visible_patches = int(num_patches * (1 - mask_ratio)) + +mask_transform_kwargs = dict( + patch_size=patch_size, + mask_ratio=mask_ratio, + source="image", + target_visible="mask_visible", + target_masked="mask_masked", +) + +train_transform = transforms.Compose( + transforms.RandomResizedCrop((crop_height, crop_width), scale=(0.2, 1.0), interpolation=3), # 3 is bicubic + transforms.RandomHorizontalFlip(), + transforms.RandomMask(**mask_transform_kwargs), + transforms.ToImage(mean=mean, std=std), +) + +val_transform = transforms.Compose( + transforms.RGB(), + transforms.Resize((height, width)), + transforms.CenterCrop((height, width)), + transforms.ToImage(mean=mean, std=std), +) + + +inet1k_train = ssl.data.HFDataset( + path="ILSVRC/imagenet-1k", + split="train", + transform=train_transform, +) + +inet1k_val = ssl.data.HFDataset( + path="ILSVRC/imagenet-1k", + split="validation", + transform=val_transform, +) + + +train = torch.utils.data.DataLoader( + dataset=inet1k_train, + batch_size=train_batch_size, + shuffle=True, + num_workers=num_workers, + drop_last=True, + collate_fn=torch.utils.data.default_collate, + pin_memory=True, + persistent_workers=True, +) + +val = torch.utils.data.DataLoader( + dataset=inet1k_val, + batch_size=val_batch_size, + num_workers=num_workers, + shuffle=False, + collate_fn=torch.utils.data.default_collate, + pin_memory=True, + persistent_workers=True, +) + +data = ssl.data.DataModule(train=train, val=val) + + +def pos_embed(patches: torch.Tensor, with_cls: bool = True) -> torch.Tensor: + embed = ( + torch.from_numpy( + get_2d_sincos_pos_embed(patches.shape[-1], int(math.sqrt(patches.shape[1]))) + ) + .to(patches.device) + .float() + .repeat(patches.shape[0], 1, 1) + ) + if with_cls: + embed = torch.cat([ + torch.zeros(embed.shape[0], 1, embed.shape[2], device=embed.device), + embed + ], dim=1) + + return embed + + +def apply_mask(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """Apply single mask to tensor""" + B, N, D = x.shape + mask_expanded = mask.unsqueeze(-1).expand(-1, -1, D) + return torch.gather(x, dim=1, index=mask_expanded) + + +def patchify(images: torch.Tensor, patch_size: int) -> torch.Tensor: + """ + images: [B, C, H, W] -> [B, N, P*P*C] + """ + B, C, H, W = images.shape + P = patch_size + assert H % P == 0 and W % P == 0 + h = H // P + w = W // P + x = images.reshape(B, C, h, P, w, P) + x = x.permute(0, 2, 4, 3, 5, 1).reshape(B, h * w, P * P * C) + return x + + +class MAE_Encoder(VisionTransformer): + def __init__(self, *args, **kwargs): + mae_in_dim = kwargs.pop('mae_in_dim', 3 * patch_size * patch_size) + super().__init__(*args, **kwargs) + # number of patch tokens (no cls) + self.patch_size = kwargs.get('patch_size', 16) + self.num_patches = self.patch_embed.num_patches + self.mae_patch_project = nn.Linear(mae_in_dim, self.embed_dim) + + def project_patches(self, patches: torch.Tensor) -> torch.Tensor: + return self.mae_patch_project(patches) + + +class MAE_Decoder(VisionTransformer): + def __init__(self, *args, **kwargs): + mae_enc_dim = kwargs.pop('mae_enc_dim', 768) + super().__init__(*args, **kwargs) + in_chans = kwargs.get('in_chans', 3) + patch_size = kwargs.get('patch_size', 16) + # tokens in the decoded grid (no cls) + self.num_patches = self.patch_embed.num_patches + self.decoder_embed = nn.Linear(mae_enc_dim, self.embed_dim) + # mask token for missing positions + self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) + # predict pixel values per token (P^2 * C) + self.patch_size = patch_size + self.in_chans = in_chans + self.decoder_pred = nn.Linear(self.embed_dim, (patch_size ** 2) * in_chans) + + +def _forward_decoder(self, batch: dict, out: dict, stage) -> torch.Tensor: + decoder: MAE_Decoder = self.decoder + patches_visible = out["embeddings"] + inverse_shuffle = batch["ids_restore"].unsqueeze(-1).expand(-1, -1, decoder.embed_dim) + # project encoded patches to decoder's space + decoder_patches = decoder.decoder_embed(patches_visible) + + patches_cls, patches_visible = torch.split(decoder_patches, [1, decoder_patches.shape[1] - 1], dim=1) + batch_size = patches_visible.shape[0] + num_patches = decoder.num_patches + num_visible = patches_visible.shape[1] + + mask_tokens = decoder.mask_token.expand(batch_size, num_patches - num_visible, -1) + # combine visible patches and mask tokens then unshuffle into place + patches = torch.cat([patches_visible, mask_tokens], dim=1) + unshuffled_patches = torch.gather(patches, dim=1, index=inverse_shuffle) + patches = torch.cat([patches_cls, unshuffled_patches], dim=1) + + pe = pos_embed(patches[:, 1:, :], with_cls=True) + patches = patches + pe + for blk in decoder.blocks: + patches = blk(patches) + patches = decoder.norm(patches) + # remove cls token + patches = patches[:, 1:, :] + patches = decoder.decoder_pred(patches) + return patches + + +def forward(self, batch: dict, stage): + out = {} + encoder: MAE_Encoder = self.encoder + images = batch["image"] + image_patches = patchify(images, patch_size) + patches = encoder.project_patches(image_patches) + posemb_cls, posemb_patches = torch.split(pos_embed(patches, with_cls=True), [1, patches.shape[1]], dim=1) + patches = patches + posemb_patches + cls_tok = encoder.cls_token + posemb_cls + + if self.training: + indices_keep, indices_masked = batch["mask_visible"], batch["mask_masked"] + patches_visible = apply_mask(patches, indices_keep) + patches_visible = torch.cat([cls_tok, patches_visible], dim=1) + for blk in encoder.blocks: + patches_visible = blk(patches_visible) + patches_visible = encoder.norm(patches_visible) + out["embeddings"] = patches_visible + out["reconstructed_pixel_patches"] = _forward_decoder(self, batch, out, stage) + out["loss"] = self.loss_fn( + apply_mask(out["reconstructed_pixel_patches"], indices_masked), + apply_mask(image_patches, indices_masked) + ) + else: + patches = torch.cat([cls_tok, patches], dim=1) + for blk in encoder.blocks: + patches = blk(patches) + patches = encoder.norm(patches) + out["embeddings"] = patches + + # exclude cls token for rankme (flat_embedding) and linear probe (sum_embedding) + out["flat_embedding"] = out["embeddings"][:, 1:, :].flatten(start_dim=1) + out["sum_embedding"] = out["embeddings"][:, 1:, :].sum(dim=1) + return out + + +encoder_kwargs = dict( + img_size=(crop_height, crop_width), + patch_size=patch_size, + embed_dim=768, + depth=16, + num_heads=16, + qkv_bias=True, # MAE typically uses bias + mae_in_dim=patch_channel_dim, +) + +decoder_kwargs = dict( + img_size=(crop_height, crop_width), + patch_size=patch_size, + mae_enc_dim=768, + embed_dim=512, + depth=8, + num_heads=16, +) + + +module = ssl.Module( + encoder=MAE_Encoder(**encoder_kwargs), + decoder=MAE_Decoder(**decoder_kwargs), + forward=forward, + loss_fn=F.mse_loss, # pixel MSE loss. we make implicit assumption that norm-pix-loss is False +) + +# Note: Linear probe uses visible patches only during training +linear_probe = ssl.callbacks.OnlineProbe( + "linear_probe", + "sum_embedding", + "label", + probe=torch.nn.Linear(768, num_classes), + loss_fn=torch.nn.CrossEntropyLoss(), + metrics={ + "top1": torchmetrics.classification.MulticlassAccuracy(num_classes), + "top5": torchmetrics.classification.MulticlassAccuracy(num_classes, top_k=5), + }, +) + +# RankMe on encoder outputs +rankme = ssl.callbacks.RankMe( + name="rankme", + target="flat_embedding", + queue_length=min(512, train_batch_size), + target_shape=(num_visible_patches, 768), +) + +# Initialize W&B logger +wandb_logger = WandbLogger( + project="mae-inet1k", + entity="slightly-more-badass", + name="mae-inet1k-run", + log_model=False, + offline=True, +) + +trainer = pl.Trainer( + max_epochs=6, + num_sanity_val_steps=0, + callbacks=[linear_probe, rankme], + precision="16-mixed", + logger=wandb_logger, + enable_checkpointing=False, + accelerator="gpu", + devices=8, + strategy=DDPStrategy( + find_unused_parameters=True, + static_graph=True, + gradient_as_bucket_view=True, + ) +) + +manager = ssl.Manager(trainer=trainer, module=module, data=data) +manager() \ No newline at end of file From 9d3344d1f7f6421161f7283bf86883a64d5d5045 Mon Sep 17 00:00:00 2001 From: Sami Date: Mon, 18 Aug 2025 06:35:20 +0000 Subject: [PATCH 22/28] inetk mae --- stable_ssl/tests/scripts/train_mae_inet1k.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stable_ssl/tests/scripts/train_mae_inet1k.py b/stable_ssl/tests/scripts/train_mae_inet1k.py index 330aec06..dbb2cfb3 100644 --- a/stable_ssl/tests/scripts/train_mae_inet1k.py +++ b/stable_ssl/tests/scripts/train_mae_inet1k.py @@ -18,7 +18,7 @@ train_batch_size = 128 val_batch_size = 128 -num_workers = 32 +num_workers = 0 num_classes = 1000 # TODO @@ -40,7 +40,7 @@ ) train_transform = transforms.Compose( - transforms.RandomResizedCrop((crop_height, crop_width), scale=(0.2, 1.0), interpolation=3), # 3 is bicubic + transforms.RandomResizedCrop((crop_height, crop_width), scale=(0.2, 1.0)), transforms.RandomHorizontalFlip(), transforms.RandomMask(**mask_transform_kwargs), transforms.ToImage(mean=mean, std=std), From 251daaa8e57a8c4ad6ec1932c077b5ce9c75bbff Mon Sep 17 00:00:00 2001 From: Sami Date: Tue, 19 Aug 2025 05:47:37 +0000 Subject: [PATCH 23/28] todo gradnorm --- .../tests/scripts/train_ijepa_cifar10.py | 62 +++++++++++++------ stable_ssl/tests/scripts/train_mae_cifar10.py | 3 +- 2 files changed, 44 insertions(+), 21 deletions(-) diff --git a/stable_ssl/tests/scripts/train_ijepa_cifar10.py b/stable_ssl/tests/scripts/train_ijepa_cifar10.py index 590903e1..67c84833 100644 --- a/stable_ssl/tests/scripts/train_ijepa_cifar10.py +++ b/stable_ssl/tests/scripts/train_ijepa_cifar10.py @@ -1,5 +1,5 @@ import math - +from functools import partial import lightning as pl import torch import torch.nn.functional as F @@ -15,15 +15,29 @@ from stable_ssl.data.utils import Dataset from stable_ssl.utils.pos_embed import get_2d_sincos_pos_embed +# -- optim +train_batch_size = 512 +val_batch_size = 128 +num_epochs = 125 +lr_warmup_epochs = 15 +lr = 2e-3 +max_grad_norm = 10.0 +ema = (0.97, 0.999) +ipe_scale = 1.25 + + +# -- data +num_workers = 0 +num_classes = 10 mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] -height, width, patch_size = 32, 32, 4 -crop_height, crop_width = 32, 32 # # CIFAR-10 is 32x32, but on INET, IJEPA uses 224 -# We precompute these so the predictor can make sinusoidal posembeds +height, width, patch_size = 32, 32, 2 +crop_height, crop_width = 28, 28 # # CIFAR-10 is 32x32, but on INET, IJEPA uses 224 +# we precompute these so the predictor can make sinusoidal posembeds num_patches = (height // patch_size) * (width // patch_size) patch_channel_dim = 3 * patch_size * patch_size -# Based on the in1k_vith14_ep300.yaml config in the ijepa repository +# based on the in1k_vith14_ep300.yaml config in the ijepa repository mask_transform_kwargs = dict( patch_size=patch_size, context_scale=(0.85, 1.0), @@ -56,6 +70,9 @@ root="/tmp/cifar10", train=False, download=True ) +optim = partial(torch.optim.AdamW, lr=lr, weight_decay=0.0, betas=(0.9, 0.999), eps=1e-8) +scheduler = partial(torch.optim.lr_scheduler.CosineAnnealingLR, T_max=int(ipe_scale * num_epochs)) + class IndexedDataset(Dataset): """Custom dataset wrapper that adds sample_idx to each sample.""" @@ -101,9 +118,9 @@ def standardize_masks(batch: list[dict]): # single views and handles masking at the model level train = torch.utils.data.DataLoader( dataset=train_dataset, - batch_size=512, + batch_size=train_batch_size, shuffle=True, # Regular shuffling, no RepeatedRandomSampler - num_workers=32, + num_workers=num_workers, drop_last=True, collate_fn=standardize_masks, ) @@ -111,8 +128,8 @@ def standardize_masks(batch: list[dict]): val_dataset = IndexedDataset(cifar_val, transform=val_transform) val = torch.utils.data.DataLoader( dataset=val_dataset, - batch_size=128, - num_workers=32, + batch_size=val_batch_size, + num_workers=num_workers, shuffle=True, ) @@ -267,30 +284,32 @@ def predict_targets( return pred +# pico vit encoder_kwargs = dict( patch_size=patch_size, - embed_dim=768, + embed_dim=64, depth=12, - num_heads=12, + num_heads=2, qkv_bias=False, ijepa_in_dim=patch_channel_dim, ) predictor_kwargs = dict( patch_size=patch_size, - embed_dim=384, + embed_dim=32, depth=6, - num_heads=6, + num_heads=2, qkv_bias=False, ijepa_encoder_dim=768, predictor_num_patches=num_patches, ) +context_encoder = IJEPA_Encoder(**encoder_kwargs) +predictor = IJEPA_Predictor(**predictor_kwargs) + def forward(self: ssl.Module, batch, stage): out = {} - target_encoder: IJEPA_Encoder = ( - self.target_encoder.teacher - ) # NOTE Would this break anything? + target_encoder: IJEPA_Encoder = self.target_encoder.teacher context_encoder: IJEPA_Encoder = self.context_encoder predictor: IJEPA_Predictor = self.predictor ijepa_loss: nn.Module = self.ijepa_loss @@ -325,11 +344,12 @@ def forward(self: ssl.Module, batch, stage): module = ssl.Module( - context_encoder=(ctx := IJEPA_Encoder(**encoder_kwargs)), - target_encoder=TeacherStudentModule(ctx), - predictor=IJEPA_Predictor(**predictor_kwargs), + context_encoder=context_encoder, + target_encoder=TeacherStudentModule(context_encoder, base_ema_coefficient=ema[0], final_ema_coefficient=ema[1]), + predictor=predictor, forward=forward, ijepa_loss=F.smooth_l1_loss, + optim=dict(optimizer=optim, scheduler=scheduler), ) @@ -361,13 +381,15 @@ def forward(self: ssl.Module, batch, stage): offline=True, # Ensure offline mode ) + trainer = pl.Trainer( - max_epochs=6, + max_epochs=num_epochs, num_sanity_val_steps=0, # Skip sanity check as queues need to be filled first callbacks=[linear_probe, rankme], precision="16-mixed", logger=wandb_logger, enable_checkpointing=False, ) + manager = ssl.Manager(trainer=trainer, module=module, data=data) manager() diff --git a/stable_ssl/tests/scripts/train_mae_cifar10.py b/stable_ssl/tests/scripts/train_mae_cifar10.py index 23173d97..ff21ecca 100644 --- a/stable_ssl/tests/scripts/train_mae_cifar10.py +++ b/stable_ssl/tests/scripts/train_mae_cifar10.py @@ -299,7 +299,8 @@ def forward(self, batch: dict, stage): logger=wandb_logger, enable_checkpointing=False, accelerator="gpu", - devices=1 + devices=1, + ) manager = ssl.Manager(trainer=trainer, module=module, data=data) From d5cf9d759b2e2ca49177bb1de706d429e4e4fe21 Mon Sep 17 00:00:00 2001 From: Sami Date: Tue, 19 Aug 2025 09:21:18 +0000 Subject: [PATCH 24/28] fixes and sweeps --- assets/benchmarks/cifar10/byol.yaml | 4 +- assets/benchmarks/cifar10/dino.yaml | 4 +- assets/benchmarks/cifar100/byol.yaml | 4 +- assets/benchmarks/cifar100/dino.yaml | 4 +- assets/benchmarks/imagenette/byol.yaml | 4 +- assets/benchmarks/imagenette/dino.yaml | 4 +- benchmarks/cifar10/vicreg-resnet18.py | 10 ++- stable_ssl/manager.py | 10 +++ stable_ssl/module.py | 7 +- .../tests/scripts/train_ijepa_cifar10.py | 77 +++++++++++++------ .../tests/scripts/train_ijepa_inet1k.py | 39 ++++++---- stable_ssl/tests/scripts/train_mae_cifar10.py | 35 ++++++--- stable_ssl/tests/scripts/train_mae_inet1k.py | 13 ++-- 13 files changed, 142 insertions(+), 73 deletions(-) diff --git a/assets/benchmarks/cifar10/byol.yaml b/assets/benchmarks/cifar10/byol.yaml index 7a3c227b..583c6fda 100644 --- a/assets/benchmarks/cifar10/byol.yaml +++ b/assets/benchmarks/cifar10/byol.yaml @@ -13,14 +13,14 @@ trainer: # ===== Module Parameters ===== module: backbone: - _target_: stable_ssl.modules.TeacherStudentModule + _target_: stable_ssl.backbone.utils.TeacherStudentWrapper student: _target_: stable_ssl.modules.load_backbone name: resnet50 low_resolution: True num_classes: null projector: - _target_: stable_ssl.modules.TeacherStudentModule + _target_: stable_ssl.backbone.utils.TeacherStudentWrapper student: _target_: stable_ssl.modules.MLP sizes: [2048, 2048, 128] diff --git a/assets/benchmarks/cifar10/dino.yaml b/assets/benchmarks/cifar10/dino.yaml index c2e3ec8e..d7b89869 100644 --- a/assets/benchmarks/cifar10/dino.yaml +++ b/assets/benchmarks/cifar10/dino.yaml @@ -9,14 +9,14 @@ trainer: # ===== Module Parameters ===== module: backbone: - _target_: stable_ssl.modules.TeacherStudentModule + _target_: stable_ssl.backbone.utils.TeacherStudentWrapper student: _target_: stable_ssl.modules.load_backbone name: resnet50 low_resolution: True num_classes: null projector: - _target_: stable_ssl.modules.TeacherStudentModule + _target_: stable_ssl.backbone.utils.TeacherStudentWrapper student: _target_: stable_ssl.modules.MLP sizes: [2048, 2048, 2048] diff --git a/assets/benchmarks/cifar100/byol.yaml b/assets/benchmarks/cifar100/byol.yaml index d727d3b1..c601f8d4 100644 --- a/assets/benchmarks/cifar100/byol.yaml +++ b/assets/benchmarks/cifar100/byol.yaml @@ -13,14 +13,14 @@ trainer: # ===== Module Parameters ===== module: backbone: - _target_: stable_ssl.modules.TeacherStudentModule + _target_: stable_ssl.backbone.utils.TeacherStudentWrapper student: _target_: stable_ssl.modules.load_backbone name: resnet50 low_resolution: True num_classes: null projector: - _target_: stable_ssl.modules.TeacherStudentModule + _target_: stable_ssl.backbone.utils.TeacherStudentWrapper student: _target_: stable_ssl.modules.MLP sizes: [2048, 2048, 128] diff --git a/assets/benchmarks/cifar100/dino.yaml b/assets/benchmarks/cifar100/dino.yaml index c2e3ec8e..d7b89869 100644 --- a/assets/benchmarks/cifar100/dino.yaml +++ b/assets/benchmarks/cifar100/dino.yaml @@ -9,14 +9,14 @@ trainer: # ===== Module Parameters ===== module: backbone: - _target_: stable_ssl.modules.TeacherStudentModule + _target_: stable_ssl.backbone.utils.TeacherStudentWrapper student: _target_: stable_ssl.modules.load_backbone name: resnet50 low_resolution: True num_classes: null projector: - _target_: stable_ssl.modules.TeacherStudentModule + _target_: stable_ssl.backbone.utils.TeacherStudentWrapper student: _target_: stable_ssl.modules.MLP sizes: [2048, 2048, 2048] diff --git a/assets/benchmarks/imagenette/byol.yaml b/assets/benchmarks/imagenette/byol.yaml index ce045e8b..fd3f8d5a 100644 --- a/assets/benchmarks/imagenette/byol.yaml +++ b/assets/benchmarks/imagenette/byol.yaml @@ -13,14 +13,14 @@ trainer: # ===== Module Parameters ===== module: backbone: - _target_: stable_ssl.modules.TeacherStudentModule + _target_: stable_ssl.backbone.utils.TeacherStudentWrapper student: _target_: stable_ssl.modules.load_backbone name: resnet50 low_resolution: False num_classes: null projector: - _target_: stable_ssl.modules.TeacherStudentModule + _target_: stable_ssl.backbone.utils.TeacherStudentWrapper student: _target_: stable_ssl.modules.MLP sizes: [2048, 2048, 128] diff --git a/assets/benchmarks/imagenette/dino.yaml b/assets/benchmarks/imagenette/dino.yaml index 7d0afbca..724240fa 100644 --- a/assets/benchmarks/imagenette/dino.yaml +++ b/assets/benchmarks/imagenette/dino.yaml @@ -9,14 +9,14 @@ trainer: # ===== Module Parameters ===== module: backbone: - _target_: stable_ssl.modules.TeacherStudentModule + _target_: stable_ssl.backbone.utils.TeacherStudentWrapper student: _target_: stable_ssl.modules.load_backbone name: resnet50 low_resolution: False num_classes: null projector: - _target_: stable_ssl.modules.TeacherStudentModule + _target_: stable_ssl.backbone.utils.TeacherStudentWrapper student: _target_: stable_ssl.modules.MLP sizes: [2048, 2048, 2048] diff --git a/benchmarks/cifar10/vicreg-resnet18.py b/benchmarks/cifar10/vicreg-resnet18.py index 85fe5e7b..ffc7d3dc 100644 --- a/benchmarks/cifar10/vicreg-resnet18.py +++ b/benchmarks/cifar10/vicreg-resnet18.py @@ -153,10 +153,11 @@ def forward(self, batch, stage): ) wandb_logger = WandbLogger( - entity="stable-ssl", - project="cifar10-vicreg", - name="vicreg-resnet18", - log_model=False, + project="ijepa-cifar10", + entity="samibg", # Your W&B entity + name="vicreg-cifar10-run", + log_model=False, # Set to True if you want to save model artifacts + offline=False, # Ensure offline mode ) trainer = pl.Trainer( @@ -165,6 +166,7 @@ def forward(self, batch, stage): callbacks=[knn_probe, linear_probe], precision="16-mixed", logger=wandb_logger, + devices=1, enable_checkpointing=False, ) diff --git a/stable_ssl/manager.py b/stable_ssl/manager.py index 8baa1c08..fb38b839 100644 --- a/stable_ssl/manager.py +++ b/stable_ssl/manager.py @@ -403,6 +403,16 @@ def __call__(self): else: ckpt_path = self.ckpt_path logging.info(f"📣📣📣 CALLING trainer.fit with {ckpt_path=} 📣📣📣") + # Enable passing in gradient_clip_* to pl.Trainer while using + # manual optimization of ssl.Module by re-assigning it here. + # https://github.com/rbalestr-lab/stable-ssl/issues/246 + if hasattr(self.trainer, "gradient_clip_val"): + self.module.gradient_clip_val = self.trainer.gradient_clip_val + self.trainer.gradient_clip_val = None + if hasattr(self.trainer, "gradient_clip_algorithm"): + self.module.gradient_clip_algorithm = self.trainer.gradient_clip_algorithm + self.trainer.gradient_clip_algorithm = None + self._trainer.fit( self.instantiated_module, datamodule=self.instantiated_data, diff --git a/stable_ssl/module.py b/stable_ssl/module.py index cdd7165f..e9777b1f 100644 --- a/stable_ssl/module.py +++ b/stable_ssl/module.py @@ -116,6 +116,9 @@ def __init__(self, *args, forward: callable, hparams: dict = None, **kwargs): self._optimizer_index_by_name = None self._optimizer_frequencies = None + def on_fit_start(self): + super().on_fit_start() + def forward(self, *args, **kwargs): raise NotImplementedError("The forward() method must be implemented.") @@ -206,8 +209,8 @@ def training_step(self, batch, batch_idx): # Clip gradients for this optimizer then step self.clip_gradients( opt, - gradient_clip_val=self.trainer.gradient_clip_val, - gradient_clip_algorithm=self.trainer.gradient_clip_algorithm, + gradient_clip_val=self.gradient_clip_val, + gradient_clip_algorithm=self.gradient_clip_algorithm, ) opt.step() opt.zero_grad(set_to_none=True) diff --git a/stable_ssl/tests/scripts/train_ijepa_cifar10.py b/stable_ssl/tests/scripts/train_ijepa_cifar10.py index 67c84833..14a826db 100644 --- a/stable_ssl/tests/scripts/train_ijepa_cifar10.py +++ b/stable_ssl/tests/scripts/train_ijepa_cifar10.py @@ -10,31 +10,35 @@ from torch import nn import stable_ssl as ssl -from stable_ssl.backbone.utils import TeacherStudentModule +from stable_ssl.backbone.utils import TeacherStudentWrapper from stable_ssl.data import transforms from stable_ssl.data.utils import Dataset from stable_ssl.utils.pos_embed import get_2d_sincos_pos_embed + +# TODO Beware - some of these are commented out eventually and +# not used in the optimizer below. # -- optim -train_batch_size = 512 +train_batch_size = 128 val_batch_size = 128 -num_epochs = 125 +num_epochs = 1000 lr_warmup_epochs = 15 -lr = 2e-3 -max_grad_norm = 10.0 +lr = 5 +# max_grad_norm = 5.0 +max_grad_norm = None ema = (0.97, 0.999) ipe_scale = 1.25 # -- data -num_workers = 0 +num_workers = 64 num_classes = 10 mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] -height, width, patch_size = 32, 32, 2 -crop_height, crop_width = 28, 28 # # CIFAR-10 is 32x32, but on INET, IJEPA uses 224 +height, width, patch_size = 32, 32, 4 +crop_height, crop_width = 32, 32 # CIFAR-10 is 32x32, but on INET, IJEPA uses 224 # we precompute these so the predictor can make sinusoidal posembeds -num_patches = (height // patch_size) * (width // patch_size) +num_patches = (crop_height // patch_size) * (crop_width // patch_size) patch_channel_dim = 3 * patch_size * patch_size # based on the in1k_vith14_ep300.yaml config in the ijepa repository @@ -58,7 +62,7 @@ val_transform = transforms.Compose( transforms.RGB(), transforms.Resize((height, width)), - transforms.CenterCrop((height, width)), + transforms.CenterCrop((crop_height, crop_width)), transforms.ToImage(mean=mean, std=std), ) @@ -123,6 +127,8 @@ def standardize_masks(batch: list[dict]): num_workers=num_workers, drop_last=True, collate_fn=standardize_masks, + pin_memory=True, + persistent_workers=True, ) val_dataset = IndexedDataset(cifar_val, transform=val_transform) @@ -131,6 +137,8 @@ def standardize_masks(batch: list[dict]): batch_size=val_batch_size, num_workers=num_workers, shuffle=True, + pin_memory=True, + persistent_workers=True, ) @@ -287,19 +295,19 @@ def predict_targets( # pico vit encoder_kwargs = dict( patch_size=patch_size, - embed_dim=64, + embed_dim=384, depth=12, - num_heads=2, + num_heads=6, qkv_bias=False, ijepa_in_dim=patch_channel_dim, ) predictor_kwargs = dict( patch_size=patch_size, - embed_dim=32, + embed_dim=192, depth=6, - num_heads=2, + num_heads=6, qkv_bias=False, - ijepa_encoder_dim=768, + ijepa_encoder_dim=384, predictor_num_patches=num_patches, ) @@ -334,31 +342,47 @@ def forward(self: ssl.Module, batch, stage): context_patches = apply_masks(image_patches, mask_context) context_patches = context_encoder.project_patches(context_patches) context_patches = context_patches + apply_masks(pos_embedding, mask_context) - out["context_patches"] = context_encoder.encode_patches( + context_patches = context_encoder.encode_patches( context_patches, with_layernorm=False ) + out["context_patches"] = context_patches out["predicted_patches"] = predictor.predict_targets(context_patches, masks_target) - out["loss"] = ijepa_loss(out["predicted_patches"], out["target_patches"]) return out module = ssl.Module( context_encoder=context_encoder, - target_encoder=TeacherStudentModule(context_encoder, base_ema_coefficient=ema[0], final_ema_coefficient=ema[1]), + target_encoder=TeacherStudentWrapper(context_encoder, base_ema_coefficient=ema[0], final_ema_coefficient=ema[1]), predictor=predictor, forward=forward, ijepa_loss=F.smooth_l1_loss, - optim=dict(optimizer=optim, scheduler=scheduler), + optim={ + "optimizer": { + "type": "LARS", + "lr": 1e-3, + "weight_decay": 1e-6, + }, + "scheduler": { + "type": "LinearWarmupCosineAnnealing", + }, + "interval": "epoch", + }, + # optim=dict(optimizer=optim, scheduler=scheduler), ) +probe_optimizer = partial(torch.optim.AdamW, lr=3e-4, weight_decay=0.0, betas=(0.9, 0.999), eps=1e-8) +probe_scheduler = partial(torch.optim.lr_scheduler.CosineAnnealingLR, T_max=int(ipe_scale * num_epochs)) + linear_probe = ssl.callbacks.OnlineProbe( "linear_probe", - "sum_embedding", + "flat_embedding", "label", - probe=torch.nn.Linear(768, 10), + probe=torch.nn.Linear(384 * num_patches, 10), loss_fn=torch.nn.CrossEntropyLoss(), + optimizer=probe_optimizer, + scheduler=probe_scheduler, metrics={ "top1": torchmetrics.classification.MulticlassAccuracy(10), "top5": torchmetrics.classification.MulticlassAccuracy(10, top_k=5), @@ -368,17 +392,17 @@ def forward(self: ssl.Module, batch, stage): rankme = ssl.callbacks.RankMe( name="rankme", target="flat_embedding", - queue_length=512, # NOTE must be >= batch_size - target_shape=(num_patches, 768), + queue_length=min(512, train_batch_size), # NOTE must be >= batch_size + target_shape=(num_patches, 384), ) # Initialize W&B logger with explicit settings wandb_logger = WandbLogger( project="ijepa-cifar10", - entity="slightly-more-badass", # Your W&B entity + entity="samibg", # Your W&B entity name="ijepa-cifar10-run", log_model=False, # Set to True if you want to save model artifacts - offline=True, # Ensure offline mode + offline=False, # Ensure offline mode ) @@ -388,7 +412,10 @@ def forward(self: ssl.Module, batch, stage): callbacks=[linear_probe, rankme], precision="16-mixed", logger=wandb_logger, + devices=1, enable_checkpointing=False, + gradient_clip_val=max_grad_norm, + gradient_clip_algorithm="norm", ) manager = ssl.Manager(trainer=trainer, module=module, data=data) diff --git a/stable_ssl/tests/scripts/train_ijepa_inet1k.py b/stable_ssl/tests/scripts/train_ijepa_inet1k.py index 4821dc3e..06facf6a 100644 --- a/stable_ssl/tests/scripts/train_ijepa_inet1k.py +++ b/stable_ssl/tests/scripts/train_ijepa_inet1k.py @@ -8,25 +8,34 @@ from torch import nn import stable_ssl as ssl -from stable_ssl.backbone.utils import TeacherStudentModule +from stable_ssl.backbone.utils import TeacherStudentWrapper from stable_ssl.data import transforms from stable_ssl.utils.pos_embed import get_2d_sincos_pos_embed from lightning.pytorch.strategies import DDPStrategy - +from functools import partial train_batch_size = 128 val_batch_size = 128 num_workers = 32 num_classes = 1000 +num_epochs = 300 +start_lr = 2e-4 +lr = 1e-3 +final_lr = 1e-6 +max_grad_norm = 5.0 +ema = (0.996, 1.0) +lr_warmup_steps = 40 + -# TODO mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] -height, width, patch_size = 256, 256, 16 +height, width, patch_size = 256, 256, 14 crop_height, crop_width = 224, 224 # CIFAR-10 is 32x32, but on INET, IJEPA uses 224 # We precompute these so the predictor can make sinusoidal posembeds num_patches = (crop_height // patch_size) * (crop_width // patch_size) patch_channel_dim = 3 * patch_size * patch_size +encoder_embed_dim = 768 +predictor_embed_dim = 384 # Based on the in1k_vith14_ep300.yaml config in the ijepa repository mask_transform_kwargs = dict( @@ -35,9 +44,13 @@ context_aspect_ratio=(1.0, 1.0), target_scales=((0.15, 0.2),) * 4, target_aspect_ratios=((0.75, 1.5),) * 4, - min_keep=20, + min_keep=10, ) +optim = partial(torch.optim.AdamW, lr=lr, weight_decay=0.0, betas=(0.9, 0.999), eps=1e-8) +scheduler = partial(ssl.optim.lr_scheduler.LinearWarmupCosineAnnealingLR, warmup_start_lr=start_lr, max_steps=num_epochs, warmup_steps=lr_warmup_steps, eta_min=final_lr) + + train_transform = transforms.Compose( transforms.RGB(), @@ -264,7 +277,7 @@ def predict_targets( encoder_kwargs = dict( patch_size=patch_size, - embed_dim=768, + embed_dim=encoder_embed_dim, depth=12, num_heads=12, qkv_bias=False, @@ -272,11 +285,11 @@ def predict_targets( ) predictor_kwargs = dict( patch_size=patch_size, - embed_dim=384, + embed_dim=predictor_embed_dim, depth=12, num_heads=12, qkv_bias=False, - ijepa_encoder_dim=768, + ijepa_encoder_dim=encoder_embed_dim, predictor_num_patches=num_patches, ) @@ -308,18 +321,18 @@ def forward(self: ssl.Module, batch, stage): context_patches = apply_masks(image_patches, mask_context) context_patches = context_encoder.project_patches(context_patches) context_patches = context_patches + apply_masks(pos_embedding, mask_context) - out["context_patches"] = context_encoder.encode_patches( + context_patches = context_encoder.encode_patches( context_patches, with_layernorm=False ) + out["context_patches"] = context_patches out["predicted_patches"] = predictor.predict_targets(context_patches, masks_target) - out["loss"] = ijepa_loss(out["predicted_patches"], out["target_patches"]) return out module = ssl.Module( context_encoder=(ctx := IJEPA_Encoder(**encoder_kwargs)), - target_encoder=TeacherStudentModule(ctx), + target_encoder=TeacherStudentWrapper(ctx), predictor=IJEPA_Predictor(**predictor_kwargs), forward=forward, ijepa_loss=F.smooth_l1_loss, @@ -330,7 +343,7 @@ def forward(self: ssl.Module, batch, stage): "linear_probe", "sum_embedding", "label", - probe=torch.nn.Linear(768, num_classes), + probe=torch.nn.Linear(encoder_embed_dim, num_classes), loss_fn=torch.nn.CrossEntropyLoss(), metrics={ "top1": torchmetrics.classification.MulticlassAccuracy(num_classes), @@ -342,7 +355,7 @@ def forward(self: ssl.Module, batch, stage): name="rankme", target="flat_embedding", queue_length=min(512, train_batch_size), # NOTE must be >= batch_size - target_shape=(num_patches, 768), + target_shape=(num_patches, encoder_embed_dim), ) diff --git a/stable_ssl/tests/scripts/train_mae_cifar10.py b/stable_ssl/tests/scripts/train_mae_cifar10.py index ff21ecca..a5213459 100644 --- a/stable_ssl/tests/scripts/train_mae_cifar10.py +++ b/stable_ssl/tests/scripts/train_mae_cifar10.py @@ -24,9 +24,11 @@ mask_ratio = 0.75 num_visible_patches = int(num_patches * (1 - mask_ratio)) num_classes = 10 -batch_size = 512 +batch_size = 128 val_batch_size = 128 num_workers = 16 +encoder_embed_dim = 384 +decoder_embed_dim = 256 mask_transform_kwargs = dict( patch_size=patch_size, @@ -237,7 +239,7 @@ def forward(self, batch: dict, stage): encoder_kwargs = dict( img_size=(crop_height, crop_width), patch_size=patch_size, - embed_dim=768, + embed_dim=encoder_embed_dim, depth=16, num_heads=16, qkv_bias=True, # MAE typically uses bias @@ -247,8 +249,8 @@ def forward(self, batch: dict, stage): decoder_kwargs = dict( img_size=(crop_height, crop_width), patch_size=patch_size, - mae_enc_dim=768, - embed_dim=512, + mae_enc_dim=encoder_embed_dim, + embed_dim=decoder_embed_dim, depth=8, num_heads=16, ) @@ -259,14 +261,25 @@ def forward(self, batch: dict, stage): decoder=MAE_Decoder(**decoder_kwargs), forward=forward, loss_fn=F.mse_loss, # pixel MSE loss. we make implicit assumption that norm-pix-loss is False + optim={ + "optimizer": { + "type": "LARS", + "lr": 1e-3, + "weight_decay": 1e-6, + }, + "scheduler": { + "type": "LinearWarmupCosineAnnealing", + }, + "interval": "epoch", + }, ) # Note: Linear probe uses visible patches only during training linear_probe = ssl.callbacks.OnlineProbe( "linear_probe", - "sum_embedding", + "flat_embedding", "label", - probe=torch.nn.Linear(768, num_classes), + probe=torch.nn.Linear(encoder_embed_dim * num_visible_patches, num_classes), loss_fn=torch.nn.CrossEntropyLoss(), metrics={ "top1": torchmetrics.classification.MulticlassAccuracy(num_classes), @@ -279,20 +292,20 @@ def forward(self, batch: dict, stage): name="rankme", target="flat_embedding", queue_length=min(512, batch_size), - target_shape=(num_visible_patches, 768), + target_shape=(num_visible_patches, encoder_embed_dim), ) # Initialize W&B logger wandb_logger = WandbLogger( - project="mae-cifar10", - entity="slightly-more-badass", + project="ijepa-cifar10", + entity="samibg", name="mae-cifar10-run", log_model=False, - offline=True, + offline=False, ) trainer = pl.Trainer( - max_epochs=6, + max_epochs=1000, num_sanity_val_steps=0, callbacks=[linear_probe, rankme], precision="16-mixed", diff --git a/stable_ssl/tests/scripts/train_mae_inet1k.py b/stable_ssl/tests/scripts/train_mae_inet1k.py index dbb2cfb3..474a66ba 100644 --- a/stable_ssl/tests/scripts/train_mae_inet1k.py +++ b/stable_ssl/tests/scripts/train_mae_inet1k.py @@ -15,7 +15,8 @@ from stable_ssl.data.utils import Dataset from stable_ssl.utils.pos_embed import get_2d_sincos_pos_embed - +encoder_embed_dim = 768 +decoder_embed_dim = 512 train_batch_size = 128 val_batch_size = 128 num_workers = 0 @@ -228,7 +229,7 @@ def forward(self, batch: dict, stage): encoder_kwargs = dict( img_size=(crop_height, crop_width), patch_size=patch_size, - embed_dim=768, + embed_dim=encoder_embed_dim, depth=16, num_heads=16, qkv_bias=True, # MAE typically uses bias @@ -238,8 +239,8 @@ def forward(self, batch: dict, stage): decoder_kwargs = dict( img_size=(crop_height, crop_width), patch_size=patch_size, - mae_enc_dim=768, - embed_dim=512, + mae_enc_dim=encoder_embed_dim, + embed_dim=decoder_embed_dim, depth=8, num_heads=16, ) @@ -257,7 +258,7 @@ def forward(self, batch: dict, stage): "linear_probe", "sum_embedding", "label", - probe=torch.nn.Linear(768, num_classes), + probe=torch.nn.Linear(encoder_embed_dim, num_classes), loss_fn=torch.nn.CrossEntropyLoss(), metrics={ "top1": torchmetrics.classification.MulticlassAccuracy(num_classes), @@ -270,7 +271,7 @@ def forward(self, batch: dict, stage): name="rankme", target="flat_embedding", queue_length=min(512, train_batch_size), - target_shape=(num_visible_patches, 768), + target_shape=(num_visible_patches, encoder_embed_dim), ) # Initialize W&B logger From 5ec568e12b5cc7d6c21b4e84af94f62b7eb7f2c9 Mon Sep 17 00:00:00 2001 From: Sami Date: Wed, 20 Aug 2025 08:31:56 +0000 Subject: [PATCH 25/28] fixed multiblock masking flatten bug, sweeping hparams --- stable_ssl/data/transforms.py | 4 +- .../tests/scripts/train_ijepa_cifar10.py | 101 ++++++++++++------ .../tests/scripts/train_ijepa_inet1k.py | 87 ++++++++++----- stable_ssl/tests/scripts/train_mae_cifar10.py | 82 ++++++++------ 4 files changed, 179 insertions(+), 95 deletions(-) diff --git a/stable_ssl/data/transforms.py b/stable_ssl/data/transforms.py index 0eefd7cb..52588d59 100644 --- a/stable_ssl/data/transforms.py +++ b/stable_ssl/data/transforms.py @@ -779,9 +779,9 @@ def __call__(self, x): for mask in target_masks: context_mask &= ~mask - x[self.target_context] = torch.nonzero(context_mask).flatten().squeeze() + x[self.target_context] = torch.nonzero(context_mask.flatten()).squeeze() x[self.target_targets] = [ - torch.nonzero(mask).flatten().squeeze() for mask in target_masks + torch.nonzero(mask.flatten()).squeeze() for mask in target_masks ] x[self.get_name(x)] = torch.tensor([scales, aspect_ratios]) return x diff --git a/stable_ssl/tests/scripts/train_ijepa_cifar10.py b/stable_ssl/tests/scripts/train_ijepa_cifar10.py index 14a826db..6d9e6f9d 100644 --- a/stable_ssl/tests/scripts/train_ijepa_cifar10.py +++ b/stable_ssl/tests/scripts/train_ijepa_cifar10.py @@ -16,15 +16,11 @@ from stable_ssl.utils.pos_embed import get_2d_sincos_pos_embed -# TODO Beware - some of these are commented out eventually and -# not used in the optimizer below. -# -- optim train_batch_size = 128 val_batch_size = 128 -num_epochs = 1000 -lr_warmup_epochs = 15 -lr = 5 -# max_grad_norm = 5.0 +num_epochs = 300 +lr_warmup_epochs = 60 +# max_grad_norm = 10.0 max_grad_norm = None ema = (0.97, 0.999) ipe_scale = 1.25 @@ -48,12 +44,15 @@ context_aspect_ratio=(1.0, 1.0), target_scales=((0.15, 0.2),) * 4, target_aspect_ratios=((0.75, 1.5),) * 4, - min_keep=20, + min_keep=10, ) + train_transform = transforms.Compose( transforms.RGB(), + # transforms.CenterCrop((crop_height, crop_width)), + # transforms.RandomHorizontalFlip(), transforms.RandomResizedCrop((crop_height, crop_width), scale=(0.3, 1.0)), transforms.ContextTargetsMultiBlockMask(**mask_transform_kwargs), transforms.ToImage(mean=mean, std=std), @@ -74,9 +73,6 @@ root="/tmp/cifar10", train=False, download=True ) -optim = partial(torch.optim.AdamW, lr=lr, weight_decay=0.0, betas=(0.9, 0.999), eps=1e-8) -scheduler = partial(torch.optim.lr_scheduler.CosineAnnealingLR, T_max=int(ipe_scale * num_epochs)) - class IndexedDataset(Dataset): """Custom dataset wrapper that adds sample_idx to each sample.""" @@ -136,7 +132,7 @@ def standardize_masks(batch: list[dict]): dataset=val_dataset, batch_size=val_batch_size, num_workers=num_workers, - shuffle=True, + shuffle=False, pin_memory=True, persistent_workers=True, ) @@ -295,19 +291,19 @@ def predict_targets( # pico vit encoder_kwargs = dict( patch_size=patch_size, - embed_dim=384, + embed_dim=192, depth=12, - num_heads=6, + num_heads=3, qkv_bias=False, ijepa_in_dim=patch_channel_dim, ) predictor_kwargs = dict( patch_size=patch_size, - embed_dim=192, + embed_dim=96, depth=6, - num_heads=6, + num_heads=3, qkv_bias=False, - ijepa_encoder_dim=384, + ijepa_encoder_dim=192, predictor_num_patches=num_patches, ) @@ -326,18 +322,28 @@ def forward(self: ssl.Module, batch, stage): target_patches = target_encoder.project_patches(image_patches) pos_embedding = pos_embed(target_patches) target_patches = target_patches + pos_embedding - out["embedding"] = target_encoder.encode_patches( + out["target_embedding"] = target_encoder.encode_patches( target_patches, with_layernorm=True ) - out["sum_embedding"] = out["embedding"].sum(dim=1) - out["flat_embedding"] = out["embedding"].reshape(out["embedding"].shape[0], -1) + # need this here (and not in if statement) so we can still compute rankme and eval properly + with torch.no_grad(): + unmasked_context_patches = context_encoder.project_patches(image_patches) + unmasked_pos_embedding = pos_embed(unmasked_context_patches) + out["context_embedding"] = context_encoder.encode_patches( + unmasked_context_patches + unmasked_pos_embedding, + with_layernorm = False + ) + out["meanpool_context_embedding"] = out["context_embedding"].mean(dim=1) + out["sum_context_embedding"] = out["context_embedding"].sum(dim=1) + out["flat_context_embedding"] = out["context_embedding"].reshape(out["context_embedding"].shape[0], -1) + if not self.training: return out mask_context, masks_target = batch["mask_context"], batch["masks_target"] # target encoder is applied on full patches, then masked - out["target_patches"] = apply_masks(out["embedding"], *masks_target) + out["target_patches"] = apply_masks(out["target_embedding"], *masks_target) # context encoder is applied on masked patches context_patches = apply_masks(image_patches, mask_context) context_patches = context_encoder.project_patches(context_patches) @@ -346,8 +352,10 @@ def forward(self: ssl.Module, batch, stage): context_patches, with_layernorm=False ) out["context_patches"] = context_patches + # use context_patches.shape[0] because applying 4 masks quadruples the batch dimension for the predictor output out["predicted_patches"] = predictor.predict_targets(context_patches, masks_target) out["loss"] = ijepa_loss(out["predicted_patches"], out["target_patches"]) + # context embedding return out @@ -356,30 +364,35 @@ def forward(self: ssl.Module, batch, stage): target_encoder=TeacherStudentWrapper(context_encoder, base_ema_coefficient=ema[0], final_ema_coefficient=ema[1]), predictor=predictor, forward=forward, - ijepa_loss=F.smooth_l1_loss, + ijepa_loss=lambda pred, tgt: ((pred - tgt) ** 2).sum(-1).sum(-1).mean(), + # ijepa_loss=partial(F.mse_loss, reduction='mean'), optim={ "optimizer": { - "type": "LARS", - "lr": 1e-3, - "weight_decay": 1e-6, + "type": "AdamW", + "lr": 0.002, + "weight_decay": 0.0, }, "scheduler": { "type": "LinearWarmupCosineAnnealing", + "peak_step": lr_warmup_epochs * len(train), + "total_steps": num_epochs * len(train), }, - "interval": "epoch", + "interval": "step", }, - # optim=dict(optimizer=optim, scheduler=scheduler), ) -probe_optimizer = partial(torch.optim.AdamW, lr=3e-4, weight_decay=0.0, betas=(0.9, 0.999), eps=1e-8) -probe_scheduler = partial(torch.optim.lr_scheduler.CosineAnnealingLR, T_max=int(ipe_scale * num_epochs)) +probe_optimizer = partial(torch.optim.AdamW, lr=1e-4, weight_decay=0.0, betas=(0.9, 0.999), eps=1e-8) +probe_scheduler = partial(torch.optim.lr_scheduler.CosineAnnealingLR, T_max=int(ipe_scale * num_epochs * len(train))) linear_probe = ssl.callbacks.OnlineProbe( "linear_probe", - "flat_embedding", + "meanpool_context_embedding", "label", - probe=torch.nn.Linear(384 * num_patches, 10), + probe=torch.nn.Sequential( + torch.nn.BatchNorm1d(192, affine=False), + torch.nn.Linear(192, 10) + ), loss_fn=torch.nn.CrossEntropyLoss(), optimizer=probe_optimizer, scheduler=probe_scheduler, @@ -391,9 +404,9 @@ def forward(self: ssl.Module, batch, stage): rankme = ssl.callbacks.RankMe( name="rankme", - target="flat_embedding", + target="flat_context_embedding", queue_length=min(512, train_batch_size), # NOTE must be >= batch_size - target_shape=(num_patches, 384), + target_shape=(192 * num_patches), ) # Initialize W&B logger with explicit settings @@ -405,11 +418,31 @@ def forward(self: ssl.Module, batch, stage): offline=False, # Ensure offline mode ) +class PerModuleGradLogger(pl.Callback): + def __init__(self, modules=("predictor", "context_encoder", "target_encoder"), norm_type=2): + self.modules = modules + self.norm_type = norm_type + + def on_before_optimizer_step(self, trainer, pl_module, optimizer): + device = pl_module.device + if trainer.global_step % 1000 == 0: + for mod in self.modules: + group = [(n, p) for n, p in pl_module.named_parameters() if n.startswith(f"{mod}.")] + grads = [p.grad for _, p in group if p.grad is not None] + if not grads: + pl_module.log(f"grad/{mod}_norm", torch.tensor(0.0, device=device), on_step=True, logger=True, sync_dist=True) + pl_module.log(f"grad/{mod}_nz_params", torch.tensor(0, device=device), on_step=True, logger=True, sync_dist=True) + continue + norms = torch.stack([g.detach().data.float().norm(self.norm_type).to(device) for g in grads]) + total_norm = torch.norm(norms, self.norm_type) + nonzero = (norms > 0).sum() + pl_module.log(f"grad/{mod}_norm", total_norm, on_step=True, logger=True, sync_dist=True) + pl_module.log(f"grad/{mod}_nz_params", nonzero, on_step=True, logger=True, sync_dist=True) trainer = pl.Trainer( max_epochs=num_epochs, num_sanity_val_steps=0, # Skip sanity check as queues need to be filled first - callbacks=[linear_probe, rankme], + callbacks=[linear_probe, rankme, PerModuleGradLogger(modules=("predictor","context_encoder","target_encoder"))], precision="16-mixed", logger=wandb_logger, devices=1, diff --git a/stable_ssl/tests/scripts/train_ijepa_inet1k.py b/stable_ssl/tests/scripts/train_ijepa_inet1k.py index 06facf6a..e0385f60 100644 --- a/stable_ssl/tests/scripts/train_ijepa_inet1k.py +++ b/stable_ssl/tests/scripts/train_ijepa_inet1k.py @@ -6,25 +6,26 @@ from lightning.pytorch.loggers import WandbLogger from timm.models.vision_transformer import VisionTransformer from torch import nn +from functools import partial import stable_ssl as ssl from stable_ssl.backbone.utils import TeacherStudentWrapper from stable_ssl.data import transforms from stable_ssl.utils.pos_embed import get_2d_sincos_pos_embed from lightning.pytorch.strategies import DDPStrategy -from functools import partial train_batch_size = 128 val_batch_size = 128 num_workers = 32 -num_classes = 1000 +num_classes = 10 num_epochs = 300 -start_lr = 2e-4 -lr = 1e-3 -final_lr = 1e-6 -max_grad_norm = 5.0 +lr_warmup_epochs = 60 +start_lr = 0.0002 +lr = 0.001 +final_lr = 1.0e-6 +# max_grad_norm = 5.0 +max_grad_norm = None ema = (0.996, 1.0) -lr_warmup_steps = 40 mean = [0.485, 0.456, 0.406] @@ -47,10 +48,6 @@ min_keep=10, ) -optim = partial(torch.optim.AdamW, lr=lr, weight_decay=0.0, betas=(0.9, 0.999), eps=1e-8) -scheduler = partial(ssl.optim.lr_scheduler.LinearWarmupCosineAnnealingLR, warmup_start_lr=start_lr, max_steps=num_epochs, warmup_steps=lr_warmup_steps, eta_min=final_lr) - - train_transform = transforms.Compose( transforms.RGB(), @@ -63,19 +60,21 @@ val_transform = transforms.Compose( transforms.RGB(), transforms.Resize((height, width)), - transforms.CenterCrop((height, width)), + transforms.CenterCrop((crop_height, crop_width)), transforms.ToImage(mean=mean, std=std), ) inet1k_train = ssl.data.HFDataset( - path="ILSVRC/imagenet-1k", + path="frgfm/imagenette", + name="320px", split="train", transform=train_transform, ) inet1k_val = ssl.data.HFDataset( - path="ILSVRC/imagenet-1k", + path="frgfm/imagenette", + name="320px", split="validation", transform=val_transform, ) @@ -305,18 +304,28 @@ def forward(self: ssl.Module, batch, stage): target_patches = target_encoder.project_patches(image_patches) pos_embedding = pos_embed(target_patches) target_patches = target_patches + pos_embedding - out["embedding"] = target_encoder.encode_patches( + out["target_embedding"] = target_encoder.encode_patches( target_patches, with_layernorm=True ) - out["sum_embedding"] = out["embedding"].sum(dim=1) - out["flat_embedding"] = out["embedding"].reshape(out["embedding"].shape[0], -1) + # need this here (and not in if statement) so we can still compute rankme and eval properly + with torch.no_grad(): + unmasked_context_patches = context_encoder.project_patches(image_patches) + unmasked_pos_embedding = pos_embed(unmasked_context_patches) + out["context_embedding"] = context_encoder.encode_patches( + unmasked_context_patches + unmasked_pos_embedding, + with_layernorm = False + ) + out["meanpool_context_embedding"] = out["context_embedding"].mean(dim=1) + out["sum_context_embedding"] = out["context_embedding"].sum(dim=1) + out["flat_context_embedding"] = out["context_embedding"].reshape(out["context_embedding"].shape[0], -1) + if not self.training: return out mask_context, masks_target = batch["mask_context"], batch["masks_target"] # target encoder is applied on full patches, then masked - out["target_patches"] = apply_masks(out["embedding"], *masks_target) + out["target_patches"] = apply_masks(out["target_embedding"], *masks_target) # context encoder is applied on masked patches context_patches = apply_masks(image_patches, mask_context) context_patches = context_encoder.project_patches(context_patches) @@ -325,8 +334,10 @@ def forward(self: ssl.Module, batch, stage): context_patches, with_layernorm=False ) out["context_patches"] = context_patches + # use context_patches.shape[0] because applying 4 masks quadruples the batch dimension for the predictor output out["predicted_patches"] = predictor.predict_targets(context_patches, masks_target) out["loss"] = ijepa_loss(out["predicted_patches"], out["target_patches"]) + # context embedding return out @@ -335,15 +346,32 @@ def forward(self: ssl.Module, batch, stage): target_encoder=TeacherStudentWrapper(ctx), predictor=IJEPA_Predictor(**predictor_kwargs), forward=forward, - ijepa_loss=F.smooth_l1_loss, + ijepa_loss=partial(F.mse_loss, reduction='mean'), + optim={ + "optimizer": { + "type": "AdamW", + "lr": lr, + "weight_decay": 0.04, + }, + "scheduler": { + "type": "LinearWarmupCosineAnnealing", + "peak_step": lr_warmup_epochs * len(train), + "total_steps": num_epochs * len(train), + "end_lr": final_lr, + }, + "interval": "step", + } ) linear_probe = ssl.callbacks.OnlineProbe( "linear_probe", - "sum_embedding", + "meanpool_context_embedding", "label", - probe=torch.nn.Linear(encoder_embed_dim, num_classes), + probe=torch.nn.Sequential( + torch.nn.BatchNorm1d(encoder_embed_dim, affine=False), + torch.nn.Linear(encoder_embed_dim, num_classes) + ), loss_fn=torch.nn.CrossEntropyLoss(), metrics={ "top1": torchmetrics.classification.MulticlassAccuracy(num_classes), @@ -353,30 +381,31 @@ def forward(self: ssl.Module, batch, stage): rankme = ssl.callbacks.RankMe( name="rankme", - target="flat_embedding", + target="flat_context_embedding", queue_length=min(512, train_batch_size), # NOTE must be >= batch_size - target_shape=(num_patches, encoder_embed_dim), + target_shape=(encoder_embed_dim * num_patches), ) # Initialize W&B logger with explicit settings wandb_logger = WandbLogger( - project="ijepa-inet1k", - entity="slightly-more-badass", # Your W&B entity - name="ijepa-inet1k-run", + project="ijepa-cifar10", + entity="samibg", # Your W&B entity + name="ijepa-inette-run", log_model=False, # Set to True if you want to save model artifacts - offline=True, # Ensure offline mode + offline=False, # Ensure offline mode ) trainer = pl.Trainer( - max_epochs=300, + max_epochs=num_epochs, num_sanity_val_steps=0, # Skip sanity check as queues need to be filled first callbacks=[linear_probe, rankme], precision="16-mixed", logger=wandb_logger, enable_checkpointing=False, accelerator="gpu", - devices=8, + devices=1, + gradient_clip_val=max_grad_norm, strategy=DDPStrategy( find_unused_parameters=True, # this is because only teacher's params are used in the teacher-student module static_graph=True, diff --git a/stable_ssl/tests/scripts/train_mae_cifar10.py b/stable_ssl/tests/scripts/train_mae_cifar10.py index a5213459..fe2fed28 100644 --- a/stable_ssl/tests/scripts/train_mae_cifar10.py +++ b/stable_ssl/tests/scripts/train_mae_cifar10.py @@ -15,20 +15,36 @@ from stable_ssl.utils.pos_embed import get_2d_sincos_pos_embed +# Dataset configuration +num_classes = 10 mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] -height, width, patch_size = 32, 32, 4 + +# Image and patch configuration +height, width = 32, 32 crop_height, crop_width = 32, 32 +patch_size = 2 num_patches = (crop_height // patch_size) * (crop_width // patch_size) patch_channel_dim = 3 * patch_size * patch_size + +# Masking configuration mask_ratio = 0.75 num_visible_patches = int(num_patches * (1 - mask_ratio)) -num_classes = 10 + +# Training configuration batch_size = 128 val_batch_size = 128 +num_epochs = 2000 num_workers = 16 -encoder_embed_dim = 384 -decoder_embed_dim = 256 + +# Optimization configuration +lr = 1.5e-4 +warmup_epochs = 400 +weight_decay = 0.05 + +# Model configuration +encoder_embed_dim = 192 +decoder_embed_dim = 192 mask_transform_kwargs = dict( patch_size=patch_size, @@ -39,7 +55,7 @@ ) train_transform = transforms.Compose( - transforms.RandomResizedCrop((crop_height, crop_width), scale=(0.2, 1.0), interpolation=3), # 3 is bicubic + # transforms.RandomResizedCrop((crop_height, crop_width), scale=(0.2, 1.0), interpolation=3), # 3 is bicubic transforms.RandomHorizontalFlip(), transforms.RandomMask(**mask_transform_kwargs), transforms.ToImage(mean=mean, std=std), @@ -47,8 +63,8 @@ val_transform = transforms.Compose( transforms.RGB(), - transforms.Resize((height, width)), - transforms.CenterCrop((height, width)), + # transforms.Resize((height, width)), + # transforms.CenterCrop((height, width)), transforms.ToImage(mean=mean, std=std), ) @@ -88,6 +104,7 @@ def __len__(self): drop_last=True, collate_fn=torch.utils.data.default_collate, pin_memory=True, + persistent_workers=True, ) val = torch.utils.data.DataLoader( @@ -97,6 +114,7 @@ def __len__(self): shuffle=False, collate_fn=torch.utils.data.default_collate, pin_memory=True, + persistent_workers=True, ) data = ssl.data.DataModule(train=train, val=val) @@ -230,9 +248,6 @@ def forward(self, batch: dict, stage): patches = encoder.norm(patches) out["embeddings"] = patches - # exclude cls token for rankme (flat_embedding) and linear probe (sum_embedding) - out["flat_embedding"] = out["embeddings"][:, 1:, :].flatten(start_dim=1) - out["sum_embedding"] = out["embeddings"][:, 1:, :].sum(dim=1) return out @@ -240,8 +255,8 @@ def forward(self, batch: dict, stage): img_size=(crop_height, crop_width), patch_size=patch_size, embed_dim=encoder_embed_dim, - depth=16, - num_heads=16, + depth=12, + num_heads=3, qkv_bias=True, # MAE typically uses bias mae_in_dim=patch_channel_dim, ) @@ -251,8 +266,8 @@ def forward(self, batch: dict, stage): patch_size=patch_size, mae_enc_dim=encoder_embed_dim, embed_dim=decoder_embed_dim, - depth=8, - num_heads=16, + depth=4, + num_heads=3, ) @@ -263,23 +278,38 @@ def forward(self, batch: dict, stage): loss_fn=F.mse_loss, # pixel MSE loss. we make implicit assumption that norm-pix-loss is False optim={ "optimizer": { - "type": "LARS", - "lr": 1e-3, - "weight_decay": 1e-6, + "type": "AdamW", + "lr": lr * (batch_size / 256), + "weight_decay": weight_decay, + "betas": (0.9, 0.999), }, "scheduler": { "type": "LinearWarmupCosineAnnealing", + "peak_step": warmup_epochs * len(train), + "total_steps": num_epochs * len(train), }, - "interval": "epoch", + "interval": "step", }, ) + +class MAE_Classifier(torch.nn.Module): + def __init__(self, patch_dim: int, num_classes: int): + super().__init__() + # classifies from the cls token + self.classifier = torch.nn.Linear(patch_dim, num_classes) + + def forward(self, embedding: torch.Tensor) -> torch.Tensor: + cls_token = embedding[:, 0, :] + return self.classifier(cls_token) + + # Note: Linear probe uses visible patches only during training linear_probe = ssl.callbacks.OnlineProbe( "linear_probe", - "flat_embedding", + "embeddings", "label", - probe=torch.nn.Linear(encoder_embed_dim * num_visible_patches, num_classes), + probe=MAE_Classifier(encoder_embed_dim, num_classes), loss_fn=torch.nn.CrossEntropyLoss(), metrics={ "top1": torchmetrics.classification.MulticlassAccuracy(num_classes), @@ -287,13 +317,6 @@ def forward(self, batch: dict, stage): }, ) -# RankMe on encoder outputs -rankme = ssl.callbacks.RankMe( - name="rankme", - target="flat_embedding", - queue_length=min(512, batch_size), - target_shape=(num_visible_patches, encoder_embed_dim), -) # Initialize W&B logger wandb_logger = WandbLogger( @@ -305,15 +328,14 @@ def forward(self, batch: dict, stage): ) trainer = pl.Trainer( - max_epochs=1000, + max_epochs=num_epochs, num_sanity_val_steps=0, - callbacks=[linear_probe, rankme], + callbacks=[linear_probe], precision="16-mixed", logger=wandb_logger, enable_checkpointing=False, accelerator="gpu", devices=1, - ) manager = ssl.Manager(trainer=trainer, module=module, data=data) From 07241a9c14c1cbebfbbdf3e614b5597d96d59a46 Mon Sep 17 00:00:00 2001 From: Sami Date: Thu, 21 Aug 2025 07:42:40 +0000 Subject: [PATCH 26/28] a lot of debugging, new script --- benchmarks/imagenet100/ijepa-vith.py | 358 ++++++++++++++++++ .../tests/scripts/train_ijepa_cifar10.py | 79 ++-- .../tests/scripts/train_ijepa_inet1k.py | 80 ++-- 3 files changed, 457 insertions(+), 60 deletions(-) create mode 100644 benchmarks/imagenet100/ijepa-vith.py diff --git a/benchmarks/imagenet100/ijepa-vith.py b/benchmarks/imagenet100/ijepa-vith.py new file mode 100644 index 00000000..f676df52 --- /dev/null +++ b/benchmarks/imagenet100/ijepa-vith.py @@ -0,0 +1,358 @@ +import math +import torch +from torch import nn +import torch.nn.functional as F +from torch.nn.init import trunc_normal_ +from timm.models.vision_transformer import PatchEmbed, Block +import lightning.pytorch.loggers as pl_loggers +import torchmetrics +from stable_ssl.backbone.utils import TeacherStudentWrapper +from stable_ssl.callbacks.teacher_student import TeacherStudentCallback +from lightning.pytorch.strategies import DDPStrategy +from stable_ssl.data import transforms + + +import stable_ssl as ssl +import lightning as pl +from stable_ssl.utils.pos_embed import get_2d_sincos_pos_embed + +NUM_DEVICES = 8 +BATCH_SIZE = 256 +VAL_BATCH_SIZE = (16384 // NUM_DEVICES) +# VAL_LR_SWEEPS = [0.01, 0.05, 0.001] +LR_MULTIPLIER = (BATCH_SIZE * NUM_DEVICES) / (128 * 16) + + +def apply_masks(x: torch.Tensor, masks: list[torch.Tensor]) -> torch.Tensor: + B, N, D = x.shape + M = len(masks) + idx = torch.stack( + [m.to(x.device, dtype=torch.long) for m in masks], dim=1 + ) # [B, M, K] + x_exp = x.unsqueeze(1).expand(-1, M, -1, -1) # [B, M, N, D] + out = x_exp.gather(2, idx.unsqueeze(-1).expand(-1, -1, -1, D)) # [B, M, K, D] + return out.reshape(B * M, idx.size(-1), D) # [B*M, K, D] + +def repeat_interleave_batch(x: torch.Tensor, B, repeat): + N = x.shape[0] // B + x = x[:N*B] + return x.reshape(N, B, *x.shape[1:]) \ + .repeat_interleave(repeat, dim=1) \ + .reshape(N * B * repeat, *x.shape[1:]) + +def init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + +def fix_init_weight(blocks): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + +class IJEPA_ViT_Encoder(nn.Module): + def __init__( + self, + img_size=224, patch_size=14, embed_dim=768, + depth=12, num_heads=12, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.norm_layer = nn.LayerNorm + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, + in_chans=3, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + # -- + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim), requires_grad=False) + pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], + int(self.patch_embed.num_patches**.5), + cls_token=False) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + # -- + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, + mlp_ratio=4.0, qkv_bias=True, attn_drop=0.0, drop_path=0.0, + norm_layer=self.norm_layer + ) + for _ in range(depth) + ]) + self.norm = self.norm_layer(embed_dim) + self.apply(init_weights) + fix_init_weight(self.blocks) + + def forward(self, x, masks: None | list = None): + x = self.patch_embed(x) + x = x + self.pos_embed + + if masks is not None: + x = apply_masks(x, masks) + + for blk in self.blocks: + x = blk(x) + + return self.norm(x) + +class IJEPA_ViT_Predictor(nn.Module): + def __init__( + self, + num_patches, embed_dim=768, predictor_embed_dim=384, + depth=6, num_heads=12, + ): + super().__init__() + self.embed_dim = embed_dim + self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True) + self.mask_token = nn.Parameter(torch.zeros(1, 1, predictor_embed_dim)) + # -- + self.predictor_pos_embed = nn.Parameter(torch.zeros(1, num_patches, predictor_embed_dim), + requires_grad=False) + predictor_pos_embed = get_2d_sincos_pos_embed(self.predictor_pos_embed.shape[-1], + int(num_patches**.5), + cls_token=False) + self.predictor_pos_embed.data.copy_(torch.from_numpy(predictor_pos_embed).float().unsqueeze(0)) + # -- + self.predictor_blocks = nn.ModuleList([ + Block( + dim=predictor_embed_dim, num_heads=num_heads, + mlp_ratio=4, qkv_bias=True, + attn_drop=0.0, drop_path=0.0, + norm_layer=nn.LayerNorm + ) + for _ in range(depth) + ]) + self.predictor_norm = nn.LayerNorm(predictor_embed_dim) + self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True) + # ------ + trunc_normal_(self.mask_token, std=0.02) + self.apply(init_weights) + fix_init_weight(self.predictor_blocks) + + def forward(self, context_patches, masks_context: list, masks_target: list): + assert len(masks_context) == 1 + B = len(context_patches) // len(masks_context) + x = self.predictor_embed(context_patches) + x_pos_embed = self.predictor_pos_embed.repeat(B, 1, 1) + x += apply_masks(x_pos_embed, masks_context) + + _, N_ctxt, D = x.shape + + # -- concat mask tokens to x + pos_embs = self.predictor_pos_embed.repeat(B, 1, 1) + pos_embs = apply_masks(pos_embs, masks_target) + pos_embs = repeat_interleave_batch(pos_embs, B, repeat=len(masks_context)) + # -- + pred_tokens = self.mask_token.repeat(pos_embs.size(0), pos_embs.size(1), 1) + # -- + pred_tokens += pos_embs + x = x.repeat(len(masks_target), 1, 1) + x = torch.cat([x, pred_tokens], dim=1) + + # -- fwd prop + for blk in self.predictor_blocks: + x = blk(x) + x = self.predictor_norm(x) + + # -- return preds for mask tokens + x = x[:, N_ctxt:] + x = self.predictor_proj(x) + + return x + +def forward_ijepa(self: ssl.Module, batch: dict, stage: str) -> dict: + out = {} + + with torch.no_grad(): + # -- keys for rankme and linear probe + out['context_embeddings'] = self.context_encoder(batch['image'], masks = None) + out['meanpool'] = out['context_embeddings'].mean(dim=1) + out['flat'] = out['context_embeddings'].reshape(out['context_embeddings'].shape[0], -1) + + if not self.training: + return out + + target_patches = self.target_encoder(batch['image']) + target_patches = F.layer_norm(target_patches, (target_patches.size(-1),)) + target_patches = apply_masks(target_patches, batch['masks_target']) + out['target_embeddings'] = target_patches + + out['context_embeddings'] = self.context_encoder(batch['image'], [batch['mask_context']]) + out['predictions'] = self.predictor( + out['context_embeddings'], + [batch['mask_context']], + batch['masks_target'] + ) + out['loss'] = self.ijepa_loss(out['predictions'], out['target_embeddings']) + return out + + +inet1k_train = ssl.data.HFDataset( + path="clane9/imagenet-100", + split="train", + transform=transforms.Compose( + transforms.RGB(), + transforms.RandomResizedCrop((224, 224), scale=(0.3, 1.0)), + transforms.ContextTargetsMultiBlockMask( + patch_size=14, + context_scale=(0.85, 1.0), + context_aspect_ratio=(1.0, 1.0), + target_scales=((0.15, 0.2),) * 4, + target_aspect_ratios=((0.75, 1.5),) * 4, + min_keep=10, + ), + transforms.ToImage( + mean=(0.485, 0.456, 0.406), + std= (0.229, 0.224, 0.225), + ), + ), +) + +inet1k_val = ssl.data.HFDataset( + path="clane9/imagenet-100", + split="validation", + transform=transforms.Compose( + transforms.RGB(), + transforms.Resize((256, 256)), + transforms.CenterCrop((224, 224)), + transforms.ToImage( + mean=(0.485, 0.456, 0.406), + std= (0.229, 0.224, 0.225), + ), + ), +) + +# collate function that makes each mask have the same number of indices, so they can be batched +def standardize_masks(batch: list[dict]): + context_indices = [item.pop("mask_context") for item in batch] + target_indices = [item.pop("masks_target") for item in batch] + batch = torch.utils.data.default_collate(batch) + + min_keep_enc = min(len(ctx) for ctx in context_indices) + min_keep_pred = min( + len(block) for multiblock in target_indices for block in multiblock + ) + + context_batch = [ctx[:min_keep_enc] for ctx in context_indices] + target_batch = [ + [tgt[:min_keep_pred] for tgt in multiblock] for multiblock in target_indices + ] + + collated_masks_context = torch.utils.data.default_collate(context_batch) + collated_masks_target = torch.utils.data.default_collate(target_batch) + + batch["mask_context"] = collated_masks_context + batch["masks_target"] = collated_masks_target + return batch + +train = torch.utils.data.DataLoader( + dataset=inet1k_train, + batch_size=BATCH_SIZE, + shuffle=True, + num_workers=16, + drop_last=True, + collate_fn=standardize_masks, + pin_memory=True, + persistent_workers=True, +) + +val = torch.utils.data.DataLoader( + dataset=inet1k_val, + batch_size=VAL_BATCH_SIZE, + num_workers=16, + shuffle=False, +) + +encoder = IJEPA_ViT_Encoder() +predictor = IJEPA_ViT_Predictor(num_patches=encoder.patch_embed.num_patches) +target_encoder = TeacherStudentWrapper(encoder, base_ema_coefficient=0.996, final_ema_coefficient=1.0) + +ema_callback = TeacherStudentCallback(update_frequency=1, update_after_backward=True) +rankme = ssl.callbacks.RankMe( + name="rankme", target="flat", + queue_length=min(512, BATCH_SIZE), + target_shape=(encoder.embed_dim * encoder.patch_embed.num_patches) +) +sweep_lr = 0.005 +linear_probe = ssl.callbacks.OnlineProbe( + name=f'linear_probe_lr{sweep_lr:.0e}', input='meanpool', target='label', + probe=torch.nn.Sequential( + torch.nn.BatchNorm1d(encoder.embed_dim, affine=False), + torch.nn.Linear(encoder.embed_dim, 100) + ), + loss_fn=torch.nn.CrossEntropyLoss(), + optimizer={ + "type": "LARS", + "lr": sweep_lr, + "weight_decay": 1e-6, + }, + scheduler={ + "type": "StepLR", + "step_size": 15, + "gamma": 0.1, + }, + metrics={ + "top1": torchmetrics.classification.MulticlassAccuracy(100), + "top5": torchmetrics.classification.MulticlassAccuracy(100, top_k=5), + } + ) + +module = ssl.Module( + context_encoder=encoder, + target_encoder=target_encoder, + predictor=predictor, + forward=forward_ijepa, + ijepa_loss=F.smooth_l1_loss, + optim={ + "optimizer": { + "type": "AdamW", + "lr": 0.001 * LR_MULTIPLIER, + "weight_decay": 0.04, # TODO Scheduler to 0.4 + }, + "scheduler": { + "type": "LinearWarmupCosineAnnealing", + "peak_step": 40 * len(train), + "start_factor": 1 / 5, + "total_steps": 300 * len(train), + "end_lr": 1.0e-6 * LR_MULTIPLIER, + }, + "interval": "step", + } +) + +trainer = pl.Trainer( + max_epochs=300, num_sanity_val_steps=0, + callbacks=[ + linear_probe, rankme, ema_callback + ], + precision='16-mixed', + logger=pl_loggers.WandbLogger( + project="ijepa-cifar10", entity="samibg", name="new-ijepa-inet100", + log_model=False, offline=False, + ), + enable_checkpointing=False, + accelerator="gpu", devices=NUM_DEVICES, gradient_clip_val=None, + strategy=DDPStrategy( + find_unused_parameters=True, # this is because only teacher's params are used in the teacher-student module + static_graph=True, + gradient_as_bucket_view=True, + ) +) + +data = ssl.data.DataModule(train=train, val=val) + +manager = ssl.Manager(trainer=trainer, module=module, data=data) + +if __name__ == "__main__": + manager() \ No newline at end of file diff --git a/stable_ssl/tests/scripts/train_ijepa_cifar10.py b/stable_ssl/tests/scripts/train_ijepa_cifar10.py index 6d9e6f9d..2a0f22eb 100644 --- a/stable_ssl/tests/scripts/train_ijepa_cifar10.py +++ b/stable_ssl/tests/scripts/train_ijepa_cifar10.py @@ -14,24 +14,27 @@ from stable_ssl.data import transforms from stable_ssl.data.utils import Dataset from stable_ssl.utils.pos_embed import get_2d_sincos_pos_embed +from stable_ssl.callbacks.teacher_student import TeacherStudentCallback -train_batch_size = 128 +train_batch_size = 512 val_batch_size = 128 -num_epochs = 300 -lr_warmup_epochs = 60 +num_epochs = 100 +lr_warmup_epochs = 15 # max_grad_norm = 10.0 max_grad_norm = None ema = (0.97, 0.999) ipe_scale = 1.25 - +encoder_embed_dim = 64 +predictor_embed_dim = 32 +lr = 0.02 # -- data num_workers = 64 num_classes = 10 mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] -height, width, patch_size = 32, 32, 4 +height, width, patch_size = 32, 32, 2 crop_height, crop_width = 32, 32 # CIFAR-10 is 32x32, but on INET, IJEPA uses 224 # we precompute these so the predictor can make sinusoidal posembeds num_patches = (crop_height // patch_size) * (crop_width // patch_size) @@ -44,7 +47,7 @@ context_aspect_ratio=(1.0, 1.0), target_scales=((0.15, 0.2),) * 4, target_aspect_ratios=((0.75, 1.5),) * 4, - min_keep=10, + min_keep=4, ) @@ -172,7 +175,7 @@ class IJEPA_Encoder(VisionTransformer): def __init__(self, *args, **kwargs): self.weight_init = kwargs.get("weight_init", "") - self.fix_init = kwargs.get("fix_init", False) + self.fix_init = kwargs.get("fix_init", True) ijepa_in_dim = kwargs.pop("ijepa_in_dim") super().__init__(*args, **kwargs) @@ -236,7 +239,7 @@ class IJEPA_Predictor(VisionTransformer): def __init__(self, *args, **kwargs): self.weight_init = kwargs.get("weight_init", "") - self.fix_init = kwargs.get("fix_init", False) + self.fix_init = kwargs.get("fix_init", True) self.predictor_num_patches = kwargs.pop("predictor_num_patches") self.ijepa_encoder_dim = kwargs.pop("ijepa_encoder_dim") self.predictor_pos_embed = pos_embed( @@ -256,14 +259,17 @@ def __init__(self, *args, **kwargs): if self.fix_init: self.fix_init_weight() + def project_context(self, context_patches: torch.Tensor) -> torch.Tensor: + return self.predictor_inproj(context_patches) + def predict_targets( self, context_patches: torch.Tensor, masks_target: list[torch.Tensor] ) -> torch.Tensor: - # NOTE context_patches already positionally embedded B, *_ = context_patches.shape M = len(masks_target) - ctx: torch.Tensor = self.predictor_inproj(context_patches) # [B, N_ctx, D] + # NOTE: These are already projected -> posembedded + ctx: torch.Tensor = context_patches # target position embeddings (stacked per mask): [B*M, K_tgt, D] pos_all = self.predictor_pos_embed.expand(B, -1, -1) @@ -291,19 +297,19 @@ def predict_targets( # pico vit encoder_kwargs = dict( patch_size=patch_size, - embed_dim=192, + embed_dim=encoder_embed_dim, depth=12, - num_heads=3, - qkv_bias=False, + num_heads=2, + qkv_bias=True, ijepa_in_dim=patch_channel_dim, ) predictor_kwargs = dict( patch_size=patch_size, - embed_dim=96, + embed_dim=predictor_embed_dim, depth=6, - num_heads=3, - qkv_bias=False, - ijepa_encoder_dim=192, + num_heads=2, + qkv_bias=True, + ijepa_encoder_dim=encoder_embed_dim, predictor_num_patches=num_patches, ) @@ -317,17 +323,14 @@ def forward(self: ssl.Module, batch, stage): context_encoder: IJEPA_Encoder = self.context_encoder predictor: IJEPA_Predictor = self.predictor ijepa_loss: nn.Module = self.ijepa_loss - - image_patches = target_encoder.patchify(batch["image"]) - target_patches = target_encoder.project_patches(image_patches) - pos_embedding = pos_embed(target_patches) - target_patches = target_patches + pos_embedding - out["target_embedding"] = target_encoder.encode_patches( - target_patches, with_layernorm=True - ) - - # need this here (and not in if statement) so we can still compute rankme and eval properly with torch.no_grad(): + image_patches = target_encoder.patchify(batch["image"]) + target_patches = target_encoder.project_patches(image_patches) + pos_embedding = pos_embed(target_patches) + target_patches = target_patches + pos_embedding + out["target_embedding"] = target_encoder.encode_patches( + target_patches, with_layernorm=True + ) unmasked_context_patches = context_encoder.project_patches(image_patches) unmasked_pos_embedding = pos_embed(unmasked_context_patches) out["context_embedding"] = context_encoder.encode_patches( @@ -353,7 +356,10 @@ def forward(self: ssl.Module, batch, stage): ) out["context_patches"] = context_patches # use context_patches.shape[0] because applying 4 masks quadruples the batch dimension for the predictor output - out["predicted_patches"] = predictor.predict_targets(context_patches, masks_target) + out["predicted_patches"] = predictor.predict_targets( + predictor.project_context(context_patches) + apply_masks(predictor.predictor_pos_embed.repeat(context_patches.shape[0],1,1), mask_context), + masks_target + ) out["loss"] = ijepa_loss(out["predicted_patches"], out["target_patches"]) # context embedding return out @@ -364,12 +370,12 @@ def forward(self: ssl.Module, batch, stage): target_encoder=TeacherStudentWrapper(context_encoder, base_ema_coefficient=ema[0], final_ema_coefficient=ema[1]), predictor=predictor, forward=forward, - ijepa_loss=lambda pred, tgt: ((pred - tgt) ** 2).sum(-1).sum(-1).mean(), + ijepa_loss=F.smooth_l1_loss, # ijepa_loss=partial(F.mse_loss, reduction='mean'), optim={ "optimizer": { "type": "AdamW", - "lr": 0.002, + "lr": lr, "weight_decay": 0.0, }, "scheduler": { @@ -382,7 +388,7 @@ def forward(self: ssl.Module, batch, stage): ) -probe_optimizer = partial(torch.optim.AdamW, lr=1e-4, weight_decay=0.0, betas=(0.9, 0.999), eps=1e-8) +probe_optimizer = partial(torch.optim.AdamW, lr=1e-3, weight_decay=0.0, betas=(0.9, 0.999), eps=1e-8) probe_scheduler = partial(torch.optim.lr_scheduler.CosineAnnealingLR, T_max=int(ipe_scale * num_epochs * len(train))) linear_probe = ssl.callbacks.OnlineProbe( @@ -390,8 +396,8 @@ def forward(self: ssl.Module, batch, stage): "meanpool_context_embedding", "label", probe=torch.nn.Sequential( - torch.nn.BatchNorm1d(192, affine=False), - torch.nn.Linear(192, 10) + torch.nn.BatchNorm1d(encoder_embed_dim, affine=False), + torch.nn.Linear(encoder_embed_dim, 10) ), loss_fn=torch.nn.CrossEntropyLoss(), optimizer=probe_optimizer, @@ -406,7 +412,7 @@ def forward(self: ssl.Module, batch, stage): name="rankme", target="flat_context_embedding", queue_length=min(512, train_batch_size), # NOTE must be >= batch_size - target_shape=(192 * num_patches), + target_shape=(encoder_embed_dim * num_patches), ) # Initialize W&B logger with explicit settings @@ -442,7 +448,10 @@ def on_before_optimizer_step(self, trainer, pl_module, optimizer): trainer = pl.Trainer( max_epochs=num_epochs, num_sanity_val_steps=0, # Skip sanity check as queues need to be filled first - callbacks=[linear_probe, rankme, PerModuleGradLogger(modules=("predictor","context_encoder","target_encoder"))], + callbacks=[ + linear_probe, rankme, PerModuleGradLogger(modules=("predictor","context_encoder","target_encoder")), + TeacherStudentCallback(update_frequency=100), + ], precision="16-mixed", logger=wandb_logger, devices=1, diff --git a/stable_ssl/tests/scripts/train_ijepa_inet1k.py b/stable_ssl/tests/scripts/train_ijepa_inet1k.py index e0385f60..75dad612 100644 --- a/stable_ssl/tests/scripts/train_ijepa_inet1k.py +++ b/stable_ssl/tests/scripts/train_ijepa_inet1k.py @@ -13,11 +13,12 @@ from stable_ssl.data import transforms from stable_ssl.utils.pos_embed import get_2d_sincos_pos_embed from lightning.pytorch.strategies import DDPStrategy +from stable_ssl.callbacks.teacher_student import TeacherStudentCallback train_batch_size = 128 val_batch_size = 128 num_workers = 32 -num_classes = 10 +num_classes = 1000 num_epochs = 300 lr_warmup_epochs = 60 start_lr = 0.0002 @@ -66,15 +67,13 @@ inet1k_train = ssl.data.HFDataset( - path="frgfm/imagenette", - name="320px", + path="ILSVRC/imagenet-1k", split="train", transform=train_transform, ) inet1k_val = ssl.data.HFDataset( - path="frgfm/imagenette", - name="320px", + path="ILSVRC/imagenet-1k", split="validation", transform=val_transform, ) @@ -158,7 +157,7 @@ class IJEPA_Encoder(VisionTransformer): def __init__(self, *args, **kwargs): self.weight_init = kwargs.get("weight_init", "") - self.fix_init = kwargs.get("fix_init", False) + self.fix_init = kwargs.get("fix_init", True) ijepa_in_dim = kwargs.pop("ijepa_in_dim") super().__init__(*args, **kwargs) @@ -222,7 +221,7 @@ class IJEPA_Predictor(VisionTransformer): def __init__(self, *args, **kwargs): self.weight_init = kwargs.get("weight_init", "") - self.fix_init = kwargs.get("fix_init", False) + self.fix_init = kwargs.get("fix_init", True) self.predictor_num_patches = kwargs.pop("predictor_num_patches") self.ijepa_encoder_dim = kwargs.pop("ijepa_encoder_dim") self.predictor_pos_embed = pos_embed( @@ -242,6 +241,9 @@ def __init__(self, *args, **kwargs): if self.fix_init: self.fix_init_weight() + def project_context(self, context_patches: torch.Tensor) -> torch.Tensor: + return self.predictor_inproj(context_patches) + def predict_targets( self, context_patches: torch.Tensor, masks_target: list[torch.Tensor] ) -> torch.Tensor: @@ -249,7 +251,8 @@ def predict_targets( B, *_ = context_patches.shape M = len(masks_target) - ctx: torch.Tensor = self.predictor_inproj(context_patches) # [B, N_ctx, D] + # NOTE: These are already projected -> posembedded + ctx: torch.Tensor = context_patches # target position embeddings (stacked per mask): [B*M, K_tgt, D] pos_all = self.predictor_pos_embed.expand(B, -1, -1) @@ -279,7 +282,7 @@ def predict_targets( embed_dim=encoder_embed_dim, depth=12, num_heads=12, - qkv_bias=False, + qkv_bias=True, ijepa_in_dim=patch_channel_dim, ) predictor_kwargs = dict( @@ -287,7 +290,7 @@ def predict_targets( embed_dim=predictor_embed_dim, depth=12, num_heads=12, - qkv_bias=False, + qkv_bias=True, ijepa_encoder_dim=encoder_embed_dim, predictor_num_patches=num_patches, ) @@ -299,17 +302,14 @@ def forward(self: ssl.Module, batch, stage): context_encoder: IJEPA_Encoder = self.context_encoder predictor: IJEPA_Predictor = self.predictor ijepa_loss: nn.Module = self.ijepa_loss - - image_patches = target_encoder.patchify(batch["image"]) - target_patches = target_encoder.project_patches(image_patches) - pos_embedding = pos_embed(target_patches) - target_patches = target_patches + pos_embedding - out["target_embedding"] = target_encoder.encode_patches( - target_patches, with_layernorm=True - ) - - # need this here (and not in if statement) so we can still compute rankme and eval properly with torch.no_grad(): + image_patches = target_encoder.patchify(batch["image"]) + target_patches = target_encoder.project_patches(image_patches) + pos_embedding = pos_embed(target_patches) + target_patches = target_patches + pos_embedding + out["target_embedding"] = target_encoder.encode_patches( + target_patches, with_layernorm=True + ) unmasked_context_patches = context_encoder.project_patches(image_patches) unmasked_pos_embedding = pos_embed(unmasked_context_patches) out["context_embedding"] = context_encoder.encode_patches( @@ -334,19 +334,26 @@ def forward(self: ssl.Module, batch, stage): context_patches, with_layernorm=False ) out["context_patches"] = context_patches + # TODO Re-write this whole thing with the patchembeds inside the models themselves and using timms patchembed/block + # but making ur own vits. # use context_patches.shape[0] because applying 4 masks quadruples the batch dimension for the predictor output - out["predicted_patches"] = predictor.predict_targets(context_patches, masks_target) + predictor_pos_embed = apply_masks(predictor.predictor_pos_embed.repeat(context_patches.shape[0],1,1), mask_context) + out["predicted_patches"] = predictor.predict_targets( + predictor.project_context(context_patches) + predictor_pos_embed, + masks_target + ) out["loss"] = ijepa_loss(out["predicted_patches"], out["target_patches"]) # context embedding return out + module = ssl.Module( context_encoder=(ctx := IJEPA_Encoder(**encoder_kwargs)), target_encoder=TeacherStudentWrapper(ctx), predictor=IJEPA_Predictor(**predictor_kwargs), forward=forward, - ijepa_loss=partial(F.mse_loss, reduction='mean'), + ijepa_loss=F.smooth_l1_loss, optim={ "optimizer": { "type": "AdamW", @@ -386,12 +393,33 @@ def forward(self: ssl.Module, batch, stage): target_shape=(encoder_embed_dim * num_patches), ) +class PerModuleGradLogger(pl.Callback): + def __init__(self, modules=("predictor", "context_encoder", "target_encoder"), norm_type=2): + self.modules = modules + self.norm_type = norm_type + + def on_before_optimizer_step(self, trainer, pl_module, optimizer): + device = pl_module.device + if trainer.global_step % 100 == 0: + for mod in self.modules: + group = [(n, p) for n, p in pl_module.named_parameters() if n.startswith(f"{mod}.")] + grads = [p.grad for _, p in group if p.grad is not None] + if not grads: + pl_module.log(f"grad/{mod}_norm", torch.tensor(0.0, device=device), on_step=True, logger=True, sync_dist=True) + pl_module.log(f"grad/{mod}_nz_params", torch.tensor(0, device=device), on_step=True, logger=True, sync_dist=True) + continue + norms = torch.stack([g.detach().data.float().norm(self.norm_type).to(device) for g in grads]) + total_norm = torch.norm(norms, self.norm_type) + nonzero = (norms > 0).sum() + pl_module.log(f"grad/{mod}_norm", total_norm, on_step=True, logger=True, sync_dist=True) + pl_module.log(f"grad/{mod}_nz_params", nonzero, on_step=True, logger=True, sync_dist=True) + # Initialize W&B logger with explicit settings wandb_logger = WandbLogger( project="ijepa-cifar10", entity="samibg", # Your W&B entity - name="ijepa-inette-run", + name="ijepa-inet1k-run", log_model=False, # Set to True if you want to save model artifacts offline=False, # Ensure offline mode ) @@ -399,12 +427,14 @@ def forward(self: ssl.Module, batch, stage): trainer = pl.Trainer( max_epochs=num_epochs, num_sanity_val_steps=0, # Skip sanity check as queues need to be filled first - callbacks=[linear_probe, rankme], + callbacks=[linear_probe, rankme, PerModuleGradLogger(modules=("predictor","context_encoder","target_encoder")), + TeacherStudentCallback(update_frequency=1, update_after_backward=True), + ], precision="16-mixed", logger=wandb_logger, enable_checkpointing=False, accelerator="gpu", - devices=1, + devices=8, gradient_clip_val=max_grad_norm, strategy=DDPStrategy( find_unused_parameters=True, # this is because only teacher's params are used in the teacher-student module From 669b4dd8f28c60faf5586ccfddd28ee287d87ded Mon Sep 17 00:00:00 2001 From: Sami Date: Thu, 21 Aug 2025 23:55:46 +0000 Subject: [PATCH 27/28] vith ijepa test --- benchmarks/imagenet100/ijepa-vith.py | 73 +++++++++++++++-------- benchmarks/imagenet100/vicreg-resnet50.py | 6 +- stable_ssl/optim/lr_scheduler.py | 4 ++ 3 files changed, 55 insertions(+), 28 deletions(-) diff --git a/benchmarks/imagenet100/ijepa-vith.py b/benchmarks/imagenet100/ijepa-vith.py index f676df52..e7127063 100644 --- a/benchmarks/imagenet100/ijepa-vith.py +++ b/benchmarks/imagenet100/ijepa-vith.py @@ -5,6 +5,9 @@ from torch.nn.init import trunc_normal_ from timm.models.vision_transformer import PatchEmbed, Block import lightning.pytorch.loggers as pl_loggers +import lightning.pytorch as pl +from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor + import torchmetrics from stable_ssl.backbone.utils import TeacherStudentWrapper from stable_ssl.callbacks.teacher_student import TeacherStudentCallback @@ -16,11 +19,18 @@ import lightning as pl from stable_ssl.utils.pos_embed import get_2d_sincos_pos_embed -NUM_DEVICES = 8 -BATCH_SIZE = 256 -VAL_BATCH_SIZE = (16384 // NUM_DEVICES) -# VAL_LR_SWEEPS = [0.01, 0.05, 0.001] -LR_MULTIPLIER = (BATCH_SIZE * NUM_DEVICES) / (128 * 16) +NUM_DEVICES = 1 + +# Match paper's setup exactly +EFFECTIVE_BATCH_SIZE = 2048 +BATCH_SIZE = EFFECTIVE_BATCH_SIZE // NUM_DEVICES + +EFFECTIVE_VAL_BATCH_SIZE = 16384 +VAL_BATCH_SIZE = EFFECTIVE_VAL_BATCH_SIZE // NUM_DEVICES + +# Use paper's learning rates directly +effective_lr = 0.001 +effective_probe_lr = 0.005 def apply_masks(x: torch.Tensor, masks: list[torch.Tensor]) -> torch.Tensor: @@ -281,33 +291,44 @@ def standardize_masks(batch: list[dict]): ema_callback = TeacherStudentCallback(update_frequency=1, update_after_backward=True) rankme = ssl.callbacks.RankMe( name="rankme", target="flat", - queue_length=min(512, BATCH_SIZE), + queue_length=max(512, BATCH_SIZE), target_shape=(encoder.embed_dim * encoder.patch_embed.num_patches) ) -sweep_lr = 0.005 + linear_probe = ssl.callbacks.OnlineProbe( - name=f'linear_probe_lr{sweep_lr:.0e}', input='meanpool', target='label', + name=f'linear_probe', input='meanpool', target='label', probe=torch.nn.Sequential( torch.nn.BatchNorm1d(encoder.embed_dim, affine=False), torch.nn.Linear(encoder.embed_dim, 100) ), loss_fn=torch.nn.CrossEntropyLoss(), - optimizer={ - "type": "LARS", - "lr": sweep_lr, - "weight_decay": 1e-6, - }, - scheduler={ - "type": "StepLR", - "step_size": 15, - "gamma": 0.1, - }, + # optimizer={ + # "type": "LARS", + # "lr": effective_probe_lr, + # "weight_decay": 1e-6, + # }, + # scheduler={ + # "type": "StepLR", + # "step_size": 15 * len(train), + # "gamma": 1.0, + # }, metrics={ "top1": torchmetrics.classification.MulticlassAccuracy(100), "top5": torchmetrics.classification.MulticlassAccuracy(100, top_k=5), } ) +knn_probe = ssl.callbacks.OnlineKNN( + name="knn_probe", + input="meanpool", + target="label", + queue_length=20000, + metrics={"accuracy": torchmetrics.classification.MulticlassAccuracy(100)}, + input_dim=768, + k=20, +) + + module = ssl.Module( context_encoder=encoder, target_encoder=target_encoder, @@ -317,15 +338,15 @@ def standardize_masks(batch: list[dict]): optim={ "optimizer": { "type": "AdamW", - "lr": 0.001 * LR_MULTIPLIER, + "lr": effective_lr, "weight_decay": 0.04, # TODO Scheduler to 0.4 }, "scheduler": { "type": "LinearWarmupCosineAnnealing", - "peak_step": 40 * len(train), + "peak_step": 40 * int(len(train) / NUM_DEVICES), "start_factor": 1 / 5, - "total_steps": 300 * len(train), - "end_lr": 1.0e-6 * LR_MULTIPLIER, + "total_steps": 300 * int(len(train) / NUM_DEVICES), + "end_lr": 1.0e-6, }, "interval": "step", } @@ -334,14 +355,16 @@ def standardize_masks(batch: list[dict]): trainer = pl.Trainer( max_epochs=300, num_sanity_val_steps=0, callbacks=[ - linear_probe, rankme, ema_callback + linear_probe, knn_probe, rankme, ema_callback, + ModelCheckpoint(monitor='train/loss', mode='min', save_top_k=1, save_last=True, every_n_epochs=10), + LearningRateMonitor(logging_interval='step') ], precision='16-mixed', logger=pl_loggers.WandbLogger( - project="ijepa-cifar10", entity="samibg", name="new-ijepa-inet100", + project="ijepa-cifar10", entity="samibg", name=f"new-ijepa-inet100-num-devices{NUM_DEVICES}", log_model=False, offline=False, ), - enable_checkpointing=False, + # enable_checkpointing=False, accelerator="gpu", devices=NUM_DEVICES, gradient_clip_val=None, strategy=DDPStrategy( find_unused_parameters=True, # this is because only teacher's params are used in the teacher-student module diff --git a/benchmarks/imagenet100/vicreg-resnet50.py b/benchmarks/imagenet100/vicreg-resnet50.py index ca233e75..bee8f628 100644 --- a/benchmarks/imagenet100/vicreg-resnet50.py +++ b/benchmarks/imagenet100/vicreg-resnet50.py @@ -155,9 +155,9 @@ def forward(self, batch, stage): ) wandb_logger = WandbLogger( - entity="stable-ssl", - project="imagenet100-vicreg", - name="vicreg-resnet18", + entity="samibg", + project="ijepa-cifar10", + name="vicreg-resnet50", log_model=False, ) diff --git a/stable_ssl/optim/lr_scheduler.py b/stable_ssl/optim/lr_scheduler.py index dabbcf50..13a1d7b5 100644 --- a/stable_ssl/optim/lr_scheduler.py +++ b/stable_ssl/optim/lr_scheduler.py @@ -395,3 +395,7 @@ def get_lr(self): / 2 for base_lr in self.base_lrs ] + + +if __name__ == "__main__": + _resolve_scheduler_callable("StepLR") \ No newline at end of file From 8d95cd943df59428c456f7d6814eea201579dd40 Mon Sep 17 00:00:00 2001 From: Sami Date: Fri, 22 Aug 2025 03:59:00 +0000 Subject: [PATCH 28/28] ijepa from scratch, still not converging --- benchmarks/imagenet100/ijepa_vith.py | 389 +++++++++++++++++++++++++++ 1 file changed, 389 insertions(+) create mode 100644 benchmarks/imagenet100/ijepa_vith.py diff --git a/benchmarks/imagenet100/ijepa_vith.py b/benchmarks/imagenet100/ijepa_vith.py new file mode 100644 index 00000000..63db525e --- /dev/null +++ b/benchmarks/imagenet100/ijepa_vith.py @@ -0,0 +1,389 @@ +import math +import torch +from torch import nn +import torch.nn.functional as F +from torch.nn.init import trunc_normal_ +from timm.models.vision_transformer import PatchEmbed, Block +import lightning.pytorch.loggers as pl_loggers +import lightning.pytorch as pl +from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor + +import torchmetrics +from stable_ssl.backbone.utils import TeacherStudentWrapper +from stable_ssl.callbacks.teacher_student import TeacherStudentCallback +from lightning.pytorch.strategies import DDPStrategy +from stable_ssl.data import transforms + +from stable_ssl.tests.scripts.masking import MaskCollator +import stable_ssl as ssl +import lightning as pl +from stable_ssl.utils.pos_embed import get_2d_sincos_pos_embed + +NUM_DEVICES = 4 +ZERO_LOSS = True + +# Match paper's setup exactly +EFFECTIVE_BATCH_SIZE = 2048 +BATCH_SIZE = EFFECTIVE_BATCH_SIZE // NUM_DEVICES + +EFFECTIVE_VAL_BATCH_SIZE = 16384 +VAL_BATCH_SIZE = EFFECTIVE_VAL_BATCH_SIZE // NUM_DEVICES + +# Use paper's learning rates directly +effective_lr = 0.001 +effective_probe_lr = 0.01 + + +def apply_masks(x: torch.Tensor, masks: list[torch.Tensor]) -> torch.Tensor: + B, N, D = x.shape + M = len(masks) + idx = torch.stack( + [m.to(x.device, dtype=torch.long) for m in masks], dim=1 + ) # [B, M, K] + x_exp = x.unsqueeze(1).expand(-1, M, -1, -1) # [B, M, N, D] + out = x_exp.gather(2, idx.unsqueeze(-1).expand(-1, -1, -1, D)) # [B, M, K, D] + return out.reshape(B * M, idx.size(-1), D) # [B*M, K, D] + +def repeat_interleave_batch(x: torch.Tensor, B, repeat): + N = x.shape[0] // B + x = x[:N*B] + return x.reshape(N, B, *x.shape[1:]) \ + .repeat_interleave(repeat, dim=1) \ + .reshape(N * B * repeat, *x.shape[1:]) + +def init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + +def fix_init_weight(blocks): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + +class IJEPA_ViT_Encoder(nn.Module): + def __init__( + self, + img_size=224, patch_size=14, embed_dim=768, + depth=12, num_heads=12, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.norm_layer = nn.LayerNorm + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, + in_chans=3, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + # -- + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim), requires_grad=False) + pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], + int(self.patch_embed.num_patches**.5), + cls_token=False) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + # -- + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, + mlp_ratio=4.0, qkv_bias=True, attn_drop=0.0, drop_path=0.0, + norm_layer=self.norm_layer + ) + for _ in range(depth) + ]) + self.norm = self.norm_layer(embed_dim) + self.apply(init_weights) + fix_init_weight(self.blocks) + + def forward(self, x, masks: None | list = None): + x = self.patch_embed(x) + x = x + self.pos_embed + + if masks is not None: + x = apply_masks(x, masks) + + for blk in self.blocks: + x = blk(x) + + return self.norm(x) + +class IJEPA_ViT_Predictor(nn.Module): + def __init__( + self, + num_patches, embed_dim=768, predictor_embed_dim=384, + depth=6, num_heads=12, + ): + super().__init__() + self.embed_dim = embed_dim + self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True) + self.mask_token = nn.Parameter(torch.zeros(1, 1, predictor_embed_dim)) + # -- + self.predictor_pos_embed = nn.Parameter(torch.zeros(1, num_patches, predictor_embed_dim), + requires_grad=False) + predictor_pos_embed = get_2d_sincos_pos_embed(self.predictor_pos_embed.shape[-1], + int(num_patches**.5), + cls_token=False) + self.predictor_pos_embed.data.copy_(torch.from_numpy(predictor_pos_embed).float().unsqueeze(0)) + # -- + self.predictor_blocks = nn.ModuleList([ + Block( + dim=predictor_embed_dim, num_heads=num_heads, + mlp_ratio=4, qkv_bias=True, + attn_drop=0.0, drop_path=0.0, + norm_layer=nn.LayerNorm + ) + for _ in range(depth) + ]) + self.predictor_norm = nn.LayerNorm(predictor_embed_dim) + self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True) + # ------ + trunc_normal_(self.mask_token, std=0.02) + self.apply(init_weights) + fix_init_weight(self.predictor_blocks) + + def forward(self, context_patches, masks_context: list, masks_target: list): + assert len(masks_context) == 1 + B = len(context_patches) // len(masks_context) + x = self.predictor_embed(context_patches) + x_pos_embed = self.predictor_pos_embed.repeat(B, 1, 1) + x += apply_masks(x_pos_embed, masks_context) + + _, N_ctxt, D = x.shape + + # -- concat mask tokens to x + pos_embs = self.predictor_pos_embed.repeat(B, 1, 1) + pos_embs = apply_masks(pos_embs, masks_target) + pos_embs = repeat_interleave_batch(pos_embs, B, repeat=len(masks_context)) + # -- + pred_tokens = self.mask_token.repeat(pos_embs.size(0), pos_embs.size(1), 1) + # -- + pred_tokens += pos_embs + x = x.repeat(len(masks_target), 1, 1) + x = torch.cat([x, pred_tokens], dim=1) + + # -- fwd prop + for blk in self.predictor_blocks: + x = blk(x) + x = self.predictor_norm(x) + + # -- return preds for mask tokens + x = x[:, N_ctxt:] + x = self.predictor_proj(x) + + return x + +def forward_ijepa(self: ssl.Module, batch: dict, stage: str) -> dict: + out = {} + + with torch.no_grad(): + # -- keys for rankme and linear probe + out['context_embeddings'] = self.context_encoder(batch['image'], masks = None) + out['meanpool'] = out['context_embeddings'].mean(dim=1) + out['flat'] = out['context_embeddings'].reshape(out['context_embeddings'].shape[0], -1) + + if not self.training: + return out + + with torch.no_grad(): + target_patches = self.target_encoder(batch['image']) + target_patches = F.layer_norm(target_patches, (target_patches.size(-1),)) + target_patches = apply_masks(target_patches, batch['masks_target']) + out['target_embeddings'] = target_patches + + out['context_embeddings'] = self.context_encoder(batch['image'], [batch['mask_context']] if not isinstance(batch['mask_context'], list) else batch['mask_context']) + out['predictions'] = self.predictor( + out['context_embeddings'], + [batch['mask_context']] if not isinstance(batch['mask_context'], list) else batch['mask_context'], + batch['masks_target'] + ) + out['loss'] = self.ijepa_loss(out['predictions'], out['target_embeddings']) * int(not self.zero_loss) + return out + + +inet1k_train = ssl.data.HFDataset( + path="clane9/imagenet-100", + split="train", + transform=transforms.Compose( + transforms.RGB(), + transforms.RandomResizedCrop((224, 224), scale=(0.3, 1.0)), + transforms.ContextTargetsMultiBlockMask( + patch_size=14, + context_scale=(0.85, 1.0), + context_aspect_ratio=(1.0, 1.0), + target_scales=((0.15, 0.2),) * 4, + target_aspect_ratios=((0.75, 1.5),) * 4, + min_keep=10, + ), + transforms.ToImage( + mean=(0.485, 0.456, 0.406), + std= (0.229, 0.224, 0.225), + ), + ), +) + +inet1k_val = ssl.data.HFDataset( + path="clane9/imagenet-100", + split="validation", + transform=transforms.Compose( + transforms.RGB(), + transforms.Resize((256, 256)), + transforms.CenterCrop((224, 224)), + transforms.ToImage( + mean=(0.485, 0.456, 0.406), + std= (0.229, 0.224, 0.225), + ), + ), +) + +# collate function that makes each mask have the same number of indices, so they can be batched +def standardize_masks(batch: list[dict]): + context_indices = [item.pop("mask_context") for item in batch] + target_indices = [item.pop("masks_target") for item in batch] + batch = torch.utils.data.default_collate(batch) + + min_keep_enc = min(len(ctx) for ctx in context_indices) + min_keep_pred = min( + len(block) for multiblock in target_indices for block in multiblock + ) + + context_batch = [ctx[:min_keep_enc] for ctx in context_indices] + target_batch = [ + [tgt[:min_keep_pred] for tgt in multiblock] for multiblock in target_indices + ] + + collated_masks_context = torch.utils.data.default_collate(context_batch) + collated_masks_target = torch.utils.data.default_collate(target_batch) + + batch["mask_context"] = collated_masks_context + batch["masks_target"] = collated_masks_target + return batch + +mask_transform_kwargs = dict( + input_size=(224, 224), + patch_size=14, + enc_mask_scale=(0.2, 0.8), + pred_mask_scale=(0.15, 0.2), + aspect_ratio=(0.75, 1.5), + nenc=1, + npred=4, + min_keep=10, +) +train = torch.utils.data.DataLoader( + dataset=inet1k_train, + batch_size=BATCH_SIZE, + shuffle=True, + num_workers=64, + drop_last=True, + collate_fn=standardize_masks, + pin_memory=True, + persistent_workers=True, +) + +val = torch.utils.data.DataLoader( + dataset=inet1k_val, + batch_size=VAL_BATCH_SIZE, + num_workers=64, + shuffle=False, +) + +encoder = IJEPA_ViT_Encoder() +predictor = IJEPA_ViT_Predictor(num_patches=encoder.patch_embed.num_patches) +target_encoder = TeacherStudentWrapper(encoder, base_ema_coefficient=0.996, final_ema_coefficient=1.0) + +ema_callback = TeacherStudentCallback(update_frequency=1, update_after_backward=True) +rankme = ssl.callbacks.RankMe( + name="rankme", target="flat", + queue_length=max(512, BATCH_SIZE), + target_shape=(encoder.embed_dim * encoder.patch_embed.num_patches) +) + +linear_probe = ssl.callbacks.OnlineProbe( + name=f'linear_probe', input='meanpool', target='label', + probe=torch.nn.Sequential( + torch.nn.BatchNorm1d(encoder.embed_dim, affine=False), + torch.nn.Linear(encoder.embed_dim, 100) + ), + loss_fn=torch.nn.CrossEntropyLoss(), + optimizer={ + "type": "LARS", + "lr": effective_probe_lr, + "weight_decay": 1e-6, + }, + scheduler={ + "type": "StepLR", + "step_size": 15 * len(train), + "gamma": 1.0, + }, + metrics={ + "top1": torchmetrics.classification.MulticlassAccuracy(100), + "top5": torchmetrics.classification.MulticlassAccuracy(100, top_k=5), + } + ) + +module = ssl.Module( + context_encoder=encoder, + target_encoder=target_encoder, + predictor=predictor, + zero_loss=ZERO_LOSS, + forward=forward_ijepa, + ijepa_loss=F.smooth_l1_loss, + optim={ + "optimizer": { + "type": "AdamW", + "lr": effective_lr, + "weight_decay": 0.04, # TODO Scheduler to 0.4 + }, + "scheduler": { + "type": "LinearWarmupCosineAnnealing", + "peak_step": 40 * int(len(train) / NUM_DEVICES), + "start_factor": 1 / 5, + "total_steps": 300 * int(len(train) / NUM_DEVICES), + "end_lr": 1.0e-6, + }, + "interval": "step", + } +) + +trainer = pl.Trainer( + max_epochs=300, num_sanity_val_steps=0, + limit_train_batches = None if not ZERO_LOSS else 1, + callbacks=[ + linear_probe, + rankme, ema_callback, + ModelCheckpoint(monitor='train/loss', mode='min', save_top_k=1, save_last=True, every_n_epochs=10), + LearningRateMonitor(logging_interval='step') + ], + precision='16-mixed', + logger=pl_loggers.WandbLogger( + project="ijepa-cifar10", entity="samibg", name=f"new-ijepa-inet100-num-devices{NUM_DEVICES}-zero-loss{ZERO_LOSS}", + log_model=False, offline=False, + ), + # enable_checkpointing=False, + accelerator="gpu", devices=NUM_DEVICES, gradient_clip_val=None, + strategy=DDPStrategy( + find_unused_parameters=True, # this is because only teacher's params are used in the teacher-student module + static_graph=True, + gradient_as_bucket_view=True, + ) +) + +data = ssl.data.DataModule(train=train, val=val) + +manager = ssl.Manager(trainer=trainer, + module=module, + data=data, + ckpt_path='/home/sky/stable-SSL/ijepa-cifar10/yxvs6tll/checkpoints/epoch=119-step=7320.ckpt', +) + +if __name__ == "__main__": + manager() \ No newline at end of file