Skip to content
Open

IJEPA #244

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
6dc75fe
IJEPA Multiblock masking with visualizations
sami-bg Aug 9, 2025
3291547
Multiblock masking complete
sami-bg Aug 9, 2025
990ac13
docstrings
sami-bg Aug 9, 2025
517889c
transform
sami-bg Aug 9, 2025
1cd7c32
masking fixed, collate fixed, ijepa nearly done
sami-bg Aug 10, 2025
db9860a
small bug
sami-bg Aug 10, 2025
e1fab88
import
sami-bg Aug 10, 2025
6367714
IJEPA - need to do sinusoidal posemb still
sami-bg Aug 11, 2025
5853147
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 11, 2025
c4052ab
ijepa example, masked args context/target
sami-bg Aug 17, 2025
43329be
merge
sami-bg Aug 17, 2025
ad64c0f
precommits
sami-bg Aug 17, 2025
979fc91
precommits
sami-bg Aug 17, 2025
5b1329b
inet test script
sami-bg Aug 17, 2025
6254db3
IJEPA INET1k HF dataset
sami-bg Aug 17, 2025
129997e
mae, random masking, inet
sami-bg Aug 17, 2025
509afdd
WIP MAE simplifying arch
sami-bg Aug 17, 2025
de7f9fb
more testing archs
sami-bg Aug 17, 2025
f7e7678
more testing archs
sami-bg Aug 17, 2025
603977a
mae cifar10
sami-bg Aug 18, 2025
a1f84df
mae cifar10
sami-bg Aug 18, 2025
c5c97a7
inet1k mae
sami-bg Aug 18, 2025
9d3344d
inetk mae
sami-bg Aug 18, 2025
251daaa
todo gradnorm
sami-bg Aug 19, 2025
a316f99
merge
sami-bg Aug 19, 2025
d5cf9d7
fixes and sweeps
sami-bg Aug 19, 2025
5ec568e
fixed multiblock masking flatten bug, sweeping hparams
sami-bg Aug 20, 2025
07241a9
a lot of debugging, new script
sami-bg Aug 21, 2025
669b4dd
vith ijepa test
sami-bg Aug 21, 2025
8d95cd9
ijepa from scratch, still not converging
sami-bg Aug 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions assets/benchmarks/cifar10/byol.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ trainer:
# ===== Module Parameters =====
module:
backbone:
_target_: stable_ssl.modules.TeacherStudentModule
_target_: stable_ssl.backbone.utils.TeacherStudentWrapper
student:
_target_: stable_ssl.modules.load_backbone
name: resnet50
low_resolution: True
num_classes: null
projector:
_target_: stable_ssl.modules.TeacherStudentModule
_target_: stable_ssl.backbone.utils.TeacherStudentWrapper
student:
_target_: stable_ssl.modules.MLP
sizes: [2048, 2048, 128]
Expand Down
4 changes: 2 additions & 2 deletions assets/benchmarks/cifar10/dino.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ trainer:
# ===== Module Parameters =====
module:
backbone:
_target_: stable_ssl.modules.TeacherStudentModule
_target_: stable_ssl.backbone.utils.TeacherStudentWrapper
student:
_target_: stable_ssl.modules.load_backbone
name: resnet50
low_resolution: True
num_classes: null
projector:
_target_: stable_ssl.modules.TeacherStudentModule
_target_: stable_ssl.backbone.utils.TeacherStudentWrapper
student:
_target_: stable_ssl.modules.MLP
sizes: [2048, 2048, 2048]
Expand Down
4 changes: 2 additions & 2 deletions assets/benchmarks/cifar100/byol.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ trainer:
# ===== Module Parameters =====
module:
backbone:
_target_: stable_ssl.modules.TeacherStudentModule
_target_: stable_ssl.backbone.utils.TeacherStudentWrapper
student:
_target_: stable_ssl.modules.load_backbone
name: resnet50
low_resolution: True
num_classes: null
projector:
_target_: stable_ssl.modules.TeacherStudentModule
_target_: stable_ssl.backbone.utils.TeacherStudentWrapper
student:
_target_: stable_ssl.modules.MLP
sizes: [2048, 2048, 128]
Expand Down
4 changes: 2 additions & 2 deletions assets/benchmarks/cifar100/dino.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ trainer:
# ===== Module Parameters =====
module:
backbone:
_target_: stable_ssl.modules.TeacherStudentModule
_target_: stable_ssl.backbone.utils.TeacherStudentWrapper
student:
_target_: stable_ssl.modules.load_backbone
name: resnet50
low_resolution: True
num_classes: null
projector:
_target_: stable_ssl.modules.TeacherStudentModule
_target_: stable_ssl.backbone.utils.TeacherStudentWrapper
student:
_target_: stable_ssl.modules.MLP
sizes: [2048, 2048, 2048]
Expand Down
4 changes: 2 additions & 2 deletions assets/benchmarks/imagenette/byol.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ trainer:
# ===== Module Parameters =====
module:
backbone:
_target_: stable_ssl.modules.TeacherStudentModule
_target_: stable_ssl.backbone.utils.TeacherStudentWrapper
student:
_target_: stable_ssl.modules.load_backbone
name: resnet50
low_resolution: False
num_classes: null
projector:
_target_: stable_ssl.modules.TeacherStudentModule
_target_: stable_ssl.backbone.utils.TeacherStudentWrapper
student:
_target_: stable_ssl.modules.MLP
sizes: [2048, 2048, 128]
Expand Down
4 changes: 2 additions & 2 deletions assets/benchmarks/imagenette/dino.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ trainer:
# ===== Module Parameters =====
module:
backbone:
_target_: stable_ssl.modules.TeacherStudentModule
_target_: stable_ssl.backbone.utils.TeacherStudentWrapper
student:
_target_: stable_ssl.modules.load_backbone
name: resnet50
low_resolution: False
num_classes: null
projector:
_target_: stable_ssl.modules.TeacherStudentModule
_target_: stable_ssl.backbone.utils.TeacherStudentWrapper
student:
_target_: stable_ssl.modules.MLP
sizes: [2048, 2048, 2048]
Expand Down
10 changes: 6 additions & 4 deletions benchmarks/cifar10/vicreg-resnet18.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,11 @@ def forward(self, batch, stage):
)

wandb_logger = WandbLogger(
entity="stable-ssl",
project="cifar10-vicreg",
name="vicreg-resnet18",
log_model=False,
project="ijepa-cifar10",
entity="samibg", # Your W&B entity
name="vicreg-cifar10-run",
log_model=False, # Set to True if you want to save model artifacts
offline=False, # Ensure offline mode
)

trainer = pl.Trainer(
Expand All @@ -165,6 +166,7 @@ def forward(self, batch, stage):
callbacks=[knn_probe, linear_probe],
precision="16-mixed",
logger=wandb_logger,
devices=1,
enable_checkpointing=False,
)

Expand Down
Loading