A small, CPU-friendly “mini research playground” for JEPA-style representation learning: predict target embeddings (from an EMA target encoder) rather than reconstructing pixels.
python -m venv .venv
source .venv/bin/activate
pip install -U pip
pip install -e ".[dev]"Python 3.10+ works; Python 3.11+ is recommended.
Train:
jepa-toy train --task sine --config configs/sine/base.yaml
jepa-toy train --task tokens --config configs/tokens/base.yaml
jepa-toy train --task gridworld --config configs/gridworld/base.yamlEvaluate:
jepa-toy eval --task sine --run_dir runs/sine/<timestamp>_<tag>TensorBoard:
tensorboard --logdir runsTiny sweep (writes runs/summary.csv):
python scripts/run_sweeps.py --tiny- Online encoder
f(x) -> z - Predictor
g(z_context, condition) -> z_pred - Target encoder
f_t(x) -> z_tis an EMA copy of the online encoder (stop-grad) - Loss matches
z_predtostopgrad(z_t)(MSE or cosine), optionally with variance regularization
We log collapse diagnostics: per-dimension variance/std, covariance spectrum (top singular values), and batch cosine similarity.
- Implement
src/jepa_toy/tasks/<name>/task.pywithget_task_spec() - Add
configs/<name>/base.yaml - Add
docs/tasks/<name>.md