From a2d66e17fd527f90cfed59d3871b58436d6f82bd Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Wed, 1 Apr 2026 15:43:02 +0530 Subject: [PATCH 01/29] add v0 --- examples/leWorldModel/README.md | 248 ++++++++ examples/leWorldModel/bench.py | 353 ++++++++++++ examples/leWorldModel/config/lewm_pusht.yaml | 42 ++ examples/leWorldModel/create_data.py | 537 ++++++++++++++++++ examples/leWorldModel/dataset.md | 189 ++++++ examples/leWorldModel/eda_analysis.py | 490 ++++++++++++++++ examples/leWorldModel/lewm_lance/__init__.py | 9 + .../leWorldModel/lewm_lance/dataloaders.py | 147 +++++ examples/leWorldModel/lewm_lance/dataset.py | 231 ++++++++ examples/leWorldModel/requirements.txt | 28 + examples/leWorldModel/train.py | 365 ++++++++++++ 11 files changed, 2639 insertions(+) create mode 100644 examples/leWorldModel/README.md create mode 100644 examples/leWorldModel/bench.py create mode 100644 examples/leWorldModel/config/lewm_pusht.yaml create mode 100644 examples/leWorldModel/create_data.py create mode 100644 examples/leWorldModel/dataset.md create mode 100644 examples/leWorldModel/eda_analysis.py create mode 100644 examples/leWorldModel/lewm_lance/__init__.py create mode 100644 examples/leWorldModel/lewm_lance/dataloaders.py create mode 100644 examples/leWorldModel/lewm_lance/dataset.py create mode 100644 examples/leWorldModel/requirements.txt create mode 100644 examples/leWorldModel/train.py diff --git a/examples/leWorldModel/README.md b/examples/leWorldModel/README.md new file mode 100644 index 0000000..855f04c --- /dev/null +++ b/examples/leWorldModel/README.md @@ -0,0 +1,248 @@ +# leWorldModel × LanceDB + +End-to-end training of [leWorldModel](https://github.com/lucas-maes/le-wm) (a JEPA-based world model) backed by LanceDB instead of HDF5. + +``` +examples/leWorldModel/ +├── create_data.py # HDF5 → LanceDB conversion + Geneva embedding backfill +├── train.py # Trainer: LanceDB loaders, identical LeWM model/loss +├── eda_analysis.py # EDA, quality scan, splits, vector search +├── config/ +│ └── lewm_pusht.yaml # Example config (copy and edit per dataset) +└── lewm_lance/ + ├── dataset.py # LeWMLanceDataset — temporal window sampler + └── dataloaders.py # make_train_val_loaders() factory +``` + +--- + +## What is leWorldModel? + +LeWM is a Joint Embedding Predictive Architecture (JEPA) that learns a world model from raw pixels with two losses: + +- **Next-embedding prediction** — MSE between predicted and actual next latent state +- **SIGReg** (Sketch Isotropic Gaussian Regularizer) — keeps the latent space well-shaped + +The model is ~15M parameters: a ViT-tiny encoder, an autoregressive predictor, and an action embedder. It trains stably on a single GPU. + +The paper evaluates on four datasets independently — they are not mixed during training: + +| Dataset | Env | Modalities | Config | +|---------|-----|-----------|--------| +| DMControl Reacher | `reacher` | pixels, action, observation | `lewm_reacher.yaml` | +| OGBench Cube | `cube_single_expert` | pixels, action, observation | `lewm_cube.yaml` | +| PushT | `pusht_expert_train` | pixels, action, proprio, state | `lewm_pusht.yaml` | +| TwoRoom | `tworoom` | pixels, action, proprio | `lewm_tworoom.yaml` | + +--- + +## Hardware + +LeWM is intentionally small (~15M params) and trains on a single GPU. These are practical recommendations: + +| GPU | VRAM | batch_size | Notes | +|-----|------|-----------|-------| +| RTX 3090 / 4090 | 24 GB | 128 | Matches paper. ~4–6 hrs per dataset at 100 epochs. | +| A100 40 GB | 40 GB | 256 | 2× faster than 3090. Use if available. | +| A100 80 GB / H100 | 80 GB | 512 | Overkill for LeWM alone; useful if running multiple seeds in parallel. | +| RTX 3080 / 4070 | 10–12 GB | 64 | Reduce `batch_size` and `num_workers` to fit. Scale `lr` linearly. | + +Training uses `bf16-mixed` precision throughout. If your GPU does not support bf16 (pre-Ampere), change `precision: "16-mixed"` in the config. + +For the DataLoader, `num_workers=6` works well with a local LanceDB store. With S3-backed storage, increase to `num_workers=8–12` to overlap network I/O with GPU compute. + +--- + +## Reproducing the paper + +### Step 1 — Convert datasets + +Download the datasets via `stable-worldmodel` and convert each to a LanceDB table. +All four datasets can share a single LanceDB store (each gets its own table). + +```bash +# All four datasets into one local store +python create_data.py --dataset all --lance-uri ./lewm_lance + +# Or one at a time +python create_data.py --dataset reacher --lance-uri ./lewm_lance +python create_data.py --dataset cube --lance-uri ./lewm_lance +python create_data.py --dataset pusht --lance-uri ./lewm_lance +python create_data.py --dataset tworoom --lance-uri ./lewm_lance + +# S3-backed store (credentials via env or --aws-* flags) +python create_data.py --dataset all --lance-uri s3://my-bucket/lewm +``` + +This creates four tables: `lewm_reacher`, `lewm_cube`, `lewm_pusht`, `lewm_tworoom`. + +### Step 2 — EDA and data quality check + +Run this **before** training to catch any data issues and understand each dataset. +Uses DINOv2 embeddings (no training needed — frozen foundation model). + +```bash +# Quality scan + statistics on each dataset +python eda_analysis.py --table lewm_pusht --section quality +python eda_analysis.py --table lewm_reacher --section quality +python eda_analysis.py --table lewm_cube --section quality +python eda_analysis.py --table lewm_tworoom --section quality + +# Add DINOv2 embeddings for pre-training EDA (semantic search, clustering, dedup) +# Requires LanceDB Enterprise (Geneva) — skip if unavailable +python create_data.py --dataset all --embed --embedding-model dinov2 + +# Explore with embeddings +python eda_analysis.py --table lewm_pusht --section vector_search --emb-col emb_dinov2 +python eda_analysis.py --table lewm_pusht --section entropy # find diverse episodes +python eda_analysis.py --table lewm_pusht --section stats +``` + +### Step 3 — Train on each dataset + +Each dataset is trained independently. Create a config per dataset by copying +`config/lewm_pusht.yaml` and updating `data.table_name` and `data.columns`. + +```bash +# PushT +python train.py --config config/lewm_pusht.yaml + +# Reacher — override table and columns directly without a separate config +python train.py --config config/lewm_pusht.yaml \ + --table-name lewm_reacher \ + --columns pixels action observation + +# Cube +python train.py --config config/lewm_pusht.yaml \ + --table-name lewm_cube \ + --columns pixels action observation + +# TwoRoom +python train.py --config config/lewm_pusht.yaml \ + --table-name lewm_tworoom \ + --columns pixels action proprio + +# S3-backed store with explicit credentials +python train.py --config config/lewm_pusht.yaml \ + --lance-uri s3://my-bucket/lewm \ + --aws-region us-east-1 \ + --aws-access-key-id AKIA... \ + --aws-secret-access-key ... + +# With credentials in environment (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_DEFAULT_REGION) +python train.py --config config/lewm_pusht.yaml --lance-uri s3://my-bucket/lewm +``` + +### Step 4 — Post-training analysis + +After training, add LeWM encoder embeddings to the table for world-model-specific analysis. +These reflect what the trained model considers semantically similar — different from DINOv2. + +```bash +python create_data.py \ + --dataset pusht \ + --embed \ + --embedding-model lewm \ + --checkpoint ./checkpoints/lewm_pusht_lewm_epoch_99_object.ckpt + +# Compare DINOv2 vs LeWM similarity structure +python eda_analysis.py --table lewm_pusht --section vector_search --emb-col emb_lewm +python eda_analysis.py --table lewm_pusht --section retrieval --emb-col emb_lewm +``` + +--- + +## Benchmarking dataloader throughput + +`bench.py` measures raw dataloader throughput (samples/sec) for three backends +independently of GPU compute, so differences are purely from data loading. + +```bash +# LanceDB S3 vs HDF5 local (put the HDF5 file in /dev/shm for best-case comparison) +python bench.py \ + --lance-uri s3://my-bucket/lewm \ + --table-name lewm_pusht \ + --hdf5-local /dev/shm/pusht.hdf5 + +# Add HDF5-from-S3 via s3fs (reads HDF5 directly from S3, no download) +python bench.py \ + --lance-uri s3://my-bucket/lewm \ + --table-name lewm_pusht \ + --hdf5-local /dev/shm/pusht.hdf5 \ + --hdf5-s3-key hdf5/pusht.hdf5 \ + --s3-bucket my-bucket + +# Credentials via env vars (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_DEFAULT_REGION) +python bench.py --lance-uri s3://my-bucket/lewm --table-name lewm_pusht +``` + +Why the gap: + +- **LanceDB S3 vs HDF5 local**: HDF5 serializes workers through a POSIX file lock — effective parallelism is ~1 worker regardless of `num_workers`. LanceDB workers hold independent connections with no locking. +- **HDF5 s3fs**: HDF5 makes many small random seeks per batch. Each seek over S3 becomes a separate HTTP range request. For temporal window reads (T=4 rows × multiple columns), this is dozens of round-trips per batch. + +--- + +## How the temporal window sampler works + +LeWM needs **contiguous T=4 frame windows** from the same episode per training sample. + +`LeWMLanceDataset` precomputes valid window positions at init time: + +1. Loads only `(episode_idx, step_idx)` into memory (~16 bytes/row — negligible even at millions of steps) +2. Checks all consecutive row pairs for same-episode + sequential step constraints +3. Stores the resulting `_window_starts` array (int64 numpy) + +At training time, `__getitems__(window_indices)` fetches all **B×T rows in one `Permutation.__getitems__`** call and splits into per-sample dicts. No N×B individual lookups. + +--- + +## Multi-worker safety + +The LanceDB `Permutation` holds Rust async state that cannot be pickled. Each DataLoader worker gets a zeroed-out copy and lazily rebuilds its own connection: + +```python +def __getstate__(self): + state = self.__dict__.copy() + state["_perm"] = None + return state + +def _ensure_open(self): + if self._perm is None: + db = lancedb.connect(self.uri, **self.connect_kwargs) + self._perm = Permutation.identity(db.open_table(...))... +``` + +Combined with `multiprocessing_context="spawn"` and `persistent_workers=True`. + +--- + +## LanceDB vs HDF5 + +| Feature | LanceDB | HDF5 | +|---------|---------|------| +| Multi-process reads | Yes (per-worker connection) | No (POSIX file lock) | +| Columnar partial reads | Native Arrow | Compound datasets only | +| Vector / ANN search | Built-in IVF-PQ | Not supported | +| SQL-like episode filters | `episode_idx = 42` | Loop + mask in Python | +| Cloud-native (S3/GCS) | Native, parallel | Download first | +| Schema evolution | Add columns in-place | Limited | +| Versioning / time-travel | Yes | No | +| Embedding storage | Native vector column | Separate dataset | +| Train/val split | Filter query, zero copy | Copy or index arrays | +| JPEG pixel compression | ~13× smaller than raw uint8 | Raw arrays only | +| Concurrent writers | Yes | No | + +The key practical difference for LeWM training: HDF5 serializes all DataLoader workers through a single file lock, limiting effective parallelism to ~1–2 workers regardless of how many you spawn. LanceDB workers each hold their own connection with no contention. + +--- + +## What else you can do with multimodal robotics data in LanceDB + +1. **Pre-training data curation**: use DINOv2 ANN search to deduplicate near-identical episodes before spending GPU hours on them. +2. **Curriculum learning**: rank episodes by action entropy (`eda_analysis.py --section entropy`) and present easy→hard schedules during training. +3. **Goal-conditioned retrieval**: encode a goal frame, search `emb_lewm` to find the K nearest observed states — useful for reward shaping. +4. **Offline RL data mixing**: union multiple datasets in one table with a `dataset_name` column, filter at training time with no file management. +5. **Reward relabeling**: append a `reward` column after collection without rewriting pixel data. +6. **Active data collection**: stream new rollout episodes into the table while training runs — LanceDB concurrent writes are safe. +7. **Embedding visualization**: dump `emb_lewm` → UMAP to inspect the latent space structure the world model has learned. diff --git a/examples/leWorldModel/bench.py b/examples/leWorldModel/bench.py new file mode 100644 index 0000000..23e9cb5 --- /dev/null +++ b/examples/leWorldModel/bench.py @@ -0,0 +1,353 @@ +""" +leWorldModel dataloader throughput benchmark: LanceDB vs HDF5. + +Measures how fast each backend can feed batches to the GPU, independently +of training compute. Three backends: + + LanceDB S3/local — our implementation, parallel workers, no download step + HDF5 local — reads from a local file (best-case for HDF5) + HDF5 s3fs — reads directly from S3 via s3fs (realistic, no download) + +Usage: + # LanceDB S3 vs HDF5 local + python bench.py \\ + --lance-uri s3://my-bucket/lewm \\ + --table-name lewm_pusht \\ + --hdf5-local /dev/shm/pusht.hdf5 + + # All three backends + python bench.py \\ + --lance-uri s3://my-bucket/lewm \\ + --table-name lewm_pusht \\ + --hdf5-local /dev/shm/pusht.hdf5 \\ + --hdf5-s3-key hdf5/pusht.hdf5 \\ + --s3-bucket my-bucket + + # Credentials via environment variables (AWS_ACCESS_KEY_ID etc.) + python bench.py --lance-uri s3://my-bucket/lewm --table-name lewm_pusht +""" + +import argparse +import multiprocessing +import os +import time + +import h5py +import hdf5plugin # noqa: F401 — registers HDF5 decompression filters +import numpy as np +import torch +from PIL import Image +from torch.utils.data import DataLoader +from torchvision import transforms + +import sys +sys.path.insert(0, os.path.dirname(__file__)) +from lewm_lance import make_lewm_lance_loader + + +# --------------------------------------------------------------------------- +# Defaults +# --------------------------------------------------------------------------- + +BATCH_SIZE = 128 +NUM_STEPS = 4 +IMAGE_SIZE = 224 +NUM_WORKERS = 8 +PREFETCH_FACTOR = 3 +WARMUP_BATCHES = 5 +BENCH_BATCHES = 50 +COLUMNS = ["pixels", "action", "proprio", "state"] + + +# --------------------------------------------------------------------------- +# HDF5 dataset (mirrors the original stable_worldmodel.data.HDF5Dataset) +# --------------------------------------------------------------------------- + +_TRANSFORM = transforms.Compose([ + transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), +]) + + +class HDF5LeWMDataset(torch.utils.data.Dataset): + """ + HDF5-backed temporal-window dataset matching stable-worldmodel's HDF5Dataset. + + The HDF5 schema uses per-episode metadata arrays: + ep_len — shape (n_episodes,) episode lengths + ep_offset — shape (n_episodes,) global start row per episode + + Valid clip_indices are (episode_idx, local_start) pairs where a full window + of span = num_steps * frameskip rows fits within the episode. At read time, + the global slice [offset + local_start : offset + local_start + span] is + fetched and every frameskip-th frame is selected. + + Pixels are stored as (N, H, W, C) uint8 — no transpose needed before PIL. + Non-pixel columns (action, proprio, etc.) are cached in RAM at init time; + only pixels are read from the file at __getitem__ time. + + hdf5_src can be a local file path (str) or an s3fs file object. + h5py is opened lazily per worker because handles are not fork-safe. + """ + + def __init__(self, hdf5_src, columns, num_steps=NUM_STEPS, frameskip=1): + self._src = hdf5_src + self.columns = columns + self.num_steps = num_steps + self.frameskip = frameskip + self._span = num_steps * frameskip + self._file = None + + with h5py.File(self._src, "r", rdcc_nbytes=256 * 1024 * 1024) as f: + ep_len = np.array(f["ep_len"], dtype=np.int32) + ep_offset = np.array(f["ep_offset"], dtype=np.int32) + # Cache all non-pixel columns in RAM — avoids repeated random HDF5 seeks + self._cached: dict[str, np.ndarray] = {} + for col in columns: + if col != "pixels": + self._cached[col] = np.array(f[col], dtype=np.float32) + + # Build (ep_idx, local_start) pairs for all valid windows + self._clip_indices: list[tuple[int, int]] = [] + for ep_idx, (off, length) in enumerate(zip(ep_offset.tolist(), ep_len.tolist())): + if length < self._span: + continue + for local_start in range(length - self._span + 1): + self._clip_indices.append((ep_idx, local_start)) + + self._ep_offset = ep_offset + + def __len__(self): + return len(self._clip_indices) + + def __getstate__(self): + state = self.__dict__.copy() + state["_file"] = None # h5py handle can't be pickled + return state + + def _ensure_open(self): + if self._file is None: + self._file = h5py.File(self._src, "r", swmr=True, rdcc_nbytes=256 * 1024 * 1024) + + def __getitem__(self, clip_idx): + self._ensure_open() + ep_idx, local_start = self._clip_indices[clip_idx] + g_start = int(self._ep_offset[ep_idx]) + local_start + g_end = g_start + self._span + + pixels_raw = self._file["pixels"][g_start:g_end:self.frameskip] # (T, H, W, C) + frames = [ + _TRANSFORM(Image.fromarray(pixels_raw[t].astype(np.uint8))) + for t in range(self.num_steps) + ] + + sample = {"pixels": torch.stack(frames)} + for col in self.columns: + if col == "pixels": + continue + data = self._cached[col][g_start:g_end:self.frameskip] + sample[col] = torch.from_numpy(np.nan_to_num(data, nan=0.0)) + return sample + + +def _collate(samples): + return {k: torch.stack([s[k] for s in samples], dim=0) for k in samples[0]} + + +def make_hdf5_loader(hdf5_src, columns, batch_size, num_workers, prefetch_factor): + return DataLoader( + HDF5LeWMDataset(hdf5_src, columns), + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + pin_memory=True, + drop_last=True, + collate_fn=_collate, + persistent_workers=(num_workers > 0), + prefetch_factor=prefetch_factor if num_workers > 0 else None, + multiprocessing_context="spawn" if num_workers > 0 else None, + ) + + +# --------------------------------------------------------------------------- +# Benchmark +# --------------------------------------------------------------------------- + +def measure_throughput(loader, label, warmup, steps): + """ + Iterate the loader for `warmup` batches (discarded), then time `steps` batches. + Returns samples/sec and average batch latency in ms. + """ + print(f"\n{'─' * 60}") + print(f" {label}") + print(f" batch_size={loader.batch_size} workers={loader.num_workers}") + print(f"{'─' * 60}") + + it = iter(loader) + + print(f" warming up ({warmup} batches)...") + for _ in range(warmup): + batch = next(it, None) + if batch is None: + it = iter(loader) + batch = next(it) + # Touch the pixels tensor to ensure decoding actually happened + _ = batch["pixels"].shape + + print(f" benchmarking ({steps} batches)...") + batch_times = [] + t_total = time.perf_counter() + + for _ in range(steps): + t0 = time.perf_counter() + batch = next(it, None) + if batch is None: + it = iter(loader) + batch = next(it) + _ = batch["pixels"].shape + batch_times.append(time.perf_counter() - t0) + + elapsed = time.perf_counter() - t_total + samples_sec = (steps * loader.batch_size) / elapsed + avg_ms = np.mean(batch_times) * 1000 + p99_ms = np.percentile(batch_times, 99) * 1000 + + print(f"\n samples/sec : {samples_sec:,.0f}") + print(f" batch avg : {avg_ms:.1f} ms") + print(f" batch p99 : {p99_ms:.1f} ms") + + return {"label": label, "samples_sec": samples_sec, "avg_ms": avg_ms, "p99_ms": p99_ms} + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + +def _build_parser(): + p = argparse.ArgumentParser( + description="Benchmark LanceDB vs HDF5 dataloader throughput for leWorldModel", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + p.add_argument("--lance-uri", required=True) + p.add_argument("--table-name", required=True) + p.add_argument("--hdf5-local", default=None, help="Path to local HDF5 file") + p.add_argument("--hdf5-s3-key", default=None, help="S3 object key for HDF5 file") + p.add_argument("--s3-bucket", default=None, help="S3 bucket (for --hdf5-s3-key)") + p.add_argument("--columns", nargs="+", default=COLUMNS) + p.add_argument("--batch-size", type=int, default=BATCH_SIZE) + p.add_argument("--num-workers", type=int, default=NUM_WORKERS) + p.add_argument("--warmup", type=int, default=WARMUP_BATCHES) + p.add_argument("--steps", type=int, default=BENCH_BATCHES) + + s3 = p.add_argument_group("S3 credentials (fall back to AWS_* env vars)") + s3.add_argument("--aws-access-key-id", default=os.environ.get("AWS_ACCESS_KEY_ID")) + s3.add_argument("--aws-secret-access-key", default=os.environ.get("AWS_SECRET_ACCESS_KEY")) + s3.add_argument("--aws-session-token", default=os.environ.get("AWS_SESSION_TOKEN")) + s3.add_argument("--aws-region", default=os.environ.get("AWS_DEFAULT_REGION", "us-east-1")) + s3.add_argument("--s3-endpoint", default=os.environ.get("AWS_ENDPOINT_URL")) + return p + + +def main(): + args = _build_parser().parse_args() + + print(f"\nleWorldModel dataloader benchmark") + print(f" batch_size : {args.batch_size}") + print(f" num_workers : {args.num_workers}") + print(f" T (frames) : {NUM_STEPS}") + print(f" warmup : {args.warmup} batches bench: {args.steps} batches") + + # Build S3 storage_options for LanceDB + storage_options = {} + if args.aws_access_key_id: + storage_options["aws_access_key_id"] = args.aws_access_key_id + if args.aws_secret_access_key: + storage_options["aws_secret_access_key"] = args.aws_secret_access_key + if args.aws_session_token: + storage_options["aws_session_token"] = args.aws_session_token + if args.aws_region: + storage_options["region"] = args.aws_region + if args.s3_endpoint: + storage_options["endpoint_url"] = args.s3_endpoint + storage_options["aws_virtual_hosted_style_request"] = "false" + + connect_kwargs = {"storage_options": storage_options} if storage_options else {} + + results = [] + + # 1. LanceDB + lance_loader = make_lewm_lance_loader( + uri=args.lance_uri, + table_name=args.table_name, + columns=args.columns, + batch_size=args.batch_size, + num_steps=NUM_STEPS, + img_size=IMAGE_SIZE, + num_workers=args.num_workers, + prefetch_factor=PREFETCH_FACTOR, + **connect_kwargs, + ) + backend = "S3" if args.lance_uri.startswith("s3://") else "local" + results.append(measure_throughput( + lance_loader, + f"LanceDB {backend} ({args.table_name})", + args.warmup, args.steps, + )) + + # 2. HDF5 local + if args.hdf5_local: + hdf5_local_loader = make_hdf5_loader( + args.hdf5_local, args.columns, args.batch_size, args.num_workers, PREFETCH_FACTOR, + ) + results.append(measure_throughput( + hdf5_local_loader, + f"HDF5 local ({os.path.basename(args.hdf5_local)})", + args.warmup, args.steps, + )) + + # 3. HDF5 via s3fs (reads directly from S3, no local copy) + if args.hdf5_s3_key and args.s3_bucket: + import s3fs + s3_kwargs = {} + if args.aws_access_key_id: + s3_kwargs["key"] = args.aws_access_key_id + if args.aws_secret_access_key: + s3_kwargs["secret"] = args.aws_secret_access_key + if args.aws_session_token: + s3_kwargs["token"] = args.aws_session_token + client_kwargs = {} + if args.aws_region: + client_kwargs["region_name"] = args.aws_region + if args.s3_endpoint: + client_kwargs["endpoint_url"] = args.s3_endpoint + if client_kwargs: + s3_kwargs["client_kwargs"] = client_kwargs + + fs = s3fs.S3FileSystem(**s3_kwargs) + s3_file = fs.open(f"{args.s3_bucket}/{args.hdf5_s3_key}", "rb") + + hdf5_s3_loader = make_hdf5_loader( + s3_file, args.columns, args.batch_size, args.num_workers, PREFETCH_FACTOR, + ) + results.append(measure_throughput( + hdf5_s3_loader, + f"HDF5 s3fs (s3://{args.s3_bucket}/{args.hdf5_s3_key})", + args.warmup, args.steps, + )) + + # Summary table + if len(results) > 1: + baseline = results[-1]["samples_sec"] + print(f"\n{'=' * 60}") + print(f" {'Backend':<46} {'samples/sec':>12} {'avg ms':>8} {'speedup':>8}") + print(f"{'─' * 60}") + for r in sorted(results, key=lambda x: -x["samples_sec"]): + speedup = r["samples_sec"] / baseline + print(f" {r['label']:<46} {r['samples_sec']:>12,.0f} {r['avg_ms']:>7.1f} {speedup:>7.1f}×") + print(f"{'=' * 60}") + + +if __name__ == "__main__": + multiprocessing.set_start_method("spawn", force=True) + main() diff --git a/examples/leWorldModel/config/lewm_pusht.yaml b/examples/leWorldModel/config/lewm_pusht.yaml new file mode 100644 index 0000000..b5640e3 --- /dev/null +++ b/examples/leWorldModel/config/lewm_pusht.yaml @@ -0,0 +1,42 @@ +# leWorldModel training config — PushT dataset +# Mirrors le-wm/config/train/lewm.yaml + data/pusht.yaml + +model: + encoder_name: "vit_tiny_patch14_224" + image_size: 224 + embed_dim: 192 + history_size: 3 + num_preds: 1 + predictor_depth: 6 + predictor_heads: 16 + predictor_mlp_dim: 2048 + proj_hidden: 2048 + # action_dim is inferred automatically from the first training batch + +data: + lance_uri: "./lewm_lance" # override with s3://bucket/prefix for cloud + table_name: "lewm_pusht" + columns: ["pixels", "action", "proprio", "state"] + num_workers: 6 + prefetch_factor: 3 + val_fraction: 0.1 + seed: 42 + +training: + batch_size: 128 + max_epochs: 100 + lr: 5.0e-5 + weight_decay: 1.0e-3 + warmup_epochs: 10 + gradient_clip: 1.0 + save_every_n_epochs: 10 + checkpoint_dir: "./checkpoints" + precision: "bf16-mixed" # bf16-mixed | 16-mixed | 32 + log_every_n_steps: 50 + +loss: + sigreg_weight: 0.09 + sigreg_knots: 17 + sigreg_num_proj: 1024 + +wandb_project: "lewm-lancedb" diff --git a/examples/leWorldModel/create_data.py b/examples/leWorldModel/create_data.py new file mode 100644 index 0000000..11732ad --- /dev/null +++ b/examples/leWorldModel/create_data.py @@ -0,0 +1,537 @@ +""" +Convert leWorldModel HDF5 datasets to LanceDB tables. + +Each LanceDB row = one timestep (same granularity as the source HDF5). +See dataset.md for full format documentation. + +Supported datasets: + reacher → ~/.stable-wm/reacher.hdf5 + cube → ~/.stable-wm/cube_single_expert.hdf5 + pusht → ~/.stable-wm/pusht_expert_train.hdf5 + tworoom → ~/.stable-wm/tworoom.hdf5 + +Usage: + # Convert all datasets to a local LanceDB store + python create_data.py --dataset all --lance-uri ./lewm_lance + + # Convert a single dataset to S3-backed LanceDB + python create_data.py --dataset pusht --lance-uri s3://my-bucket/lewm + + # Overwrite an existing table + python create_data.py --dataset pusht --overwrite + + # Convert + back-fill embeddings for vector search + python create_data.py --dataset pusht --embed --embedding-model dinov2 +""" + +import argparse +import io +import os +from collections.abc import Iterator + +import h5py +import hdf5plugin # noqa: F401 — registers HDF5 decompression filters (Blosc, Zstd, etc.) +import lancedb +import numpy as np +import pyarrow as pa +from PIL import Image +from tqdm import tqdm + + +# --------------------------------------------------------------------------- +# Dataset registry +# --------------------------------------------------------------------------- + +DATASETS = { + "reacher": { + "hdf5_file": "reacher.hdf5", + "table_name": "lewm_reacher", + "columns": ["pixels", "action", "observation"], + }, + "cube": { + "hdf5_file": "cube_single_expert.hdf5", + "table_name": "lewm_cube", + "columns": ["pixels", "action", "observation"], + }, + "pusht": { + "hdf5_file": "pusht_expert_train.hdf5", + "table_name": "lewm_pusht", + "columns": ["pixels", "action", "proprio", "state"], + }, + "tworoom": { + "hdf5_file": "tworoom.hdf5", + "table_name": "lewm_tworoom", + "columns": ["pixels", "action", "proprio"], + }, +} + +JPEG_QUALITY = 95 # 95 → ~13× smaller than raw uint8, negligible quality loss +BATCH_ROWS = 1000 # rows per RecordBatch yielded to LanceDB + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _to_jpeg_bytes(frame: np.ndarray) -> bytes: + """(C,H,W) or (H,W,C) uint8 ndarray → JPEG-compressed bytes.""" + if frame.ndim == 3 and frame.shape[0] in (1, 3, 4): + frame = np.transpose(frame, (1, 2, 0)) # (C,H,W) → (H,W,C) + buf = io.BytesIO() + Image.fromarray(frame.astype(np.uint8)).save(buf, format="JPEG", quality=JPEG_QUALITY) + return buf.getvalue() + + +def _infer_episode_step(f: h5py.File, total: int) -> tuple[np.ndarray, np.ndarray]: + """ + Return (episode_idx_arr, step_idx_arr) as int32 arrays from an open HDF5 file. + + The stable-worldmodel HDF5 format stores per-episode metadata: + ep_len — int array of shape (n_episodes,) with the length of each episode + ep_offset — int array of shape (n_episodes,) with the global start row of each episode + + These are expanded into per-row arrays for the LanceDB schema. + """ + ep_len = np.array(f["ep_len"], dtype=np.int32) + ep_offset = np.array(f["ep_offset"], dtype=np.int32) + + episode_idx = np.zeros(total, dtype=np.int32) + step_idx = np.zeros(total, dtype=np.int32) + for i, (off, length) in enumerate(zip(ep_offset.tolist(), ep_len.tolist())): + episode_idx[off : off + length] = i + step_idx[off : off + length] = np.arange(length, dtype=np.int32) + + return episode_idx, step_idx + + +def _build_schema(columns: list[str], dims: dict[str, int]) -> pa.Schema: + """Build a PyArrow schema for a leWorldModel table.""" + fields = [ + pa.field("episode_idx", pa.int32()), + pa.field("step_idx", pa.int32()), + pa.field("pixels", pa.binary()), + pa.field("pixels_h", pa.int16()), + pa.field("pixels_w", pa.int16()), + ] + for col in columns: + if col == "pixels": + continue + fields.append(pa.field(col, pa.list_(pa.float32(), dims[col]))) + return pa.schema(fields) + + +def _record_batch_reader( + f: h5py.File, + columns: list[str], + episode_arr: np.ndarray, + step_arr: np.ndarray, + h: int, + w: int, + schema: pa.Schema, +) -> pa.RecordBatchReader: + """ + Return a pa.RecordBatchReader that streams the HDF5 file in BATCH_ROWS chunks. + + Using a RecordBatchReader rather than repeated table.add() calls lets + LanceDB write all data in a single pass through the file without accumulating + large in-memory lists, and avoids creating many small Lance fragments. + """ + total = len(episode_arr) + non_pixel_cols = [c for c in columns if c != "pixels"] + + def _generate() -> Iterator[pa.RecordBatch]: + # Buffers — reset every BATCH_ROWS rows + ep_buf: list[int] = [] + st_buf: list[int] = [] + px_buf: list[bytes] = [] + ph_buf: list[int] = [] + pw_buf: list[int] = [] + col_bufs: dict[str, list[list[float]]] = {c: [] for c in non_pixel_cols} + + for idx in tqdm(range(total), desc=" Converting", unit="step"): + ep_buf.append(int(episode_arr[idx])) + st_buf.append(int(step_arr[idx])) + px_buf.append(_to_jpeg_bytes(np.array(f["pixels"][idx]))) + ph_buf.append(h) + pw_buf.append(w) + for col in non_pixel_cols: + col_bufs[col].append(np.array(f[col][idx], dtype=np.float32).flatten().tolist()) + + if len(ep_buf) == BATCH_ROWS: + yield _make_batch(ep_buf, st_buf, px_buf, ph_buf, pw_buf, col_bufs, schema) + ep_buf, st_buf, px_buf, ph_buf, pw_buf = [], [], [], [], [] + col_bufs = {c: [] for c in non_pixel_cols} + + if ep_buf: + yield _make_batch(ep_buf, st_buf, px_buf, ph_buf, pw_buf, col_bufs, schema) + + return pa.RecordBatchReader.from_batches(schema, _generate()) + + +def _make_batch( + ep_buf: list[int], + st_buf: list[int], + px_buf: list[bytes], + ph_buf: list[int], + pw_buf: list[int], + col_bufs: dict[str, list[list[float]]], + schema: pa.Schema, +) -> pa.RecordBatch: + arrays = [ + pa.array(ep_buf, type=pa.int32()), + pa.array(st_buf, type=pa.int32()), + pa.array(px_buf, type=pa.binary()), + pa.array(ph_buf, type=pa.int16()), + pa.array(pw_buf, type=pa.int16()), + ] + for col in col_bufs: + field_type = schema.field(col).type # fixed_size_list[D] + arrays.append(pa.array(col_bufs[col], type=field_type)) + return pa.RecordBatch.from_arrays(arrays, schema=schema) + + +# --------------------------------------------------------------------------- +# Core conversion +# --------------------------------------------------------------------------- + +def convert_dataset( + dataset_name: str, + hdf5_dir: str, + lance_uri: str, + overwrite: bool = False, + connect_kwargs: dict | None = None, +): + cfg = DATASETS[dataset_name] + hdf5_path = os.path.join(hdf5_dir, cfg["hdf5_file"]) + table_name = cfg["table_name"] + columns = cfg["columns"] + connect_kwargs = connect_kwargs or {} + + if not os.path.exists(hdf5_path): + raise FileNotFoundError( + f"HDF5 not found: {hdf5_path}\n" + "Download with: " + f"python -c \"import stable_worldmodel as swm; swm.data.download('{dataset_name}')\"" + ) + + print(f"\n{'=' * 60}") + print(f"Dataset : {dataset_name}") + print(f"HDF5 : {hdf5_path}") + print(f"Lance : {lance_uri} (table={table_name})") + print(f"{'=' * 60}") + + db = lancedb.connect(lance_uri, **connect_kwargs) + + if table_name in db.table_names(): + if overwrite: + print(f" Dropping existing table '{table_name}'...") + db.drop_table(table_name) + else: + print(f" Table '{table_name}' already exists. Use --overwrite to recreate.") + return + + with h5py.File(hdf5_path, "r") as f: + total = len(f["pixels"]) + episode_arr, step_arr = _infer_episode_step(f, total) + n_episodes = int(episode_arr.max()) + 1 + + print(f" Steps : {total:,}") + print(f" Episodes : {n_episodes:,}") + + # Determine vector dims from first row + dims: dict[str, int] = {} + for col in columns: + if col == "pixels": + continue + dims[col] = int(np.array(f[col][0], dtype=np.float32).flatten().shape[0]) + print(f" {col:<14}: dim={dims[col]}") + + # Determine pixel dimensions + sample_frame = np.array(f["pixels"][0]) + if sample_frame.ndim == 3 and sample_frame.shape[0] in (1, 3, 4): + _, h, w = sample_frame.shape + else: + h, w = sample_frame.shape[:2] + print(f" pixels : ({h} × {w}) → JPEG quality={JPEG_QUALITY}") + + schema = _build_schema(columns, dims) + reader = _record_batch_reader(f, columns, episode_arr, step_arr, h, w, schema) + + # Single create_table call — LanceDB reads from the reader in streaming fashion. + # This produces one Lance fragment per BATCH_ROWS rows, then compacts. + db.create_table(table_name, data=reader, schema=schema) + + final_count = len(db.open_table(table_name)) + print(f" Done! {final_count:,} rows written to '{table_name}'.") + assert final_count == total, f"Row count mismatch: wrote {final_count}, expected {total}" + + +# --------------------------------------------------------------------------- +# Embedding back-fill via LanceDB Geneva (Enterprise feature engineering) +# --------------------------------------------------------------------------- +# +# WHEN to generate embeddings and WHICH model to use: +# +# PRE-TRAINING (before training LeWM — for EDA, clustering, quality filtering): +# Use a frozen foundation model: DINOv2 or CLIP. +# The encoder requires no training — embeddings are semantically meaningful +# from day one. Good for: +# - Clustering frames to discover sub-behaviours +# - Detecting near-duplicate / degenerate episodes before wasting GPU time +# - Curriculum design: embed → cluster → order episodes by difficulty +# - Goal-state retrieval using natural-language queries (CLIP only) +# +# POST-TRAINING (after training LeWM — for analysis of the learned world model): +# Use the trained LeWM encoder (CLS token of the ViT). The latent space now +# reflects the dynamics your model has learned, not just visual similarity. +# Good for: +# - Validating that the encoder separates distinct behaviours +# - ANN retrieval of states that "look the same to the world model" +# - Debugging: find states the model consistently mispredicts +# +# Using LeWM embeddings BEFORE training gives meaningless results — +# the encoder is randomly initialised. +# +# HOW: LanceDB Geneva (LanceDB Enterprise) +# Geneva's UDF API replaces the manual encode-loop-then-merge pattern with: +# 1. Define a stateful GPU UDF (@udf class with setup() + __call__()) +# 2. Register it as a column: tbl.add_columns({"emb_X": MyUDF()}) +# 3. Backfill: tbl.backfill("emb_X", batch_size=32) +# or async with progress: tbl.backfill_async("emb_X", concurrency=4) +# Geneva handles batching, GPU process concurrency, partial commits, and +# incremental re-runs (where="emb_X is null") automatically. +# +# --------------------------------------------------------------------------- + +def add_embeddings_geneva( + lance_uri: str, + table_name: str, + model_name: str = "dinov2", + checkpoint: str | None = None, + batch_size: int = 32, + img_size: int = 224, + concurrency: int = 1, + connect_kwargs: dict | None = None, +): + """ + Add a frame embedding column to a LanceDB table using Geneva UDFs. + + Requires LanceDB Enterprise with the `geneva` package installed. + Geneva handles batching, GPU concurrency, partial commits, and incremental + re-runs — no manual encode loop needed. + + Args: + model_name: "dinov2" | "clip" | "lewm" + dinov2/clip → pre-training EDA (frozen foundation model) + lewm → post-training analysis (requires checkpoint) + checkpoint: Path to a trained LeWM .ckpt file (only for model_name="lewm"). + batch_size: Frames per UDF call (tune to fit GPU VRAM). + concurrency: Number of parallel GPU worker processes for backfill. + """ + import geneva + import pyarrow as pa + + connect_kwargs = connect_kwargs or {} + conn = geneva.connect(lance_uri, **connect_kwargs) + tbl = conn.open_table(table_name) + + col_name = f"emb_{model_name}" + if col_name in tbl.schema.names: + print(f" '{col_name}' column already present. Skipping.") + return + + # Build the UDF class for the chosen model + udf_cls = _make_embedding_udf(model_name, checkpoint, img_size) + + print(f" Registering '{col_name}' UDF ({model_name})...") + tbl.add_columns({col_name: udf_cls()}) + + print(f" Starting backfill (concurrency={concurrency}, batch_size={batch_size})...") + fut = tbl.backfill( + col_name, + batch_size=batch_size, + concurrency=concurrency, + # Only process rows that don't have embeddings yet — safe to re-run + where=f"{col_name} IS NULL", + ) + + print(f" Backfill complete. Building IVF-PQ vector index on '{col_name}'...") + tbl.create_index( + column=col_name, + index_type="IVF_PQ", + num_partitions=64, + num_sub_vectors=16, + ) + print(f" Done! ANN search: tbl.search(vec, vector_column_name='{col_name}').limit(10)") + + +def _make_embedding_udf(model_name: str, checkpoint: str | None, img_size: int): + """ + Return a Geneva UDF class that encodes JPEG-binary frames with the chosen model. + + Geneva UDFs are stateful classes: + setup() — called once per worker process to load model weights + __call__() — called per row (or per batch if batch_size > 1) + + The UDF is decorated with @geneva.udf(data_type=...) to declare the output + Arrow type so Geneva can build the schema before any data is processed. + """ + import geneva + import io as _io + import numpy as _np + from PIL import Image as _Image + from torchvision import transforms as _transforms + + _transform = _transforms.Compose([ + _transforms.Resize((img_size, img_size)), + _transforms.ToTensor(), + _transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ]) + + if model_name == "dinov2": + EMBED_DIM = 384 + + @geneva.udf(data_type=pa.list_(pa.float32(), EMBED_DIM)) + class DINOv2Embedder: + def setup(self): + import timm, torch + self.model = timm.create_model( + "vit_small_patch14_dinov2.lvd142m", pretrained=True, num_classes=0 + ).cuda().eval() + self.torch = torch + + def __call__(self, pixel_bytes: bytes) -> list[float]: + img = _Image.open(_io.BytesIO(pixel_bytes)).convert("RGB") + t = _transform(img).unsqueeze(0).cuda() + with self.torch.no_grad(): + return self.model(t)[0].cpu().tolist() + + return DINOv2Embedder + + if model_name == "clip": + EMBED_DIM = 512 + + @geneva.udf(data_type=pa.list_(pa.float32(), EMBED_DIM)) + class CLIPEmbedder: + def setup(self): + import clip, torch + self.model, self.preprocess = clip.load("ViT-B/32", device="cuda") + self.model.eval() + self.torch = torch + + def __call__(self, pixel_bytes: bytes) -> list[float]: + img = _Image.open(_io.BytesIO(pixel_bytes)).convert("RGB") + t = self.preprocess(img).unsqueeze(0).cuda() + with self.torch.no_grad(): + return self.model.encode_image(t)[0].cpu().float().tolist() + + return CLIPEmbedder + + if model_name == "lewm": + assert checkpoint, "--checkpoint is required for --embedding-model lewm" + _ckpt = checkpoint + + @geneva.udf(data_type=pa.list_(pa.float32(), 192)) # ViT-tiny embed_dim + class LeWMEmbedder: + def setup(self): + import torch + model = torch.load(_ckpt, map_location="cuda") + model.eval() + self.encoder = model.encoder + self.torch = torch + + def __call__(self, pixel_bytes: bytes) -> list[float]: + img = _Image.open(_io.BytesIO(pixel_bytes)).convert("RGB") + t = _transform(img).unsqueeze(0).cuda() + with self.torch.no_grad(): + out = self.encoder(t) + return out[0, 0, :].cpu().tolist() # CLS token + + return LeWMEmbedder + + raise ValueError(f"Unknown model_name: {model_name!r}. Choose dinov2 | clip | lewm") + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + +def main(): + parser = argparse.ArgumentParser( + description="Convert leWorldModel HDF5 datasets to LanceDB tables" + ) + parser.add_argument( + "--dataset", + choices=list(DATASETS.keys()) + ["all"], + default="all", + help="Dataset to convert (default: all)", + ) + parser.add_argument( + "--hdf5-dir", + default=os.path.expanduser("~/.stable-wm"), + help="Directory containing .hdf5 files (default: ~/.stable-wm)", + ) + parser.add_argument( + "--lance-uri", + default="./lewm_lance", + help="LanceDB URI — local path or s3://bucket/prefix (default: ./lewm_lance)", + ) + parser.add_argument( + "--overwrite", + action="store_true", + help="Drop and recreate existing tables", + ) + parser.add_argument( + "--embed", + action="store_true", + help="Add an embedding column for vector search after conversion", + ) + parser.add_argument( + "--embedding-model", + choices=["clip", "dinov2", "lewm"], + default="dinov2", + help=( + "Which vision model to use for embeddings.\n" + " dinov2 — Meta DINOv2 ViT-S/14, pre-trained (best for pre-training EDA)\n" + " clip — OpenAI CLIP ViT-B/32, pre-trained (supports text queries)\n" + " lewm — Trained LeWM encoder (post-training analysis only, needs --checkpoint)" + ), + ) + parser.add_argument( + "--checkpoint", + metavar="CKPT", + help="Path to a trained LeWM object checkpoint (required when --embedding-model=lewm)", + ) + parser.add_argument( + "--embed-batch-size", + type=int, + default=256, + help="Frames per GPU forward pass during embedding generation (default: 256)", + ) + args = parser.parse_args() + + datasets_to_run = list(DATASETS.keys()) if args.dataset == "all" else [args.dataset] + + for ds_name in datasets_to_run: + convert_dataset( + dataset_name=ds_name, + hdf5_dir=args.hdf5_dir, + lance_uri=args.lance_uri, + overwrite=args.overwrite, + ) + + if args.embed: + for ds_name in datasets_to_run: + add_embeddings_geneva( + lance_uri=args.lance_uri, + table_name=DATASETS[ds_name]["table_name"], + model_name=args.embedding_model, + checkpoint=args.checkpoint, + batch_size=args.embed_batch_size, + ) + + print("\nAll done!") + + +if __name__ == "__main__": + main() diff --git a/examples/leWorldModel/dataset.md b/examples/leWorldModel/dataset.md new file mode 100644 index 0000000..ed0d737 --- /dev/null +++ b/examples/leWorldModel/dataset.md @@ -0,0 +1,189 @@ +# Dataset Format: HDF5 vs LanceDB + +This document describes the original leWorldModel HDF5 dataset format, what we +store in LanceDB, and why the two differ in specific ways. + +--- + +## Original HDF5 Format + +leWorldModel datasets are produced by the `stable-worldmodel` package and +downloaded into `$STABLEWM_HOME` (default: `~/.stable-wm/`). Each dataset is +a **single monolithic HDF5 file** (e.g. `pusht_expert_train.hdf5`). + +### Structure + +HDF5 stores data as **flat arrays at the file root** — one row per timestep, +all episodes concatenated sequentially. There is no nesting. + +``` +pusht_expert_train.hdf5 +├── pixels (N, C, H, W) uint8 — raw pixel frames +├── action (N, A) float32 — continuous action vectors +├── proprio (N, P) float32 — proprioceptive state +├── state (N, S) float32 — full simulator state +├── episode_idx (N,) int32 — which episode each row belongs to +└── step_idx (N,) int32 — step counter within episode +``` + +`N` = total number of timesteps across all episodes. Episodes are contiguous: +all rows for episode 0 come first, then episode 1, etc. + +`episode_idx` and `step_idx` are index columns, not observation data — they +exist purely to let you reconstruct episode boundaries. + +### Per-dataset column map + +| Dataset file | pixels shape | Columns present | +|-------------------------------|-------------------|-------------------------------------------| +| `reacher.hdf5` | (N, 3, H, W) | pixels, action, observation | +| `cube_single_expert.hdf5` | (N, 3, H, W) | pixels, action, observation | +| `pusht_expert_train.hdf5` | (N, 3, H, W) | pixels, action, proprio, state | +| `tworoom.hdf5` | (N, 3, H, W) | pixels, action, proprio | + +`H` and `W` vary by dataset but are typically 64 or 84 pixels in the raw files; +`train.py` resizes to 224×224 at load time. + +### Episode boundaries + +The `stable_worldmodel.data.HDF5Dataset` class reconstructs episode windows by +scanning `episode_idx` and building a mapping from `(episode, local_step)` to +the global row index `i`. Training samples T consecutive rows within the same +episode: + +``` +global_row = episode_start[ep] + local_step +window = [global_row, global_row+1, ..., global_row+T-1] +``` + +--- + +## LanceDB Format + +### What changes — and what stays the same + +The row granularity is **identical**: one LanceDB row = one HDF5 row = one +timestep. We do not reshape, aggregate, or split the data differently. + +What does change: + +| Aspect | HDF5 | LanceDB | +|--------|------|---------| +| **Storage layout** | Monolithic `.h5` file | Columnar Lance fragments (local or S3) | +| **Pixel encoding** | Raw uint8 (C, H, W) | JPEG-compressed binary (quality=95) | +| **Vector columns** | float32 ndarrays | `fixed_size_list` in Arrow schema | +| **Index columns** | `episode_idx`, `step_idx` arrays | Same, stored as int32 Arrow columns | +| **Reads** | `h5py.File[col][i]` — single-threaded | `Permutation.__getitems__` — parallel, multi-worker safe | +| **Partial column reads** | Requires loading the full compound dataset | Native: `.select_columns(["action"])` reads only that column | + +### Schema (example: pusht) + +``` +episode_idx int32 +step_idx int32 +pixels binary ← JPEG bytes (not raw uint8) +pixels_h int16 ← stored so decode knows H +pixels_w int16 ← stored so decode knows W +action fixed_size_list[A] +proprio fixed_size_list[P] +state fixed_size_list[S] +``` + +After running `create_data.py --embed`, embedding columns are appended: + +``` +emb_dinov2 fixed_size_list[384] ← DINOv2 ViT-S/14 (pre-training EDA) +emb_clip fixed_size_list[512] ← CLIP ViT-B/32 (text queries) +emb_lewm fixed_size_list[192] ← Trained LeWM CLS token (post-training) +``` + +None of these columns are in the original HDF5 files. + +### Why JPEG for pixels, and does it hurt training speed? + +**The storage case** is clear: raw uint8 RGB at 224×224 is 150 KB per frame. +At JPEG quality=95 the same frame compresses to ~10-15 KB — a 10-15× reduction +with negligible perceptual quality loss for vision model training. + +**The compute tradeoff** is real but net-positive when combined with +LanceDB's multi-worker DataLoader: + +| | Raw uint8 in HDF5 | JPEG binary in LanceDB | +|--|--|--| +| Bytes transferred from storage per frame | 150 KB (raw float32: 600 KB) | 12 KB | +| I/O time (NVMe, 3 GB/s) per 1000 frames | ~50 ms | ~4 ms | +| JPEG decode time per 1000 frames (CPU) | — | ~30 ms (libturbo) | +| Decode happens in | main process | DataLoader worker (parallel) | + +The decode cost (~30 ms/1000 frames on a single CPU core) is hidden because: + +1. **Parallel decoding across workers.** With `num_workers=6` each worker + decodes its own subset of JPEG frames. The GPU training step and worker + decoding overlap in time — decoding is not on the critical path. + +2. **I/O dominates for large datasets, especially on S3.** When data lives on + S3 (which is the point of using LanceDB), the network round-trip to fetch + 150 KB vs 12 KB per frame is 12× slower. JPEG decode is cheap compared to + cross-AZ network I/O. + +3. **HDF5 single-threaded read negates its "no decode" advantage.** HDF5 reads + are serialized through a file lock. Even though there's no decode step, all + 8 DataLoader workers queue behind the same lock. In practice the leWorldModel + HDF5 pipeline runs roughly 2-3 workers effectively regardless of how many you + spawn. LanceDB workers never contend with each other. + +**When raw uint8 would be better:** if your entire dataset fits in RAM and you +use a RAM disk, raw storage avoids the decode cost. For datasets that fit in +`/dev/shm` (tens of GB) this is a valid strategy. LanceDB supports this too — +you can store raw `pa.binary()` frames and handle encode/decode yourself. + +**Bottom line:** for typical leWorldModel dataset sizes (gigabytes to tens of +gigabytes), training on JPEG-compressed LanceDB tables matches or exceeds the +throughput of HDF5, because the 8× I/O reduction and parallelism gains exceed +the added decode cost. The ViT MFU benchmarks in `examples/ViT/` confirm this: +LanceDB with JPEG achieves 37-39% MFU on H200 vs 13% for raw S3 object storage. + +### Why `fixed_size_list` instead of a flat float column? + +Arrow `fixed_size_list[D]` lets you: +- Read an entire column as a 2D numpy array in one zero-copy call +- Use it as a vector column for ANN indexing after adding embeddings +- Keep schema self-describing (dimension D is part of the type) + +A plain `list` (variable-length) would also work but is slower to +convert to numpy because Arrow cannot guarantee contiguous memory layout. + +--- + +## Episode Boundary Handling + +Both HDF5 and LanceDB store episodes as contiguous row ranges. + +`LeWMLanceDataset` reconstructs valid window positions at init time by loading +`(episode_idx, step_idx)` into two numpy arrays (~16 bytes/row, trivial even at +1M steps) and checking: + +```python +same_ep = episode_idx[i+offset] == episode_idx[i] # same episode +consec = step_idx[i+offset] == step_idx[i] + offset # consecutive steps +valid[i] = all(same_ep and consec for offset in 1..T) +``` + +This is equivalent to what `HDF5Dataset` does internally, but exposed +explicitly so the dataset object knows all valid start rows upfront. + +The precomputed `_window_starts` array is a plain int64 numpy array stored on +the dataset object — it is pickled safely to DataLoader workers. + +--- + +## What is NOT changed + +- **Training sample format**: each sample is still `{pixels: (T,C,H,W), action: (T,A), ...}` +- **Normalization**: z-score per column, computed on train episodes only (same as `get_column_normalizer` in `le-wm/utils.py`) +- **Image preprocessing**: ImageNet normalization + resize to 224×224 (same as `get_img_preprocessor`) +- **Episode-level train/val split**: 90/10 by default with fixed seed +- **NaN handling**: `nan_to_num(nan=0.0)` on action boundaries (same as original) + +The LanceDB pipeline is a direct structural equivalent of the HDF5 pipeline +with a different I/O backend. diff --git a/examples/leWorldModel/eda_analysis.py b/examples/leWorldModel/eda_analysis.py new file mode 100644 index 0000000..df14511 --- /dev/null +++ b/examples/leWorldModel/eda_analysis.py @@ -0,0 +1,490 @@ +""" +leWorldModel × LanceDB: EDA, analysis, splits, and vector search. + +Run sections independently or top-to-bottom: + python eda_analysis.py --lance-uri ./lewm_lance --table lewm_pusht + +Sections: + 1. Dataset statistics – action/proprio distributions, episode lengths + 2. Episode-level splits – clean train/val/test splits stored as metadata + 3. Temporal coherence checks – verify no off-by-one leakage across episodes + 4. Vector search – ANN search over frame embeddings + 5. Cross-episode retrieval – find episodes with similar goal states + 6. Action entropy analysis – identify high/low diversity episodes + 7. Data quality scan – detect NaN, frozen frames, degenerate episodes + 8. LanceDB vs HDF5 comparison + +Which embedding column to use for vector search (sections 4 & 5): + emb_dinov2 – DINOv2 ViT-S/14 embeddings (best for pre-training EDA) + emb_clip – CLIP ViT-B/32 embeddings (supports text-to-state queries) + emb_lewm – Trained LeWM encoder embeddings (post-training analysis only) + + Add embeddings with: + python create_data.py --embed --embedding-model dinov2 --dataset pusht +""" + +import argparse + +import lancedb +import numpy as np +import pyarrow as pa +import pyarrow.compute as pc + + +# ============================================================================ +# 1. Dataset statistics +# ============================================================================ + +def dataset_statistics(tbl: lancedb.table.Table): + print("\n" + "=" * 60) + print("1. DATASET STATISTICS") + print("=" * 60) + + schema = tbl.schema + total_rows = len(tbl) + + # Read only the two index columns — negligible memory + idx_arrow = tbl.to_arrow(columns=["episode_idx"]) + n_episodes = len(pc.unique(idx_arrow["episode_idx"])) + + print(f"\nTotal timesteps : {total_rows:,}") + print(f"Total episodes : {n_episodes:,}") + print(f"Avg steps/ep : {total_rows / n_episodes:.1f}") + print(f"\nSchema:\n{schema}\n") + + # Per-column stats for fixed-size list columns (skip pixels and embeddings) + list_cols = [ + f.name for f in schema + if (pa.types.is_list(f.type) or pa.types.is_fixed_size_list(f.type)) + and f.name not in ("pixels",) + and not f.name.startswith("emb_") + ] + if not list_cols: + return + + arrow = tbl.to_arrow(columns=list_cols) + for col in list_cols: + data = np.stack([row.as_py() for row in arrow[col]], axis=0).astype(np.float32) + valid = ~np.isnan(data).any(axis=1) + data = data[valid] + print(f" {col:<14} dim={data.shape[1]:3d} | " + f"mean={data.mean():+.4f} std={data.std():.4f} " + f"min={data.min():+.4f} max={data.max():+.4f} " + f"NaN rows={(~valid).sum()}") + + # Episode length distribution + ep_arr = idx_arrow["episode_idx"].to_numpy() + _, counts = np.unique(ep_arr, return_counts=True) + print(f"\nEpisode length min={counts.min()} max={counts.max()} " + f"median={int(np.median(counts))} std={counts.std():.1f}") + + +# ============================================================================ +# 2. Episode-level train / val / test splits +# ============================================================================ + +def create_splits( + tbl: lancedb.table.Table, + train: float = 0.8, + val: float = 0.1, + test: float = 0.1, + seed: int = 42, +) -> dict[str, np.ndarray]: + """ + Assign each episode to train/val/test. + + With LanceDB you can use these episode IDs as a filter at training time — + no need to copy or materialise new tables: + train_arrow = tbl.to_arrow( + columns=["pixels", "action"], + filter=f"episode_idx IN {tuple(splits['train'].tolist())}", + ) + """ + print("\n" + "=" * 60) + print("2. EPISODE-LEVEL SPLITS") + print("=" * 60) + + ep_arr = tbl.to_arrow(columns=["episode_idx"])["episode_idx"].to_numpy() + all_eps = np.unique(ep_arr) + rng = np.random.default_rng(seed) + rng.shuffle(all_eps) + + n = len(all_eps) + n_train = int(n * train) + n_val = int(n * val) + + splits = { + "train": all_eps[:n_train], + "val": all_eps[n_train : n_train + n_val], + "test": all_eps[n_train + n_val :], + } + + for name, eps in splits.items(): + ep_mask = np.isin(ep_arr, eps) + print(f" {name:<6}: {len(eps):5,} episodes {ep_mask.sum():8,} timesteps") + + print("\n To use a split in training, filter with:") + print(" tbl.to_arrow(columns=[...], filter='episode_idx IN (0,1,...)')") + + return splits + + +# ============================================================================ +# 3. Temporal coherence check +# ============================================================================ + +def temporal_coherence_check(tbl: lancedb.table.Table): + """ + Verify that episodes are stored contiguously and step indices are + monotonically increasing. Detects truncated or merged episodes. + """ + print("\n" + "=" * 60) + print("3. TEMPORAL COHERENCE CHECK") + print("=" * 60) + + idx = tbl.to_arrow(columns=["episode_idx", "step_idx"]) + ep = idx["episode_idx"].to_numpy() + step = idx["step_idx"].to_numpy() + + ep_changes = np.where(np.diff(ep) != 0)[0] + episode_starts = np.concatenate([[0], ep_changes + 1, [len(ep)]]) + seen_episodes: set[int] = set() + non_contiguous = 0 + for i in range(len(episode_starts) - 1): + eid = ep[episode_starts[i]] + if eid in seen_episodes: + non_contiguous += 1 + seen_episodes.add(eid) + + if non_contiguous == 0: + print(f" [OK] All {len(seen_episodes):,} episodes stored contiguously.") + else: + print(f" [WARN] {non_contiguous} episodes appear in non-contiguous blocks.") + + bad_resets = sum( + 1 for i in range(len(episode_starts) - 1) + if step[episode_starts[i]] != 0 + ) + if bad_resets == 0: + print(" [OK] All episodes start at step_idx = 0.") + else: + print(f" [WARN] {bad_resets} episodes do not start at step_idx = 0.") + + non_mono = sum( + 1 for i in range(len(episode_starts) - 1) + if not np.all(np.diff(step[episode_starts[i]:episode_starts[i + 1]]) == 1) + ) + if non_mono == 0: + print(" [OK] All step indices monotonically increase by 1.") + else: + print(f" [WARN] {non_mono} episodes have non-unit step increments.") + + +# ============================================================================ +# 4. Vector search over frame embeddings +# ============================================================================ + +def vector_search_demo( + tbl: lancedb.table.Table, + emb_col: str = "emb_dinov2", + query_episode: int = 0, + query_step: int = 0, + top_k: int = 10, +): + """ + Find the top_k most similar frames to a query frame (by ANN in embedding space). + + emb_col choices: + emb_dinov2 – for pre-training EDA; semantically meaningful out of the box + emb_clip – for pre-training EDA; also supports text-to-state queries + emb_lewm – for post-training analysis; reflects the learned dynamics + + Requires: python create_data.py --embed --embedding-model {dinov2|clip|lewm} + """ + print("\n" + "=" * 60) + print(f"4. VECTOR SEARCH (column: {emb_col})") + print("=" * 60) + + if emb_col not in tbl.schema.names: + print(f" [SKIP] '{emb_col}' column not found.") + which = emb_col.replace("emb_", "") + print(f" Add it with: python create_data.py --embed --embedding-model {which}") + return + + # Fetch query embedding using a filter — no full-table scan + query_arrow = ( + tbl.search() + .where(f"episode_idx = {query_episode} AND step_idx = {query_step}") + .select([emb_col]) + .limit(1) + .to_arrow() + ) + if len(query_arrow) == 0: + print(f" [SKIP] No row for episode={query_episode}, step={query_step}") + return + + query_emb = np.array(query_arrow[emb_col][0].as_py(), dtype=np.float32) + print(f" Query: episode={query_episode}, step={query_step} (dim={len(query_emb)})") + + results = ( + tbl.search(query_emb.tolist(), vector_column_name=emb_col) + .limit(top_k) + .select(["episode_idx", "step_idx", "_distance"]) + .to_arrow() + ) + + print(f"\n Top-{top_k} nearest neighbors:") + print(f" {'episode_idx':>12} {'step_idx':>10} {'distance':>10}") + for row in results.to_pylist(): + print(f" {row['episode_idx']:>12} {row['step_idx']:>10} {row['_distance']:>10.4f}") + + +# ============================================================================ +# 5. Cross-episode retrieval +# ============================================================================ + +def episode_retrieval_demo( + tbl: lancedb.table.Table, + emb_col: str = "emb_dinov2", + target_episode: int = 0, + top_k: int = 5, +): + """ + Represent each episode as its mean frame embedding, then rank all episodes + by cosine similarity to the target episode. + + Use cases: + - Curriculum learning: order episodes by difficulty (distance from mean) + - Deduplication: detect near-identical demonstrations + - Retrieval-augmented planning: find past episodes like the current state + """ + print("\n" + "=" * 60) + print(f"5. CROSS-EPISODE RETRIEVAL (column: {emb_col})") + print("=" * 60) + + if emb_col not in tbl.schema.names: + print(f" [SKIP] '{emb_col}' column not found.") + return + + arrow = tbl.to_arrow(columns=["episode_idx", emb_col]) + ep_arr = arrow["episode_idx"].to_numpy() + emb_arr = np.stack([row.as_py() for row in arrow[emb_col]], axis=0).astype(np.float32) + + unique_eps = np.unique(ep_arr) + ep_means = {ep: emb_arr[ep_arr == ep].mean(axis=0) for ep in unique_eps} + + query_mean = ep_means[target_episode] + all_eps = np.array(list(ep_means.keys())) + all_embs = np.stack(list(ep_means.values()), axis=0) + sims = (all_embs @ query_mean) / ( + np.linalg.norm(all_embs, axis=1) * np.linalg.norm(query_mean) + 1e-8 + ) + # Sort descending, skip self + order = np.argsort(-sims) + order = order[all_eps[order] != target_episode] + + print(f"\n Query episode: {target_episode}") + print(f" {'episode':>10} {'cosine_sim':>12}") + for idx in order[:top_k]: + print(f" {all_eps[idx]:>10} {sims[idx]:>12.4f}") + + +# ============================================================================ +# 6. Action entropy analysis +# ============================================================================ + +def action_entropy_analysis(tbl: lancedb.table.Table, top_k: int = 5): + """ + Compute per-episode action entropy as a proxy for behavioural diversity. + + High entropy → varied actions → good for diverse training + Low entropy → repetitive trajectories → candidate for deduplication + """ + print("\n" + "=" * 60) + print("6. ACTION ENTROPY ANALYSIS") + print("=" * 60) + + if "action" not in tbl.schema.names: + print(" [SKIP] No action column.") + return + + arrow = tbl.to_arrow(columns=["episode_idx", "action"]) + ep_arr = arrow["episode_idx"].to_numpy() + act_arr = np.stack([row.as_py() for row in arrow["action"]], axis=0).astype(np.float32) + + unique_eps = np.unique(ep_arr) + entropies = { + ep: float(np.log(act_arr[ep_arr == ep].std(axis=0) + 1e-8).mean()) + for ep in unique_eps + } + + sorted_eps = sorted(entropies.items(), key=lambda x: x[1]) + + print(f"\n Least diverse (lowest action entropy):") + for ep, ent in sorted_eps[:top_k]: + print(f" episode {ep:5d}: entropy = {ent:.4f}") + + print(f"\n Most diverse (highest action entropy):") + for ep, ent in sorted_eps[-top_k:][::-1]: + print(f" episode {ep:5d}: entropy = {ent:.4f}") + + return entropies + + +# ============================================================================ +# 7. Data quality scan +# ============================================================================ + +def data_quality_scan(tbl: lancedb.table.Table): + """ + Scan for: NaN values, degenerate short episodes, and pixel column presence. + """ + print("\n" + "=" * 60) + print("7. DATA QUALITY SCAN") + print("=" * 60) + + schema = tbl.schema + total = len(tbl) + + # NaN scan — only for non-pixel, non-embedding vector columns + list_cols = [ + f.name for f in schema + if (pa.types.is_list(f.type) or pa.types.is_fixed_size_list(f.type)) + and f.name not in ("pixels",) + and not f.name.startswith("emb_") + ] + for col in list_cols: + arrow = tbl.to_arrow(columns=[col]) + data = np.stack([row.as_py() for row in arrow[col]], axis=0).astype(np.float32) + n_nan = int(np.isnan(data).any(axis=1).sum()) + pct = 100 * n_nan / total + flag = "[WARN]" if pct > 5 else "[OK] " + print(f" {flag} {col:<14} NaN rows: {n_nan:,} ({pct:.1f}%)") + + # Degenerate episode check + ep_arr = tbl.to_arrow(columns=["episode_idx"])["episode_idx"].to_numpy() + _, counts = np.unique(ep_arr, return_counts=True) + short_eps = int((counts < 4).sum()) + if short_eps == 0: + print(f" [OK] All {len(counts):,} episodes have ≥ 4 steps (suitable for T=4 windows).") + else: + print(f" [WARN] {short_eps} episodes have < 4 steps — they produce no valid training windows.") + + # Embedding columns present + emb_cols = [f.name for f in schema if f.name.startswith("emb_")] + if emb_cols: + print(f"\n Embedding columns present: {emb_cols}") + print(" Vector search is available on these columns.") + else: + print("\n No embedding columns. Run: python create_data.py --embed --embedding-model dinov2") + + +# ============================================================================ +# 8. LanceDB vs HDF5 comparison +# ============================================================================ + +def print_lancedb_vs_hdf5(): + comparison = """ +╔══════════════════════════════╦════════════════════════════╦════════════════════════════╗ +║ Feature ║ LanceDB ║ HDF5 ║ +╠══════════════════════════════╬════════════════════════════╬════════════════════════════╣ +║ Random row access ║ O(1) via Permutation ║ O(1) but single-threaded ║ +║ Columnar reads ║ Native Arrow columns ║ Compound datasets only ║ +║ Multi-process reads ║ Yes (per-worker conn.) ║ No (POSIX file lock) ║ +║ Vector / ANN search ║ Built-in IVF-PQ index ║ Not supported ║ +║ SQL-like filter queries ║ Yes (DuckDB dialect) ║ No ║ +║ Cloud-native (S3/GCS) ║ Native, parallel ║ Download first ║ +║ Schema evolution ║ Add columns in-place ║ Limited (no column drop) ║ +║ Versioning / time-travel ║ Yes (Lance versioning) ║ No ║ +║ Embedding storage ║ Native fixed_size_list ║ Separate dataset ║ +║ Episode-level filters ║ episode_idx = 42 ║ Loop + mask in Python ║ +║ Train/val split ║ Filter query, zero copy ║ Copy or index arrays ║ +║ Arrow zero-copy tensors ║ Yes (with_format="arrow") ║ No (numpy copy always) ║ +║ Concurrent writers ║ Yes (append-safe) ║ No ║ +║ Compressed pixel storage ║ JPEG binary column ║ Raw uint8 (3-13× larger) ║ +╚══════════════════════════════╩════════════════════════════╩════════════════════════════╝ +""" + print("\n" + "=" * 60) + print("8. LANCEDB vs HDF5 — Feature Comparison") + print("=" * 60) + print(comparison) + + print("""Key advantages for leWorldModel: + +1. MULTI-PROCESS DATALOADERS + HDF5 uses a POSIX file lock. Eight DataLoader workers trying to read the + same .hdf5 file simultaneously either serialize or crash. The standard + workaround (copy the file into /dev/shm) wastes RAM and requires manual + setup. LanceDB opens an independent connection per worker with no locking. + +2. PRE-TRAINING EDA WITH FOUNDATION MODEL EMBEDDINGS + Run create_data.py --embed --embedding-model dinov2 before training and + you immediately get ANN search, episode clustering, and similarity + retrieval using DINOv2 semantics — before your LeWM model sees a single + gradient. Not possible with HDF5 without a separate vector store. + +3. POST-TRAINING ANALYSIS WITH LEWM EMBEDDINGS + After training, add a second embedding column (--embedding-model lewm). + You can now compare what DINOv2 vs LeWM consider "similar" — a direct + window into what the world model has learned to focus on. + +4. EPISODE FILTERING WITHOUT ARRAY MANIPULATION + tbl.to_arrow(filter="episode_idx IN (...)") returns only the matching + rows, columnar-compressed, as Arrow. With HDF5 you load the full array + and mask in Python. + +5. ZERO-COPY ARROW FORMAT IN DATALOADERS + Permutation.with_format("arrow") returns a pa.RecordBatch that converts + to tensors without memory copy. HDF5 always goes through numpy. + +6. VERSIONING + Every table.add() creates a new Lance version. You can audit, roll back, + or diff data additions — critical for reproducible experiment tracking. +""") + + +# ============================================================================ +# CLI +# ============================================================================ + +def main(): + parser = argparse.ArgumentParser(description="leWorldModel LanceDB EDA and analysis") + parser.add_argument("--lance-uri", default="./lewm_lance") + parser.add_argument("--table", default="lewm_pusht") + parser.add_argument( + "--emb-col", + default="emb_dinov2", + help="Embedding column to use for vector search sections (default: emb_dinov2)", + ) + parser.add_argument( + "--section", + default="all", + choices=["all", "stats", "splits", "coherence", + "vector_search", "retrieval", "entropy", "quality", "comparison"], + ) + args = parser.parse_args() + + db = lancedb.connect(args.lance_uri) + tbl = db.open_table(args.table) + run_all = args.section == "all" + + if run_all or args.section == "stats": + dataset_statistics(tbl) + if run_all or args.section == "splits": + create_splits(tbl) + if run_all or args.section == "coherence": + temporal_coherence_check(tbl) + if run_all or args.section == "vector_search": + vector_search_demo(tbl, emb_col=args.emb_col) + if run_all or args.section == "retrieval": + episode_retrieval_demo(tbl, emb_col=args.emb_col) + if run_all or args.section == "entropy": + action_entropy_analysis(tbl) + if run_all or args.section == "quality": + data_quality_scan(tbl) + if run_all or args.section == "comparison": + print_lancedb_vs_hdf5() + + +if __name__ == "__main__": + main() diff --git a/examples/leWorldModel/lewm_lance/__init__.py b/examples/leWorldModel/lewm_lance/__init__.py new file mode 100644 index 0000000..f00b070 --- /dev/null +++ b/examples/leWorldModel/lewm_lance/__init__.py @@ -0,0 +1,9 @@ +from .dataset import LeWMLanceDataset, compute_normalizers +from .dataloaders import make_lewm_lance_loader, make_train_val_loaders + +__all__ = [ + "LeWMLanceDataset", + "compute_normalizers", + "make_lewm_lance_loader", + "make_train_val_loaders", +] diff --git a/examples/leWorldModel/lewm_lance/dataloaders.py b/examples/leWorldModel/lewm_lance/dataloaders.py new file mode 100644 index 0000000..26c7a38 --- /dev/null +++ b/examples/leWorldModel/lewm_lance/dataloaders.py @@ -0,0 +1,147 @@ +""" +DataLoader factories for leWorldModel LanceDB-backed training. + +Two public functions: + make_lewm_lance_loader() – single loader, caller provides a pre-built dataset + make_train_val_loaders() – episode-level train/val split, returns two loaders + +Episode-level split (not random-row split) avoids data leakage: + all timesteps of a given episode go entirely to train or entirely to val. +""" + +import lancedb +import numpy as np +import torch +from torch.utils.data import DataLoader + +from .dataset import LeWMLanceDataset, compute_normalizers + + +# --------------------------------------------------------------------------- +# Collate: list[{key: (T,...) tensor}] → {key: (B,T,...) tensor} +# --------------------------------------------------------------------------- + +def _lewm_collate(samples: list[dict]) -> dict: + keys = samples[0].keys() + return {k: torch.stack([s[k] for s in samples], dim=0) for k in keys} + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + +def make_lewm_lance_loader( + uri: str, + table_name: str, + columns: list[str], + batch_size: int, + num_steps: int = 4, + img_size: int = 224, + num_workers: int = 6, + prefetch_factor: int = 3, + normalizers: dict | None = None, + **connect_kwargs, +) -> DataLoader: + """ + Build a DataLoader over a LanceDB leWorldModel table. + + Args: + uri: LanceDB URI. + table_name: Table name (as created by create_data.py). + columns: Columns to load, e.g. ["pixels", "action", "proprio"]. + batch_size: Training batch size B. + num_steps: Window length T = history_size + num_preds. + img_size: Resize target (square) applied before ImageNet normalization. + num_workers: DataLoader worker processes. + prefetch_factor: Batches queued per worker. + normalizers: {col: (mean, std)} from compute_normalizers(); if None, + no normalization is applied to non-pixel columns. + **connect_kwargs: Forwarded to lancedb.connect() (api_key, host_override, …). + """ + dataset = LeWMLanceDataset( + uri=uri, + table_name=table_name, + columns=columns, + num_steps=num_steps, + img_size=img_size, + normalizers=normalizers, + **connect_kwargs, + ) + return _build_loader(dataset, batch_size, num_workers, prefetch_factor) + + +def make_train_val_loaders( + uri: str, + table_name: str, + columns: list[str], + batch_size: int, + num_steps: int = 4, + img_size: int = 224, + num_workers: int = 6, + prefetch_factor: int = 3, + val_fraction: float = 0.1, + seed: int = 42, + **connect_kwargs, +) -> tuple[DataLoader, DataLoader]: + """ + Episode-level train/val split. + + val_fraction of episodes (randomly sampled, seeded) are held out for + validation. Normalizers are computed on training episodes only. + + Returns: + (train_loader, val_loader) + """ + db = lancedb.connect(uri, **connect_kwargs) + tbl = db.open_table(table_name) + + # Only reads one int32 column — negligible memory even at millions of rows + ep_arr = tbl.to_arrow(columns=["episode_idx"])["episode_idx"].to_numpy() + all_episodes = np.unique(ep_arr) + + rng = np.random.default_rng(seed) + rng.shuffle(all_episodes) + n_val = max(1, int(len(all_episodes) * val_fraction)) + val_episodes = set(all_episodes[:n_val].tolist()) + train_episodes = set(all_episodes[n_val:].tolist()) + + print(f" Split: {len(train_episodes)} train episodes, {len(val_episodes)} val episodes") + + # Compute normalizers on training data only to avoid leakage + normalizers = compute_normalizers(uri, table_name, columns, **connect_kwargs) + + # Build full datasets then restrict _window_starts by episode membership + train_ds = LeWMLanceDataset(uri, table_name, columns, num_steps, img_size, normalizers, **connect_kwargs) + val_ds = LeWMLanceDataset(uri, table_name, columns, num_steps, img_size, normalizers, **connect_kwargs) + + train_ep_mask = np.isin(train_ds._ep[train_ds._window_starts], list(train_episodes)) + val_ep_mask = np.isin(val_ds._ep[val_ds._window_starts], list(val_episodes)) + + train_ds._window_starts = train_ds._window_starts[train_ep_mask] + val_ds._window_starts = val_ds._window_starts[val_ep_mask] + + print(f" Windows: {len(train_ds):,} train, {len(val_ds):,} val") + + return ( + _build_loader(train_ds, batch_size, num_workers, prefetch_factor), + _build_loader(val_ds, batch_size, num_workers, prefetch_factor), + ) + + +# --------------------------------------------------------------------------- +# Internal +# --------------------------------------------------------------------------- + +def _build_loader(dataset: LeWMLanceDataset, batch_size: int, num_workers: int, prefetch_factor: int) -> DataLoader: + return DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + pin_memory=True, + drop_last=True, + collate_fn=_lewm_collate, + persistent_workers=(num_workers > 0), + prefetch_factor=prefetch_factor if num_workers > 0 else None, + multiprocessing_context="spawn" if num_workers > 0 else None, + ) diff --git a/examples/leWorldModel/lewm_lance/dataset.py b/examples/leWorldModel/lewm_lance/dataset.py new file mode 100644 index 0000000..4b38879 --- /dev/null +++ b/examples/leWorldModel/lewm_lance/dataset.py @@ -0,0 +1,231 @@ +""" +LanceDB-backed PyTorch Dataset for leWorldModel temporal sequences. + +leWorldModel trains on windows of T consecutive frames from the same episode: + T = history_size (3) + num_preds (1) = 4 by default + +Each dataset item is a dict of tensors: + "pixels" : (T, C, H, W) float32 ImageNet-normalized + "action" : (T, A) float32 z-score normalized, NaN→0 + "proprio" : (T, P) float32 z-score normalized [if present] + ... + +Design: + - One LanceDB row = one timestep. + - Window index is precomputed at __init__ so __getitems__ only does I/O. + - Permutation object (Rust state) is zeroed before pickling; each worker + lazily reopens its own connection inside _ensure_open(). + - __getitems__ batches the full (B*T) row fetch in a single Permutation call, + then splits into per-sample dicts — same pattern as ViT dataloaders.py. +""" + +import io + +import lancedb +import numpy as np +import pyarrow as pa +import torch +from lancedb.permutation import Permutation +from PIL import Image +from torchvision import transforms + + +_IMAGENET_MEAN = [0.485, 0.456, 0.406] +_IMAGENET_STD = [0.229, 0.224, 0.225] + + +def _build_img_transform(img_size: int) -> transforms.Compose: + return transforms.Compose([ + transforms.Resize((img_size, img_size)), + transforms.ToTensor(), + transforms.Normalize(mean=_IMAGENET_MEAN, std=_IMAGENET_STD), + ]) + + +def _jpeg_to_tensor(jpeg_bytes: bytes, transform: transforms.Compose) -> torch.Tensor: + img = Image.open(io.BytesIO(jpeg_bytes)).convert("RGB") + return transform(img) + + +def compute_normalizers( + uri: str, + table_name: str, + columns: list[str], + **connect_kwargs, +) -> dict[str, tuple[np.ndarray, np.ndarray]]: + """ + Compute per-column (mean, std) arrays for z-score normalization. + + Reads only the requested non-pixel columns using LanceDB's column projection + so no pixel data is loaded. For large datasets this streams in batches via + the Arrow scanner rather than materializing the full table at once. + + Returns: + {col: (mean_array, std_array)} — each array has shape (D,). + """ + db = lancedb.connect(uri, **connect_kwargs) + tbl = db.open_table(table_name) + non_pixel = [c for c in columns if c != "pixels"] + if not non_pixel: + return {} + + # Use column-projected Arrow read — loads only the requested columns. + # episode_idx + step_idx columns are tiny; the vector columns are float32 + # lists already in Arrow format, so this is as efficient as possible. + arrow = tbl.to_arrow(columns=non_pixel) + normalizers = {} + for col in non_pixel: + data = np.stack([row.as_py() for row in arrow[col]], axis=0).astype(np.float32) + valid = ~np.isnan(data).any(axis=1) + data = data[valid] + normalizers[col] = (data.mean(axis=0), data.std(axis=0)) + return normalizers + + +class LeWMLanceDataset(torch.utils.data.Dataset): + """ + Temporal-window dataset backed by a LanceDB table. + + Args: + uri: LanceDB URI (local path or s3://…). + table_name: Name of the table created by create_data.py. + columns: List of column names to return, e.g. ["pixels","action","proprio"]. + num_steps: Window length T (= history_size + num_preds). + img_size: Target image size after resize (square). + normalizers: Output of compute_normalizers(); used to z-score non-pixel columns. + **connect_kwargs: Passed to lancedb.connect() (api_key, host_override, region, …). + """ + + def __init__( + self, + uri: str, + table_name: str, + columns: list[str], + num_steps: int = 4, + img_size: int = 224, + normalizers: dict | None = None, + **connect_kwargs, + ): + self.uri = uri + self.table_name = table_name + self.columns = columns + self.num_steps = num_steps + self.img_size = img_size + self.normalizers = normalizers or {} + self.connect_kwargs = connect_kwargs + + # Rust Permutation — zeroed before pickling, rebuilt per-worker + self._perm: Permutation | None = None + self._transform: transforms.Compose | None = None + + # ------------------------------------------------------------------ + # Eagerly load the episode/step index to precompute valid windows. + # These are only two int32 columns — ~8 bytes/row regardless of + # dataset size, so loading them fully is fine. + # ------------------------------------------------------------------ + db = lancedb.connect(uri, **connect_kwargs) + tbl = db.open_table(table_name) + idx_arrow = tbl.to_arrow(columns=["episode_idx", "step_idx"]) + self._ep = idx_arrow["episode_idx"].to_numpy().astype(np.int32) + self._step = idx_arrow["step_idx"].to_numpy().astype(np.int32) + self._n_rows = len(self._ep) + + # Precompute valid window start rows. + # A window starting at row i is valid iff rows i..i+T-1 are all in + # the same episode and have consecutive step indices. + T = num_steps + N = self._n_rows - T + 1 + valid = np.ones(N, dtype=bool) + for offset in range(1, T): + same_ep = self._ep[offset : N + offset] == self._ep[:N] + consec = self._step[offset : N + offset] == self._step[:N] + offset + valid &= same_ep & consec + + # _window_starts[i] = absolute row index for the i-th valid window + self._window_starts = np.where(valid)[0].astype(np.int64) + + # ---------------------------------------------------------------------- # + # PyTorch Dataset protocol + # ---------------------------------------------------------------------- # + + def __len__(self) -> int: + return len(self._window_starts) + + def __getstate__(self) -> dict: + """Zero out Rust state before the object is pickled for a worker process.""" + state = self.__dict__.copy() + state["_perm"] = None + state["_transform"] = None + return state + + def _ensure_open(self): + """Lazily open DB connection + Permutation once per worker process.""" + if self._perm is None: + db = lancedb.connect(self.uri, **self.connect_kwargs) + tbl = db.open_table(self.table_name) + fetch_cols = ["pixels"] + [c for c in self.columns if c != "pixels"] + self._perm = ( + Permutation.identity(tbl) + .select_columns(fetch_cols) + .with_format("arrow") + ) + self._transform = _build_img_transform(self.img_size) + + # ---------------------------------------------------------------------- # + # Internal: convert a RecordBatch of T rows into a sample dict + # ---------------------------------------------------------------------- # + + def _rows_to_sample(self, batch: pa.RecordBatch) -> dict[str, torch.Tensor]: + T = self.num_steps + assert len(batch) == T + + # Decode JPEG pixels → (T, C, H, W) + jpeg_list = batch["pixels"].to_pylist() + frames = torch.stack([_jpeg_to_tensor(b, self._transform) for b in jpeg_list]) + sample: dict[str, torch.Tensor] = {"pixels": frames} + + for col in self.columns: + if col == "pixels": + continue + data = np.array([batch[col][t].as_py() for t in range(T)], dtype=np.float32) + if col in self.normalizers: + mean, std = self.normalizers[col] + data = (data - mean) / (std + 1e-8) + data = np.nan_to_num(data, nan=0.0) + sample[col] = torch.from_numpy(data) + + return sample + + # ---------------------------------------------------------------------- # + # Single-item access (used when num_workers=0) + # ---------------------------------------------------------------------- # + + def __getitem__(self, window_idx: int) -> dict[str, torch.Tensor]: + self._ensure_open() + start = int(self._window_starts[window_idx]) + rows = list(range(start, start + self.num_steps)) + batch = self._perm.__getitems__(rows) + return self._rows_to_sample(batch) + + # ---------------------------------------------------------------------- # + # Batch access — called by DataLoader with num_workers > 0. + # Fetches all B*T rows in ONE Permutation call instead of B calls. + # ---------------------------------------------------------------------- # + + def __getitems__(self, window_indices: list[int]) -> list[dict[str, torch.Tensor]]: + self._ensure_open() + T = self.num_steps + starts = self._window_starts[window_indices] # (B,) + + # Build flat list: [w0_t0, w0_t1, …, w1_t0, w1_t1, …] + all_rows: list[int] = [] + for s in starts: + all_rows.extend(range(int(s), int(s) + T)) + + big_batch: pa.RecordBatch = self._perm.__getitems__(all_rows) # (B*T, cols) + + samples = [] + for b in range(len(window_indices)): + row_slice = big_batch.slice(b * T, T) + samples.append(self._rows_to_sample(row_slice)) + return samples diff --git a/examples/leWorldModel/requirements.txt b/examples/leWorldModel/requirements.txt new file mode 100644 index 0000000..59257f2 --- /dev/null +++ b/examples/leWorldModel/requirements.txt @@ -0,0 +1,28 @@ +# leWorldModel × LanceDB requirements + +# Core +lancedb>=0.20.0 +pyarrow>=16.0.0 +torch>=2.2.0 +torchvision>=0.17.0 + +# Training +pytorch-lightning>=2.2.0 +timm>=1.0.0 +pyyaml>=6.0 +stable-worldmodel + +# Data +h5py>=3.10.0 +hdf5plugin>=4.3.0 +Pillow>=10.0.0 +numpy>=1.26.0 +tqdm>=4.66.0 +s3fs>=2024.2.0 + +# Embedding models +# openai-clip — for --embedding-model clip +# geneva — LanceDB Enterprise, for embedding backfill + +# Logging +wandb>=0.17.0 diff --git a/examples/leWorldModel/train.py b/examples/leWorldModel/train.py new file mode 100644 index 0000000..b76b50c --- /dev/null +++ b/examples/leWorldModel/train.py @@ -0,0 +1,365 @@ +""" +leWorldModel trainer with LanceDB data backend. + +Drop-in replacement for le-wm/train.py that swaps the HDF5 data pipeline for +a LanceDB-backed DataLoader while keeping the model, loss, and Lightning +training loop identical to the original. + +Usage: + # Local LanceDB store (defaults come from config) + python train.py --config config/lewm_pusht.yaml + + # S3-backed store, credentials via CLI + python train.py --config config/lewm_pusht.yaml \\ + --lance-uri s3://my-bucket/lewm \\ + --aws-region us-east-1 \\ + --aws-access-key-id AKIA... \\ + --aws-secret-access-key ... + + # S3-backed store, credentials via environment (AWS_ACCESS_KEY_ID etc.) + python train.py --config config/lewm_pusht.yaml --lance-uri s3://my-bucket/lewm + + # Override table and columns without editing the config + python train.py --config config/lewm_pusht.yaml \\ + --table-name lewm_reacher \\ + --columns pixels action observation +""" + +import argparse +import os +import sys + +sys.path.insert(0, os.path.dirname(__file__)) + +import pytorch_lightning as pl +import timm +import torch +import yaml +from pytorch_lightning.loggers import WandbLogger + +from jepa import JEPA +from lewm_lance import make_train_val_loaders +from module import ARPredictor, Embedder, MLP, SIGReg +from stable_worldmodel.optim import LinearWarmupCosineAnnealingLR +from utils import ModelObjectCallBack + + +# --------------------------------------------------------------------------- +# Lightning module +# --------------------------------------------------------------------------- + +class LeWMLightning(pl.LightningModule): + """ + PyTorch Lightning wrapper around the JEPA world model. + + Accepts the batch dict produced by make_train_val_loaders: + "pixels" : (B, T, C, H, W) float32 + "action" : (B, T, A) float32 + ...additional columns... + """ + + def __init__(self, model: JEPA, sigreg: SIGReg, cfg: dict): + super().__init__() + self.model = model + self.sigreg = sigreg + self.cfg = cfg + self.save_hyperparameters(ignore=["model", "sigreg"]) + + def _shared_step(self, batch: dict, stage: str) -> torch.Tensor: + loss_pred = self.model.criterion(batch)["loss"] + loss_reg = self.sigreg(self.model.last_embeddings) + loss = loss_pred + self.cfg["sigreg_weight"] * loss_reg + self.log(f"{stage}/loss_pred", loss_pred, on_step=(stage == "train"), on_epoch=True, prog_bar=True) + self.log(f"{stage}/loss_reg", loss_reg, on_step=(stage == "train"), on_epoch=True) + self.log(f"{stage}/loss", loss, on_step=(stage == "train"), on_epoch=True, prog_bar=True) + return loss + + def training_step(self, batch: dict, _) -> torch.Tensor: + return self._shared_step(batch, "train") + + def validation_step(self, batch: dict, _) -> None: + self._shared_step(batch, "val") + + def configure_optimizers(self): + opt = torch.optim.AdamW( + self.model.parameters(), + lr=self.cfg["lr"], + weight_decay=self.cfg["weight_decay"], + ) + sched = LinearWarmupCosineAnnealingLR( + opt, + warmup_epochs=self.cfg["warmup_epochs"], + max_epochs=self.cfg["max_epochs"], + ) + return [opt], [{"scheduler": sched, "interval": "epoch"}] + + +# --------------------------------------------------------------------------- +# Model construction +# --------------------------------------------------------------------------- + +def build_model(cfg: dict) -> tuple[JEPA, SIGReg]: + m = cfg["model"] + + encoder = timm.create_model( + m["encoder_name"], + pretrained=False, + img_size=m["image_size"], + num_classes=0, + ) + + predictor = ARPredictor( + embed_dim=m["embed_dim"], + depth=m["predictor_depth"], + num_heads=m["predictor_heads"], + mlp_dim=m["predictor_mlp_dim"], + max_seq_len=m["history_size"] + m["num_preds"], + ) + + action_encoder = Embedder( + in_dim=m["action_dim"], + out_dim=m["embed_dim"], + ) + + encoder_dim = encoder.embed_dim + projector = MLP(encoder_dim, m["proj_hidden"], m["embed_dim"]) + pred_proj = MLP(m["embed_dim"], m["proj_hidden"], m["embed_dim"]) + + model = JEPA( + encoder=encoder, + predictor=predictor, + action_encoder=action_encoder, + projector=projector, + pred_proj=pred_proj, + history_size=m["history_size"], + num_preds=m["num_preds"], + ) + + sigreg = SIGReg( + embed_dim=m["embed_dim"], + knots=cfg["loss"]["sigreg_knots"], + num_proj=cfg["loss"]["sigreg_num_proj"], + ) + + return model, sigreg + + +# --------------------------------------------------------------------------- +# S3 storage options +# --------------------------------------------------------------------------- + +def build_storage_options(args: argparse.Namespace) -> dict: + """ + Build the storage_options dict for lancedb.connect() from CLI args, + falling back to standard AWS environment variables. + + LanceDB passes storage_options directly to the Rust object_store library, + which accepts these keys for S3: + aws_access_key_id, aws_secret_access_key, aws_session_token, + region, endpoint_url, aws_virtual_hosted_style_request + + Environment variable fallbacks follow the standard AWS SDK convention: + AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_SESSION_TOKEN, + AWS_DEFAULT_REGION, AWS_ENDPOINT_URL + + Returns an empty dict for local URIs (no storage_options needed). + """ + uri = args.lance_uri + if not (uri.startswith("s3://") or uri.startswith("gs://") or uri.startswith("az://")): + return {} + + opts: dict[str, str] = {} + + access_key = args.aws_access_key_id or os.environ.get("AWS_ACCESS_KEY_ID") + secret_key = args.aws_secret_access_key or os.environ.get("AWS_SECRET_ACCESS_KEY") + session_token = args.aws_session_token or os.environ.get("AWS_SESSION_TOKEN") + region = args.aws_region or os.environ.get("AWS_DEFAULT_REGION") or os.environ.get("AWS_REGION") + endpoint = args.s3_endpoint or os.environ.get("AWS_ENDPOINT_URL") + + if access_key: + opts["aws_access_key_id"] = access_key + if secret_key: + opts["aws_secret_access_key"] = secret_key + if session_token: + opts["aws_session_token"] = session_token + if region: + opts["region"] = region + if endpoint: + opts["endpoint_url"] = endpoint + opts["aws_virtual_hosted_style_request"] = "false" + + return opts + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + +def _build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Train leWorldModel with LanceDB data backend", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument( + "--config", + default="config/lewm_pusht.yaml", + help="Path to YAML training config", + ) + parser.add_argument( + "--lance-uri", + default=None, + help="LanceDB URI. Defaults to data.lance_uri in config. " + "Use s3://bucket/prefix for cloud storage.", + ) + parser.add_argument( + "--table-name", + default=None, + help="LanceDB table name. Defaults to data.table_name in config.", + ) + parser.add_argument( + "--columns", + nargs="+", + default=None, + help="Columns to load. Defaults to data.columns in config.", + ) + parser.add_argument( + "--run-name", + default=None, + help="WandB run name. Defaults to -steps.", + ) + parser.add_argument( + "--no-wandb", + action="store_true", + help="Disable WandB logging.", + ) + + s3 = parser.add_argument_group( + "S3 storage options", + "Credentials for S3-backed LanceDB tables. All args fall back to " + "standard AWS environment variables (AWS_ACCESS_KEY_ID, etc.).", + ) + s3.add_argument("--aws-access-key-id", default=None, metavar="KEY") + s3.add_argument("--aws-secret-access-key", default=None, metavar="SECRET") + s3.add_argument("--aws-session-token", default=None, metavar="TOKEN", + help="Temporary session token (STS / IAM role assumed credentials).") + s3.add_argument("--aws-region", default=None, metavar="REGION", + help="AWS region, e.g. us-east-1.") + s3.add_argument("--s3-endpoint", default=None, metavar="URL", + help="Custom S3-compatible endpoint (MinIO, R2, etc.).") + + return parser + + +def main(): + parser = _build_parser() + args = parser.parse_args() + + with open(args.config) as f: + cfg = yaml.safe_load(f) + + data_cfg = cfg["data"] + model_cfg = cfg["model"] + train_cfg = cfg["training"] + loss_cfg = cfg.get("loss", {}) + + # Resolve: CLI arg > config value > hardcoded fallback + lance_uri = args.lance_uri or data_cfg.get("lance_uri", "./lewm_lance") + table_name = args.table_name or data_cfg.get("table_name") + columns = args.columns or data_cfg["columns"] + num_steps = model_cfg["history_size"] + model_cfg["num_preds"] + + if table_name is None: + parser.error("table_name is required: set data.table_name in config or pass --table-name") + + storage_options = build_storage_options(args) + if storage_options: + print(f" S3 storage: region={storage_options.get('region', 'env')}" + + (f" endpoint={storage_options['endpoint_url']}" if "endpoint_url" in storage_options else "")) + + # ------------------------------------------------------------------ # + # Data + # ------------------------------------------------------------------ # + # storage_options is only passed for cloud URIs; empty dict for local paths + connect_kwargs = {"storage_options": storage_options} if storage_options else {} + + print(f"Building DataLoaders ({lance_uri} / {table_name})...") + train_loader, val_loader = make_train_val_loaders( + uri=lance_uri, + table_name=table_name, + columns=columns, + batch_size=train_cfg["batch_size"], + num_steps=num_steps, + img_size=model_cfg["image_size"], + num_workers=data_cfg["num_workers"], + prefetch_factor=data_cfg["prefetch_factor"], + val_fraction=data_cfg["val_fraction"], + seed=data_cfg["seed"], + **connect_kwargs, + ) + print(f" Train batches: {len(train_loader):,} | Val batches: {len(val_loader):,}") + + # ------------------------------------------------------------------ # + # Model + # ------------------------------------------------------------------ # + print("Building model...") + model_cfg["action_dim"] = _infer_action_dim(train_loader) + model, sigreg = build_model(cfg) + n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(f" Trainable parameters: {n_params / 1e6:.1f}M") + + lightning_model = LeWMLightning( + model=model, + sigreg=sigreg, + cfg={**train_cfg, **loss_cfg}, + ) + + # ------------------------------------------------------------------ # + # Logging & callbacks + # ------------------------------------------------------------------ # + run_name = args.run_name or f"{table_name}-{num_steps}T" + + logger = None + if not args.no_wandb: + logger = WandbLogger( + project=cfg.get("wandb_project", "lewm-lancedb"), + name=run_name, + config={**cfg, "lance_uri": lance_uri, "table": table_name}, + ) + + ckpt_dir = train_cfg["checkpoint_dir"] + os.makedirs(ckpt_dir, exist_ok=True) + + callbacks = [ + ModelObjectCallBack( + dirpath=ckpt_dir, + filename=f"{table_name}_lewm", + epoch_interval=train_cfg["save_every_n_epochs"], + ) + ] + + # ------------------------------------------------------------------ # + # Trainer + # ------------------------------------------------------------------ # + trainer = pl.Trainer( + max_epochs=train_cfg["max_epochs"], + precision=train_cfg["precision"], + gradient_clip_val=train_cfg["gradient_clip"], + logger=logger, + callbacks=callbacks, + log_every_n_steps=train_cfg["log_every_n_steps"], + enable_progress_bar=True, + ) + + print("Starting training...") + trainer.fit(lightning_model, train_loader, val_loader) + print("Training complete.") + + +def _infer_action_dim(loader: torch.utils.data.DataLoader) -> int: + batch = next(iter(loader)) + return batch["action"].shape[-1] + + +if __name__ == "__main__": + main() From 4b2eb783d51ccb4c749fc26ee254bf702db55e65 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Wed, 1 Apr 2026 15:54:43 +0530 Subject: [PATCH 02/29] update --- examples/leWorldModel/README.md | 37 ++++++++++++--- examples/leWorldModel/create_data.py | 68 +++++++++++++++++----------- 2 files changed, 72 insertions(+), 33 deletions(-) diff --git a/examples/leWorldModel/README.md b/examples/leWorldModel/README.md index 855f04c..68ab389 100644 --- a/examples/leWorldModel/README.md +++ b/examples/leWorldModel/README.md @@ -55,13 +55,38 @@ For the DataLoader, `num_workers=6` works well with a local LanceDB store. With ## Reproducing the paper -### Step 1 — Convert datasets +### Step 1 — Collect / download datasets -Download the datasets via `stable-worldmodel` and convert each to a LanceDB table. -All four datasets can share a single LanceDB store (each gets its own table). +Three datasets are collected locally using the stable-worldmodel expert scripts. +The cube dataset (`ogbench/cube_single_expert`) is downloaded automatically from +HuggingFace during conversion — no manual step needed for that one. ```bash +# Clone stable-worldmodel for the data collection scripts +git clone https://github.com/galilai-group/stable-worldmodel /tmp/stable-worldmodel +cd /tmp/stable-worldmodel && pip install -e . + +# Collect PushT expert demonstrations (~1000 episodes) +python scripts/data/collect_weak_pusht.py + +# Collect DMControl Reacher +python scripts/data/collect_dmc.py + +# Collect TwoRoom +python scripts/data/collect_tworooms.py + +# cube is auto-downloaded from HuggingFace in the next step +``` + +HDF5 files are saved to `$STABLEWM_HOME` (default: `~/.stable_worldmodel/`). + +### Step 2 — Convert datasets to LanceDB + +```bash +cd /path/to/examples/leWorldModel + # All four datasets into one local store +# (cube is fetched from HuggingFace automatically if not cached) python create_data.py --dataset all --lance-uri ./lewm_lance # Or one at a time @@ -76,7 +101,7 @@ python create_data.py --dataset all --lance-uri s3://my-bucket/lewm This creates four tables: `lewm_reacher`, `lewm_cube`, `lewm_pusht`, `lewm_tworoom`. -### Step 2 — EDA and data quality check +### Step 3 — EDA and data quality check Run this **before** training to catch any data issues and understand each dataset. Uses DINOv2 embeddings (no training needed — frozen foundation model). @@ -98,7 +123,7 @@ python eda_analysis.py --table lewm_pusht --section entropy # find divers python eda_analysis.py --table lewm_pusht --section stats ``` -### Step 3 — Train on each dataset +### Step 4 — Train on each dataset Each dataset is trained independently. Create a config per dataset by copying `config/lewm_pusht.yaml` and updating `data.table_name` and `data.columns`. @@ -133,7 +158,7 @@ python train.py --config config/lewm_pusht.yaml \ python train.py --config config/lewm_pusht.yaml --lance-uri s3://my-bucket/lewm ``` -### Step 4 — Post-training analysis +### Step 5 — Post-training analysis After training, add LeWM encoder embeddings to the table for world-model-specific analysis. These reflect what the trained model considers semantically similar — different from DINOv2. diff --git a/examples/leWorldModel/create_data.py b/examples/leWorldModel/create_data.py index 11732ad..d8b4424 100644 --- a/examples/leWorldModel/create_data.py +++ b/examples/leWorldModel/create_data.py @@ -4,13 +4,27 @@ Each LanceDB row = one timestep (same granularity as the source HDF5). See dataset.md for full format documentation. -Supported datasets: - reacher → ~/.stable-wm/reacher.hdf5 - cube → ~/.stable-wm/cube_single_expert.hdf5 - pusht → ~/.stable-wm/pusht_expert_train.hdf5 - tworoom → ~/.stable-wm/tworoom.hdf5 +COLLECTING THE DATASETS +----------------------- +Three datasets (reacher, pusht, tworoom) must be collected locally using the +stable-worldmodel expert scripts before converting. The cube dataset is +downloaded automatically from HuggingFace (ogbench/cube_single_expert). -Usage: + # Collect reacher (~30 min on a single GPU with mujoco) + python scripts/data/collect_dmc.py + + # Collect pusht + python scripts/data/collect_pusht_fov.py # or collect_weak_pusht.py + + # Collect tworoom + python scripts/data/collect_tworooms.py + + # cube is auto-downloaded from HuggingFace when you run create_data.py + +HDF5 files are written to $STABLEWM_HOME (default: ~/.stable_worldmodel/). + +CONVERTING TO LANCEDB +--------------------- # Convert all datasets to a local LanceDB store python create_data.py --dataset all --lance-uri ./lewm_lance @@ -35,6 +49,7 @@ import numpy as np import pyarrow as pa from PIL import Image +from stable_worldmodel.data.utils import get_cache_dir, load_dataset from tqdm import tqdm @@ -43,23 +58,28 @@ # --------------------------------------------------------------------------- DATASETS = { + # swm_name is passed to stable_worldmodel.data.load_dataset(). + # For reacher/pusht/tworoom these are local dataset names (collected via + # the stable-worldmodel data-collection scripts and stored in STABLEWM_HOME). + # For cube, "ogbench/cube_single_expert" is a HuggingFace repo id — load_dataset() + # downloads and caches it automatically on first run. "reacher": { - "hdf5_file": "reacher.hdf5", + "swm_name": "reacher", "table_name": "lewm_reacher", "columns": ["pixels", "action", "observation"], }, "cube": { - "hdf5_file": "cube_single_expert.hdf5", + "swm_name": "ogbench/cube_single_expert", "table_name": "lewm_cube", "columns": ["pixels", "action", "observation"], }, "pusht": { - "hdf5_file": "pusht_expert_train.hdf5", + "swm_name": "pusht_expert_train", "table_name": "lewm_pusht", "columns": ["pixels", "action", "proprio", "state"], }, "tworoom": { - "hdf5_file": "tworoom.hdf5", + "swm_name": "tworoom", "table_name": "lewm_tworoom", "columns": ["pixels", "action", "proprio"], }, @@ -196,26 +216,26 @@ def _make_batch( def convert_dataset( dataset_name: str, - hdf5_dir: str, lance_uri: str, overwrite: bool = False, connect_kwargs: dict | None = None, ): - cfg = DATASETS[dataset_name] - hdf5_path = os.path.join(hdf5_dir, cfg["hdf5_file"]) + cfg = DATASETS[dataset_name] + swm_name = cfg["swm_name"] table_name = cfg["table_name"] columns = cfg["columns"] connect_kwargs = connect_kwargs or {} - if not os.path.exists(hdf5_path): - raise FileNotFoundError( - f"HDF5 not found: {hdf5_path}\n" - "Download with: " - f"python -c \"import stable_worldmodel as swm; swm.data.download('{dataset_name}')\"" - ) - + # Resolve the HDF5 path via stable_worldmodel. + # load_dataset() handles both local names (looked up in $STABLEWM_HOME) and + # HuggingFace repo ids (e.g. "ogbench/cube_single_expert") — it downloads and + # caches the archive automatically for the latter. print(f"\n{'=' * 60}") - print(f"Dataset : {dataset_name}") + print(f"Dataset : {dataset_name} (swm_name={swm_name!r})") + print(f" Resolving HDF5 path via stable_worldmodel...") + ds = load_dataset(swm_name) + hdf5_path = ds.h5_path + print(f"HDF5 : {hdf5_path}") print(f"Lance : {lance_uri} (table={table_name})") print(f"{'=' * 60}") @@ -466,11 +486,6 @@ def main(): default="all", help="Dataset to convert (default: all)", ) - parser.add_argument( - "--hdf5-dir", - default=os.path.expanduser("~/.stable-wm"), - help="Directory containing .hdf5 files (default: ~/.stable-wm)", - ) parser.add_argument( "--lance-uri", default="./lewm_lance", @@ -515,7 +530,6 @@ def main(): for ds_name in datasets_to_run: convert_dataset( dataset_name=ds_name, - hdf5_dir=args.hdf5_dir, lance_uri=args.lance_uri, overwrite=args.overwrite, ) From 96f867e1111a227f78b0ba9208376391eac9fdd5 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Wed, 1 Apr 2026 16:05:50 +0530 Subject: [PATCH 03/29] update --- examples/leWorldModel/README.md | 28 +++++----------------------- examples/leWorldModel/create_data.py | 24 +++++++++++------------- 2 files changed, 16 insertions(+), 36 deletions(-) diff --git a/examples/leWorldModel/README.md b/examples/leWorldModel/README.md index 68ab389..d93cc1d 100644 --- a/examples/leWorldModel/README.md +++ b/examples/leWorldModel/README.md @@ -55,30 +55,12 @@ For the DataLoader, `num_workers=6` works well with a local LanceDB store. With ## Reproducing the paper -### Step 1 — Collect / download datasets +### Step 1 — No dataset setup needed -Three datasets are collected locally using the stable-worldmodel expert scripts. -The cube dataset (`ogbench/cube_single_expert`) is downloaded automatically from -HuggingFace during conversion — no manual step needed for that one. - -```bash -# Clone stable-worldmodel for the data collection scripts -git clone https://github.com/galilai-group/stable-worldmodel /tmp/stable-worldmodel -cd /tmp/stable-worldmodel && pip install -e . - -# Collect PushT expert demonstrations (~1000 episodes) -python scripts/data/collect_weak_pusht.py - -# Collect DMControl Reacher -python scripts/data/collect_dmc.py - -# Collect TwoRoom -python scripts/data/collect_tworooms.py - -# cube is auto-downloaded from HuggingFace in the next step -``` - -HDF5 files are saved to `$STABLEWM_HOME` (default: `~/.stable_worldmodel/`). +All four datasets are published on HuggingFace at +https://huggingface.co/collections/quentinll/lewm. +`create_data.py` downloads and caches each one automatically on first run +via `stable_worldmodel.data.load_dataset()` — just run Step 2. ### Step 2 — Convert datasets to LanceDB diff --git a/examples/leWorldModel/create_data.py b/examples/leWorldModel/create_data.py index d8b4424..a2f9e13 100644 --- a/examples/leWorldModel/create_data.py +++ b/examples/leWorldModel/create_data.py @@ -58,28 +58,27 @@ # --------------------------------------------------------------------------- DATASETS = { - # swm_name is passed to stable_worldmodel.data.load_dataset(). - # For reacher/pusht/tworoom these are local dataset names (collected via - # the stable-worldmodel data-collection scripts and stored in STABLEWM_HOME). - # For cube, "ogbench/cube_single_expert" is a HuggingFace repo id — load_dataset() - # downloads and caches it automatically on first run. + # swm_name is a HuggingFace repo id (owner/repo). + # stable_worldmodel.data.load_dataset() downloads and caches the archive + # automatically on first run — no manual download needed. + # HF collection: https://huggingface.co/collections/quentinll/lewm "reacher": { - "swm_name": "reacher", + "swm_name": "quentinll/lewm-reacher", "table_name": "lewm_reacher", "columns": ["pixels", "action", "observation"], }, "cube": { - "swm_name": "ogbench/cube_single_expert", + "swm_name": "quentinll/lewm-cube", "table_name": "lewm_cube", "columns": ["pixels", "action", "observation"], }, "pusht": { - "swm_name": "pusht_expert_train", + "swm_name": "quentinll/lewm-pusht", "table_name": "lewm_pusht", "columns": ["pixels", "action", "proprio", "state"], }, "tworoom": { - "swm_name": "tworoom", + "swm_name": "quentinll/lewm-tworooms", "table_name": "lewm_tworoom", "columns": ["pixels", "action", "proprio"], }, @@ -226,10 +225,9 @@ def convert_dataset( columns = cfg["columns"] connect_kwargs = connect_kwargs or {} - # Resolve the HDF5 path via stable_worldmodel. - # load_dataset() handles both local names (looked up in $STABLEWM_HOME) and - # HuggingFace repo ids (e.g. "ogbench/cube_single_expert") — it downloads and - # caches the archive automatically for the latter. + # load_dataset() resolves HuggingFace repo ids: downloads the .tar.zst archive, + # extracts it, caches the .h5 file under $STABLEWM_HOME, and returns an + # HDF5Dataset with the resolved .h5_path. Nothing to do manually. print(f"\n{'=' * 60}") print(f"Dataset : {dataset_name} (swm_name={swm_name!r})") print(f" Resolving HDF5 path via stable_worldmodel...") From 69c0b51220471d379725353718c0e9dfc66905c7 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Wed, 1 Apr 2026 16:10:46 +0530 Subject: [PATCH 04/29] update --- examples/leWorldModel/create_data.py | 64 ++++++++++++++++++++++---- examples/leWorldModel/requirements.txt | 1 + 2 files changed, 57 insertions(+), 8 deletions(-) diff --git a/examples/leWorldModel/create_data.py b/examples/leWorldModel/create_data.py index a2f9e13..0c44cee 100644 --- a/examples/leWorldModel/create_data.py +++ b/examples/leWorldModel/create_data.py @@ -48,8 +48,8 @@ import lancedb import numpy as np import pyarrow as pa +from huggingface_hub import snapshot_download from PIL import Image -from stable_worldmodel.data.utils import get_cache_dir, load_dataset from tqdm import tqdm @@ -209,6 +209,59 @@ def _make_batch( return pa.RecordBatch.from_arrays(arrays, schema=schema) +# --------------------------------------------------------------------------- +# HuggingFace dataset resolution +# --------------------------------------------------------------------------- + +_CACHE_DIR = os.path.expanduser(os.environ.get("STABLEWM_HOME", "~/.stable_worldmodel")) + + +def _ensure_hdf5(hf_repo: str) -> str: + """ + Return the local path to the .h5 file for *hf_repo*. + + Downloads the repo snapshot from HuggingFace Hub on first call and caches + it under $STABLEWM_HOME (default: ~/.stable_worldmodel/). + The repo is expected to contain exactly one .h5 / .hdf5 file. + """ + import glob + cache_dir = os.path.join(_CACHE_DIR, "datasets", hf_repo.replace("/", "--")) + # Return cached file if already present + for ext in ("*.h5", "*.hdf5"): + matches = glob.glob(os.path.join(cache_dir, ext)) + if matches: + return matches[0] + + print(f" Downloading {hf_repo} from HuggingFace Hub...") + snapshot_download( + repo_id=hf_repo, + repo_type="dataset", + local_dir=cache_dir, + ignore_patterns=["*.tar.zst"], # skip raw archives if .h5 is published directly + ) + + # After download, check for .h5 first; fall back to extracting .tar.zst + for ext in ("*.h5", "*.hdf5"): + matches = glob.glob(os.path.join(cache_dir, ext)) + if matches: + return matches[0] + + # Extract any .tar.zst archive present + import subprocess + archives = glob.glob(os.path.join(cache_dir, "*.tar.zst")) + assert archives, f"No .h5 or .tar.zst found in {cache_dir} after downloading {hf_repo}" + for archive in archives: + subprocess.run( + ["tar", "--use-compress-program=unzstd", "-xf", archive, "-C", cache_dir], + check=True, + ) + os.remove(archive) + + matches = glob.glob(os.path.join(cache_dir, "*.h5")) + glob.glob(os.path.join(cache_dir, "*.hdf5")) + assert matches, f"No .h5 file found in {cache_dir} after extracting archives" + return matches[0] + + # --------------------------------------------------------------------------- # Core conversion # --------------------------------------------------------------------------- @@ -225,15 +278,10 @@ def convert_dataset( columns = cfg["columns"] connect_kwargs = connect_kwargs or {} - # load_dataset() resolves HuggingFace repo ids: downloads the .tar.zst archive, - # extracts it, caches the .h5 file under $STABLEWM_HOME, and returns an - # HDF5Dataset with the resolved .h5_path. Nothing to do manually. print(f"\n{'=' * 60}") - print(f"Dataset : {dataset_name} (swm_name={swm_name!r})") - print(f" Resolving HDF5 path via stable_worldmodel...") - ds = load_dataset(swm_name) - hdf5_path = ds.h5_path + print(f"Dataset : {dataset_name} (hf_repo={swm_name!r})") + hdf5_path = _ensure_hdf5(swm_name) print(f"HDF5 : {hdf5_path}") print(f"Lance : {lance_uri} (table={table_name})") print(f"{'=' * 60}") diff --git a/examples/leWorldModel/requirements.txt b/examples/leWorldModel/requirements.txt index 59257f2..4715490 100644 --- a/examples/leWorldModel/requirements.txt +++ b/examples/leWorldModel/requirements.txt @@ -2,6 +2,7 @@ # Core lancedb>=0.20.0 +huggingface_hub>=0.23.0 pyarrow>=16.0.0 torch>=2.2.0 torchvision>=0.17.0 From 4350dc5ecc6ed0ca0496d09d658fb52b192ac272 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Wed, 1 Apr 2026 16:13:20 +0530 Subject: [PATCH 05/29] update --- examples/leWorldModel/create_data.py | 52 +++++++++++++++------------- 1 file changed, 28 insertions(+), 24 deletions(-) diff --git a/examples/leWorldModel/create_data.py b/examples/leWorldModel/create_data.py index 0c44cee..a73eb22 100644 --- a/examples/leWorldModel/create_data.py +++ b/examples/leWorldModel/create_data.py @@ -48,7 +48,7 @@ import lancedb import numpy as np import pyarrow as pa -from huggingface_hub import snapshot_download +from huggingface_hub import hf_hub_download, list_repo_files from PIL import Image from tqdm import tqdm @@ -220,45 +220,49 @@ def _ensure_hdf5(hf_repo: str) -> str: """ Return the local path to the .h5 file for *hf_repo*. - Downloads the repo snapshot from HuggingFace Hub on first call and caches - it under $STABLEWM_HOME (default: ~/.stable_worldmodel/). - The repo is expected to contain exactly one .h5 / .hdf5 file. + On first call: finds the .tar.zst (or .h5) file in the HF repo, downloads + just that file, extracts it if needed, and caches the result under + $STABLEWM_HOME/datasets/--/. + Subsequent calls return the cached path immediately. """ import glob + import subprocess + cache_dir = os.path.join(_CACHE_DIR, "datasets", hf_repo.replace("/", "--")) - # Return cached file if already present - for ext in ("*.h5", "*.hdf5"): - matches = glob.glob(os.path.join(cache_dir, ext)) + os.makedirs(cache_dir, exist_ok=True) + + # Return cached .h5 if already extracted + for pattern in ("*.h5", "*.hdf5"): + matches = glob.glob(os.path.join(cache_dir, pattern)) if matches: return matches[0] - print(f" Downloading {hf_repo} from HuggingFace Hub...") - snapshot_download( + # Find the data file in the repo (expect one .tar.zst or .h5) + repo_files = list(list_repo_files(hf_repo, repo_type="dataset")) + data_file = next( + (f for f in repo_files if f.endswith(".tar.zst") or f.endswith(".h5") or f.endswith(".hdf5")), + None, + ) + assert data_file, f"No .h5 or .tar.zst file found in HF repo {hf_repo}. Files: {repo_files}" + + print(f" Downloading {hf_repo}/{data_file}...") + local_file = hf_hub_download( repo_id=hf_repo, + filename=data_file, repo_type="dataset", local_dir=cache_dir, - ignore_patterns=["*.tar.zst"], # skip raw archives if .h5 is published directly ) - # After download, check for .h5 first; fall back to extracting .tar.zst - for ext in ("*.h5", "*.hdf5"): - matches = glob.glob(os.path.join(cache_dir, ext)) - if matches: - return matches[0] - - # Extract any .tar.zst archive present - import subprocess - archives = glob.glob(os.path.join(cache_dir, "*.tar.zst")) - assert archives, f"No .h5 or .tar.zst found in {cache_dir} after downloading {hf_repo}" - for archive in archives: + if local_file.endswith(".tar.zst"): + print(f" Extracting {os.path.basename(local_file)}...") subprocess.run( - ["tar", "--use-compress-program=unzstd", "-xf", archive, "-C", cache_dir], + ["tar", "--use-compress-program=unzstd", "-xf", local_file, "-C", cache_dir], check=True, ) - os.remove(archive) + os.remove(local_file) matches = glob.glob(os.path.join(cache_dir, "*.h5")) + glob.glob(os.path.join(cache_dir, "*.hdf5")) - assert matches, f"No .h5 file found in {cache_dir} after extracting archives" + assert matches, f"No .h5 file found in {cache_dir} after extracting {data_file}" return matches[0] From 72c3fecf7bb2bd42925c6fc4d470dd83bf2220b7 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Wed, 1 Apr 2026 20:06:39 +0530 Subject: [PATCH 06/29] update --- examples/leWorldModel/create_data.py | 58 ++++++++++++++-------------- 1 file changed, 30 insertions(+), 28 deletions(-) diff --git a/examples/leWorldModel/create_data.py b/examples/leWorldModel/create_data.py index a73eb22..5731809 100644 --- a/examples/leWorldModel/create_data.py +++ b/examples/leWorldModel/create_data.py @@ -42,6 +42,7 @@ import io import os from collections.abc import Iterator +from concurrent.futures import ThreadPoolExecutor import h5py import hdf5plugin # noqa: F401 — registers HDF5 decompression filters (Blosc, Zstd, etc.) @@ -85,7 +86,8 @@ } JPEG_QUALITY = 95 # 95 → ~13× smaller than raw uint8, negligible quality loss -BATCH_ROWS = 1000 # rows per RecordBatch yielded to LanceDB +BATCH_ROWS = 4096 # rows read from HDF5 and written to LanceDB per chunk +JPEG_WORKERS = 8 # parallel threads for JPEG encoding within each batch # --------------------------------------------------------------------------- @@ -151,38 +153,38 @@ def _record_batch_reader( """ Return a pa.RecordBatchReader that streams the HDF5 file in BATCH_ROWS chunks. - Using a RecordBatchReader rather than repeated table.add() calls lets - LanceDB write all data in a single pass through the file without accumulating - large in-memory lists, and avoids creating many small Lance fragments. + Reads BATCH_ROWS rows at a time from HDF5 (one slice per column, not one + row at a time) and encodes the pixel batch in parallel with a thread pool. + This is the dominant speedup over the naive row-by-row approach. """ total = len(episode_arr) non_pixel_cols = [c for c in columns if c != "pixels"] + def _encode_frame(frame: np.ndarray) -> bytes: + return _to_jpeg_bytes(frame) + def _generate() -> Iterator[pa.RecordBatch]: - # Buffers — reset every BATCH_ROWS rows - ep_buf: list[int] = [] - st_buf: list[int] = [] - px_buf: list[bytes] = [] - ph_buf: list[int] = [] - pw_buf: list[int] = [] - col_bufs: dict[str, list[list[float]]] = {c: [] for c in non_pixel_cols} - - for idx in tqdm(range(total), desc=" Converting", unit="step"): - ep_buf.append(int(episode_arr[idx])) - st_buf.append(int(step_arr[idx])) - px_buf.append(_to_jpeg_bytes(np.array(f["pixels"][idx]))) - ph_buf.append(h) - pw_buf.append(w) - for col in non_pixel_cols: - col_bufs[col].append(np.array(f[col][idx], dtype=np.float32).flatten().tolist()) - - if len(ep_buf) == BATCH_ROWS: - yield _make_batch(ep_buf, st_buf, px_buf, ph_buf, pw_buf, col_bufs, schema) - ep_buf, st_buf, px_buf, ph_buf, pw_buf = [], [], [], [], [] - col_bufs = {c: [] for c in non_pixel_cols} - - if ep_buf: - yield _make_batch(ep_buf, st_buf, px_buf, ph_buf, pw_buf, col_bufs, schema) + with ThreadPoolExecutor(max_workers=JPEG_WORKERS) as pool: + for start in tqdm(range(0, total, BATCH_ROWS), desc=" Converting", unit="batch"): + end = min(start + BATCH_ROWS, total) + sl = slice(start, end) + + # One HDF5 slice read per column — orders of magnitude fewer I/O ops + pixels_raw = f["pixels"][sl] # (B, H, W, C) + col_data = {c: np.array(f[c][sl], dtype=np.float32) for c in non_pixel_cols} + + # Encode all frames in the batch concurrently + px_buf = list(pool.map(_encode_frame, pixels_raw)) + + yield _make_batch( + episode_arr[sl].tolist(), + step_arr[sl].tolist(), + px_buf, + [h] * len(px_buf), + [w] * len(px_buf), + {c: col_data[c].reshape(len(px_buf), -1).tolist() for c in non_pixel_cols}, + schema, + ) return pa.RecordBatchReader.from_batches(schema, _generate()) From ee75038a732c1e0ec1becf3f73f05b785b1ef3d9 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Wed, 1 Apr 2026 20:48:46 +0530 Subject: [PATCH 07/29] update --- examples/leWorldModel/README.md | 45 ++- examples/leWorldModel/config/lewm_pusht.yaml | 75 ++-- examples/leWorldModel/create_data.py | 12 +- .../leWorldModel/lewm_lance/dataloaders.py | 21 +- examples/leWorldModel/lewm_lance/dataset.py | 190 +++++----- examples/leWorldModel/train.py | 356 ++++++++++-------- 6 files changed, 394 insertions(+), 305 deletions(-) diff --git a/examples/leWorldModel/README.md b/examples/leWorldModel/README.md index d93cc1d..3c5bf20 100644 --- a/examples/leWorldModel/README.md +++ b/examples/leWorldModel/README.md @@ -140,10 +140,47 @@ python train.py --config config/lewm_pusht.yaml \ python train.py --config config/lewm_pusht.yaml --lance-uri s3://my-bucket/lewm ``` -### Step 5 — Post-training analysis - -After training, add LeWM encoder embeddings to the table for world-model-specific analysis. -These reflect what the trained model considers semantically similar — different from DINOv2. +### Step 5 — Post-training analysis with LeWM embeddings + +#### What we're doing and why + +After training, we run the trained LeWM encoder over every frame in the dataset +and store the resulting CLS-token vectors as a new `emb_lewm` column in the LanceDB +table. This lets us query the table using the world model's own learned similarity — +not pixel similarity (DINOv2/CLIP) but *dynamics similarity*: two frames are close +in `emb_lewm` space if the world model predicts them as leading to similar futures. + +#### What this reveals + +The le-wm paper shows that the encoder's latent space encodes meaningful physical +structure: it separates behaviorally distinct states and can be probed for quantities +like object position, velocity, and task progress. Adding `emb_lewm` to LanceDB +lets us do this interactively: + +- **Nearest-neighbor retrieval**: given a query frame, find the K training frames + the model considers most similar — sanity-checks whether the world model's + similarity makes physical sense. +- **DINOv2 vs LeWM comparison**: the same two frames may be far apart in DINOv2 + space (different appearance) but close in LeWM space (same task phase), or vice + versa. Comparing the two embedding columns directly shows what the model has + learned to *ignore* (irrelevant visual details) and what it *attends to* + (task-relevant structure). +- **Clustering / UMAP**: exporting `emb_lewm` → UMAP reveals whether the latent + space organises into interpretable clusters (e.g. "reaching", "grasping", + "releasing" in a manipulation task). +- **Failure diagnosis**: episodes where val loss is high can be retrieved by + their `emb_lewm` vectors and inspected — often revealing a sub-behaviour the + model hasn't learned well. + +#### Did the paper authors do this? + +The le-wm paper validates the latent space through *probing* — training small linear +heads on top of frozen encoder embeddings to predict physical quantities (object +position, velocity). This is the standard JEPA evaluation protocol. Those probing +scripts are not in the public repo, but the technique is identical to what we do +here: freeze the trained encoder, run it over the dataset, store the vectors, then +analyse them. Storing them in LanceDB rather than a separate file means the analysis +is a single ANN query away. ```bash python create_data.py \ diff --git a/examples/leWorldModel/config/lewm_pusht.yaml b/examples/leWorldModel/config/lewm_pusht.yaml index b5640e3..7b2ac6e 100644 --- a/examples/leWorldModel/config/lewm_pusht.yaml +++ b/examples/leWorldModel/config/lewm_pusht.yaml @@ -1,42 +1,57 @@ # leWorldModel training config — PushT dataset -# Mirrors le-wm/config/train/lewm.yaml + data/pusht.yaml +# Mirrors le-wm/config/train/lewm.yaml + config/data/pusht.yaml +# LanceDB-specific keys added under 'data:' (not present in original le-wm). +# All other sections are a strict superset of the original le-wm config. -model: - encoder_name: "vit_tiny_patch14_224" - image_size: 224 - embed_dim: 192 - history_size: 3 - num_preds: 1 - predictor_depth: 6 - predictor_heads: 16 - predictor_mlp_dim: 2048 - proj_hidden: 2048 - # action_dim is inferred automatically from the first training batch +seed: 42 +img_size: 224 -data: - lance_uri: "./lewm_lance" # override with s3://bucket/prefix for cloud - table_name: "lewm_pusht" - columns: ["pixels", "action", "proprio", "state"] +trainer: + max_epochs: 100 + precision: "bf16-mixed" # PyTorch Lightning format (le-wm uses "bf16" via Hydra) + gradient_clip_val: 1.0 + log_every_n_steps: 50 + save_every_n_epochs: 10 + checkpoint_dir: "./checkpoints" + +loader: + batch_size: 128 num_workers: 6 prefetch_factor: 3 - val_fraction: 0.1 - seed: 42 -training: - batch_size: 128 - max_epochs: 100 +optimizer: lr: 5.0e-5 weight_decay: 1.0e-3 - warmup_epochs: 10 - gradient_clip: 1.0 - save_every_n_epochs: 10 - checkpoint_dir: "./checkpoints" - precision: "bf16-mixed" # bf16-mixed | 16-mixed | 32 - log_every_n_steps: 50 + warmup_epochs: 10 # added: controls LinearWarmup → Cosine schedule + +wm: + history_size: 3 + num_preds: 1 + embed_dim: 192 + encoder_name: "vit_tiny_patch14_224" # added: timm model identifier + proj_hidden: 2048 # added: projector MLP hidden dim + +predictor: + depth: 6 + heads: 16 + mlp_dim: 2048 + dim_head: 64 + dropout: 0.1 + emb_dropout: 0.0 loss: - sigreg_weight: 0.09 - sigreg_knots: 17 - sigreg_num_proj: 1024 + sigreg: + weight: 0.09 + kwargs: + knots: 17 + num_proj: 1024 + +# ── LanceDB-specific additions (not present in original le-wm) ──────────────── +data: + lance_uri: "./lewm_lance" # override with s3://bucket/prefix for cloud + table_name: "lewm_pusht" + columns: ["pixels", "action", "proprio", "state"] + frameskip: 5 # matches le-wm paper default + val_fraction: 0.1 # episode-level hold-out fraction wandb_project: "lewm-lancedb" diff --git a/examples/leWorldModel/create_data.py b/examples/leWorldModel/create_data.py index 5731809..2c072b1 100644 --- a/examples/leWorldModel/create_data.py +++ b/examples/leWorldModel/create_data.py @@ -239,10 +239,12 @@ def _ensure_hdf5(hf_repo: str) -> str: if matches: return matches[0] - # Find the data file in the repo (expect one .tar.zst or .h5) + # Find the data file in the repo (expect one .tar.zst, .h5.zst, or bare .h5/.hdf5) repo_files = list(list_repo_files(hf_repo, repo_type="dataset")) data_file = next( - (f for f in repo_files if f.endswith(".tar.zst") or f.endswith(".h5") or f.endswith(".hdf5")), + (f for f in repo_files + if f.endswith(".tar.zst") or f.endswith(".h5.zst") + or f.endswith(".h5") or f.endswith(".hdf5")), None, ) assert data_file, f"No .h5 or .tar.zst file found in HF repo {hf_repo}. Files: {repo_files}" @@ -262,6 +264,12 @@ def _ensure_hdf5(hf_repo: str) -> str: check=True, ) os.remove(local_file) + elif local_file.endswith(".h5.zst"): + # Bare zstd-compressed HDF5 (no tar wrapper) — decompress with zstd + out_path = local_file[:-4] # strip .zst → .h5 + print(f" Decompressing {os.path.basename(local_file)} → {os.path.basename(out_path)}...") + subprocess.run(["zstd", "-d", local_file, "-o", out_path], check=True) + os.remove(local_file) matches = glob.glob(os.path.join(cache_dir, "*.h5")) + glob.glob(os.path.join(cache_dir, "*.hdf5")) assert matches, f"No .h5 file found in {cache_dir} after extracting {data_file}" diff --git a/examples/leWorldModel/lewm_lance/dataloaders.py b/examples/leWorldModel/lewm_lance/dataloaders.py index 26c7a38..519d368 100644 --- a/examples/leWorldModel/lewm_lance/dataloaders.py +++ b/examples/leWorldModel/lewm_lance/dataloaders.py @@ -36,6 +36,7 @@ def make_lewm_lance_loader( columns: list[str], batch_size: int, num_steps: int = 4, + frameskip: int = 5, img_size: int = 224, num_workers: int = 6, prefetch_factor: int = 3, @@ -45,24 +46,15 @@ def make_lewm_lance_loader( """ Build a DataLoader over a LanceDB leWorldModel table. - Args: - uri: LanceDB URI. - table_name: Table name (as created by create_data.py). - columns: Columns to load, e.g. ["pixels", "action", "proprio"]. - batch_size: Training batch size B. - num_steps: Window length T = history_size + num_preds. - img_size: Resize target (square) applied before ImageNet normalization. - num_workers: DataLoader worker processes. - prefetch_factor: Batches queued per worker. - normalizers: {col: (mean, std)} from compute_normalizers(); if None, - no normalization is applied to non-pixel columns. - **connect_kwargs: Forwarded to lancedb.connect() (api_key, host_override, …). + frameskip=5 matches the le-wm paper default. With T=4 and frameskip=5, + each window spans 20 raw rows; action is reshaped to (T, 5×action_dim). """ dataset = LeWMLanceDataset( uri=uri, table_name=table_name, columns=columns, num_steps=num_steps, + frameskip=frameskip, img_size=img_size, normalizers=normalizers, **connect_kwargs, @@ -76,6 +68,7 @@ def make_train_val_loaders( columns: list[str], batch_size: int, num_steps: int = 4, + frameskip: int = 5, img_size: int = 224, num_workers: int = 6, prefetch_factor: int = 3, @@ -111,8 +104,8 @@ def make_train_val_loaders( normalizers = compute_normalizers(uri, table_name, columns, **connect_kwargs) # Build full datasets then restrict _window_starts by episode membership - train_ds = LeWMLanceDataset(uri, table_name, columns, num_steps, img_size, normalizers, **connect_kwargs) - val_ds = LeWMLanceDataset(uri, table_name, columns, num_steps, img_size, normalizers, **connect_kwargs) + train_ds = LeWMLanceDataset(uri, table_name, columns, num_steps, frameskip, img_size, normalizers, **connect_kwargs) + val_ds = LeWMLanceDataset(uri, table_name, columns, num_steps, frameskip, img_size, normalizers, **connect_kwargs) train_ep_mask = np.isin(train_ds._ep[train_ds._window_starts], list(train_episodes)) val_ep_mask = np.isin(val_ds._ep[val_ds._window_starts], list(val_episodes)) diff --git a/examples/leWorldModel/lewm_lance/dataset.py b/examples/leWorldModel/lewm_lance/dataset.py index 4b38879..9377b9c 100644 --- a/examples/leWorldModel/lewm_lance/dataset.py +++ b/examples/leWorldModel/lewm_lance/dataset.py @@ -1,22 +1,30 @@ """ LanceDB-backed PyTorch Dataset for leWorldModel temporal sequences. -leWorldModel trains on windows of T consecutive frames from the same episode: - T = history_size (3) + num_preds (1) = 4 by default +leWorldModel trains on windows of T frames with a configurable frameskip (default 5, +matching the original le-wm paper): + T = history_size (3) + num_preds (1) = 4 frames + span = T × frameskip = 20 raw rows per window + +Frameskip mirrors the original HDF5Dataset behaviour: + - Pixels: sampled at stride frameskip → (T, C, H, W) + - Actions: ALL span rows kept, reshaped to (T, frameskip × action_dim) + This matches le-wm's effective_act_dim = frameskip × action_dim. + - Other columns (proprio, state, observation): sampled at stride frameskip → (T, D) Each dataset item is a dict of tensors: - "pixels" : (T, C, H, W) float32 ImageNet-normalized - "action" : (T, A) float32 z-score normalized, NaN→0 - "proprio" : (T, P) float32 z-score normalized [if present] + "pixels" : (T, C, H, W) float32 ImageNet-normalized + "action" : (T, frameskip×A) float32 z-score normalized, NaN→0 + "proprio" : (T, P) float32 [if present] ... Design: - - One LanceDB row = one timestep. + - One LanceDB row = one raw timestep. - Window index is precomputed at __init__ so __getitems__ only does I/O. - Permutation object (Rust state) is zeroed before pickling; each worker lazily reopens its own connection inside _ensure_open(). - - __getitems__ batches the full (B*T) row fetch in a single Permutation call, - then splits into per-sample dicts — same pattern as ViT dataloaders.py. + - __getitems__ batches the full (B × span) row fetch in a single Permutation + call, then splits into per-sample dicts. """ import io @@ -54,30 +62,20 @@ def compute_normalizers( **connect_kwargs, ) -> dict[str, tuple[np.ndarray, np.ndarray]]: """ - Compute per-column (mean, std) arrays for z-score normalization. - - Reads only the requested non-pixel columns using LanceDB's column projection - so no pixel data is loaded. For large datasets this streams in batches via - the Arrow scanner rather than materializing the full table at once. - - Returns: - {col: (mean_array, std_array)} — each array has shape (D,). + Compute per-column (mean, std) for z-score normalization. + Only reads non-pixel columns; no pixel data loaded. """ - db = lancedb.connect(uri, **connect_kwargs) + db = lancedb.connect(uri, **connect_kwargs) tbl = db.open_table(table_name) non_pixel = [c for c in columns if c != "pixels"] if not non_pixel: return {} - - # Use column-projected Arrow read — loads only the requested columns. - # episode_idx + step_idx columns are tiny; the vector columns are float32 - # lists already in Arrow format, so this is as efficient as possible. arrow = tbl.to_arrow(columns=non_pixel) normalizers = {} for col in non_pixel: - data = np.stack([row.as_py() for row in arrow[col]], axis=0).astype(np.float32) + data = np.stack([row.as_py() for row in arrow[col]], axis=0).astype(np.float32) valid = ~np.isnan(data).any(axis=1) - data = data[valid] + data = data[valid] normalizers[col] = (data.mean(axis=0), data.std(axis=0)) return normalizers @@ -89,11 +87,15 @@ class LeWMLanceDataset(torch.utils.data.Dataset): Args: uri: LanceDB URI (local path or s3://…). table_name: Name of the table created by create_data.py. - columns: List of column names to return, e.g. ["pixels","action","proprio"]. + columns: List of column names to return. num_steps: Window length T (= history_size + num_preds). - img_size: Target image size after resize (square). - normalizers: Output of compute_normalizers(); used to z-score non-pixel columns. - **connect_kwargs: Passed to lancedb.connect() (api_key, host_override, region, …). + frameskip: Stride between sampled frames. Matches le-wm default of 5. + With frameskip=5 and T=4, each window spans 20 raw rows. + action is kept at full resolution and reshaped to + (T, frameskip × action_dim); all other columns are strided. + img_size: Target image size after resize. + normalizers: {col: (mean, std)} from compute_normalizers(). + **connect_kwargs: Passed to lancedb.connect(). """ def __init__( @@ -102,66 +104,54 @@ def __init__( table_name: str, columns: list[str], num_steps: int = 4, + frameskip: int = 5, img_size: int = 224, normalizers: dict | None = None, **connect_kwargs, ): - self.uri = uri - self.table_name = table_name - self.columns = columns - self.num_steps = num_steps - self.img_size = img_size + self.uri = uri + self.table_name = table_name + self.columns = columns + self.num_steps = num_steps + self.frameskip = frameskip + self.img_size = img_size self.normalizers = normalizers or {} self.connect_kwargs = connect_kwargs + self._span = num_steps * frameskip # raw rows per window - # Rust Permutation — zeroed before pickling, rebuilt per-worker - self._perm: Permutation | None = None + self._perm: Permutation | None = None self._transform: transforms.Compose | None = None - # ------------------------------------------------------------------ - # Eagerly load the episode/step index to precompute valid windows. - # These are only two int32 columns — ~8 bytes/row regardless of - # dataset size, so loading them fully is fine. - # ------------------------------------------------------------------ - db = lancedb.connect(uri, **connect_kwargs) + # Load only the two index columns to precompute valid windows. + db = lancedb.connect(uri, **connect_kwargs) tbl = db.open_table(table_name) - idx_arrow = tbl.to_arrow(columns=["episode_idx", "step_idx"]) - self._ep = idx_arrow["episode_idx"].to_numpy().astype(np.int32) - self._step = idx_arrow["step_idx"].to_numpy().astype(np.int32) + idx = tbl.to_arrow(columns=["episode_idx", "step_idx"]) + self._ep = idx["episode_idx"].to_numpy().astype(np.int32) + self._step = idx["step_idx"].to_numpy().astype(np.int32) self._n_rows = len(self._ep) - # Precompute valid window start rows. - # A window starting at row i is valid iff rows i..i+T-1 are all in - # the same episode and have consecutive step indices. - T = num_steps - N = self._n_rows - T + 1 + # A window starting at row i is valid iff all span rows are in the same + # episode with consecutive step indices. + span = self._span + N = self._n_rows - span + 1 valid = np.ones(N, dtype=bool) - for offset in range(1, T): - same_ep = self._ep[offset : N + offset] == self._ep[:N] - consec = self._step[offset : N + offset] == self._step[:N] + offset - valid &= same_ep & consec - - # _window_starts[i] = absolute row index for the i-th valid window + for offset in range(1, span): + valid &= (self._ep[offset : N + offset] == self._ep[:N]) + valid &= (self._step[offset : N + offset] == self._step[:N] + offset) self._window_starts = np.where(valid)[0].astype(np.int64) - # ---------------------------------------------------------------------- # - # PyTorch Dataset protocol - # ---------------------------------------------------------------------- # - def __len__(self) -> int: return len(self._window_starts) def __getstate__(self) -> dict: - """Zero out Rust state before the object is pickled for a worker process.""" state = self.__dict__.copy() - state["_perm"] = None + state["_perm"] = None state["_transform"] = None return state def _ensure_open(self): - """Lazily open DB connection + Permutation once per worker process.""" if self._perm is None: - db = lancedb.connect(self.uri, **self.connect_kwargs) + db = lancedb.connect(self.uri, **self.connect_kwargs) tbl = db.open_table(self.table_name) fetch_cols = ["pixels"] + [c for c in self.columns if c != "pixels"] self._perm = ( @@ -171,61 +161,71 @@ def _ensure_open(self): ) self._transform = _build_img_transform(self.img_size) - # ---------------------------------------------------------------------- # - # Internal: convert a RecordBatch of T rows into a sample dict - # ---------------------------------------------------------------------- # - def _rows_to_sample(self, batch: pa.RecordBatch) -> dict[str, torch.Tensor]: - T = self.num_steps - assert len(batch) == T + """ + Convert a RecordBatch of `span` raw rows into one training sample. + + Pixels and non-action columns: take every frameskip-th row → T frames. + Action: keep all span rows, reshape to (T, frameskip × action_dim). + """ + T = self.num_steps + frameskip = self.frameskip + assert len(batch) == self._span - # Decode JPEG pixels → (T, C, H, W) + # Pixels: stride by frameskip → T frames jpeg_list = batch["pixels"].to_pylist() - frames = torch.stack([_jpeg_to_tensor(b, self._transform) for b in jpeg_list]) + frames = torch.stack( + [_jpeg_to_tensor(jpeg_list[t * frameskip], self._transform) for t in range(T)] + ) sample: dict[str, torch.Tensor] = {"pixels": frames} for col in self.columns: if col == "pixels": continue - data = np.array([batch[col][t].as_py() for t in range(T)], dtype=np.float32) - if col in self.normalizers: - mean, std = self.normalizers[col] - data = (data - mean) / (std + 1e-8) - data = np.nan_to_num(data, nan=0.0) + + if col == "action": + # All span steps → (span, action_dim) → (T, frameskip × action_dim) + data = np.array( + [batch[col][i].as_py() for i in range(self._span)], + dtype=np.float32, + ) + data = np.nan_to_num(data, nan=0.0) + data = data.reshape(T, -1) # (T, frameskip × action_dim) + else: + # Stride by frameskip → T steps + data = np.array( + [batch[col][t * frameskip].as_py() for t in range(T)], + dtype=np.float32, + ) + if col in self.normalizers: + mean, std = self.normalizers[col] + data = (data - mean) / (std + 1e-8) + data = np.nan_to_num(data, nan=0.0) + sample[col] = torch.from_numpy(data) return sample - # ---------------------------------------------------------------------- # - # Single-item access (used when num_workers=0) - # ---------------------------------------------------------------------- # - def __getitem__(self, window_idx: int) -> dict[str, torch.Tensor]: self._ensure_open() start = int(self._window_starts[window_idx]) - rows = list(range(start, start + self.num_steps)) + rows = list(range(start, start + self._span)) batch = self._perm.__getitems__(rows) return self._rows_to_sample(batch) - # ---------------------------------------------------------------------- # - # Batch access — called by DataLoader with num_workers > 0. - # Fetches all B*T rows in ONE Permutation call instead of B calls. - # ---------------------------------------------------------------------- # - def __getitems__(self, window_indices: list[int]) -> list[dict[str, torch.Tensor]]: + """Fetch all B × span rows in one Permutation call.""" self._ensure_open() - T = self.num_steps - starts = self._window_starts[window_indices] # (B,) + span = self._span + starts = self._window_starts[window_indices] - # Build flat list: [w0_t0, w0_t1, …, w1_t0, w1_t1, …] all_rows: list[int] = [] for s in starts: - all_rows.extend(range(int(s), int(s) + T)) + all_rows.extend(range(int(s), int(s) + span)) - big_batch: pa.RecordBatch = self._perm.__getitems__(all_rows) # (B*T, cols) + big_batch: pa.RecordBatch = self._perm.__getitems__(all_rows) - samples = [] - for b in range(len(window_indices)): - row_slice = big_batch.slice(b * T, T) - samples.append(self._rows_to_sample(row_slice)) - return samples + return [ + self._rows_to_sample(big_batch.slice(b * span, span)) + for b in range(len(window_indices)) + ] diff --git a/examples/leWorldModel/train.py b/examples/leWorldModel/train.py index b76b50c..048f5be 100644 --- a/examples/leWorldModel/train.py +++ b/examples/leWorldModel/train.py @@ -16,9 +16,6 @@ --aws-access-key-id AKIA... \\ --aws-secret-access-key ... - # S3-backed store, credentials via environment (AWS_ACCESS_KEY_ID etc.) - python train.py --config config/lewm_pusht.yaml --lance-uri s3://my-bucket/lewm - # Override table and columns without editing the config python train.py --config config/lewm_pusht.yaml \\ --table-name lewm_reacher \\ @@ -31,17 +28,71 @@ sys.path.insert(0, os.path.dirname(__file__)) -import pytorch_lightning as pl import timm import torch +import torch.nn as nn import yaml +import pytorch_lightning as pl +from pathlib import Path +from pytorch_lightning.callbacks import Callback from pytorch_lightning.loggers import WandbLogger from jepa import JEPA from lewm_lance import make_train_val_loaders from module import ARPredictor, Embedder, MLP, SIGReg -from stable_worldmodel.optim import LinearWarmupCosineAnnealingLR -from utils import ModelObjectCallBack + + +# --------------------------------------------------------------------------- +# Encoder wrapper +# +# JEPA.encode() calls: +# output = self.encoder(pixels, interpolate_pos_encoding=True) +# pixels_emb = output.last_hidden_state[:, 0] ← CLS token +# +# timm ViTs return (B, D) when num_classes=0, not a structured object. +# We wrap forward_features() — which returns (B, N, D) with CLS at [:, 0] — +# to satisfy the HuggingFace-style interface JEPA expects. +# --------------------------------------------------------------------------- + +class _TimmViTOutput: + __slots__ = ("last_hidden_state",) + def __init__(self, last_hidden_state): + self.last_hidden_state = last_hidden_state + + +class TimmViT(nn.Module): + def __init__(self, model_name: str, img_size: int): + super().__init__() + self._model = timm.create_model( + model_name, pretrained=False, img_size=img_size, num_classes=0 + ) + self.embed_dim = self._model.embed_dim + + def forward(self, x, interpolate_pos_encoding=False): + # forward_features → (B, N, D) where N[0] is the CLS token + return _TimmViTOutput(self._model.forward_features(x)) + + +# --------------------------------------------------------------------------- +# Checkpoint callback (inlined to avoid stable_pretraining dependency) +# --------------------------------------------------------------------------- + +class ModelObjectCallBack(Callback): + """Save the raw model object (torch.save) at the end of every epoch_interval epochs.""" + + def __init__(self, dirpath, filename="model_object", epoch_interval: int = 1): + super().__init__() + self.dirpath = Path(dirpath) + self.filename = filename + self.epoch_interval = epoch_interval + + def on_train_epoch_end(self, trainer, pl_module): + epoch = trainer.current_epoch + 1 + if not trainer.is_global_zero: + return + if epoch % self.epoch_interval == 0 or epoch == trainer.max_epochs: + path = self.dirpath / f"{self.filename}_epoch_{epoch}_object.ckpt" + torch.save(pl_module.model, path) # --------------------------------------------------------------------------- @@ -50,12 +101,18 @@ class LeWMLightning(pl.LightningModule): """ - PyTorch Lightning wrapper around the JEPA world model. - - Accepts the batch dict produced by make_train_val_loaders: - "pixels" : (B, T, C, H, W) float32 - "action" : (B, T, A) float32 - ...additional columns... + PyTorch Lightning wrapper around the LeWM JEPA world model. + + Training loss (mirrors lejepa_forward in the original le-wm/train.py): + 1. Encode the pixel+action sequence into embedding space. + 2. Predict next-state embeddings autoregressively from the context window. + 3. Prediction loss: MSE between predicted and actual next embeddings. + 4. SIGReg loss: Sketch Isotropic Gaussian Regularizer — keeps the latent + distribution well-shaped, preventing collapse and mode-dropping. + + Note: JEPA.criterion() is the MPC planning cost (comparing predicted rollouts + to a goal embedding during evaluation). It is NOT the training loss and is + never called here. """ def __init__(self, model: JEPA, sigreg: SIGReg, cfg: dict): @@ -66,9 +123,28 @@ def __init__(self, model: JEPA, sigreg: SIGReg, cfg: dict): self.save_hyperparameters(ignore=["model", "sigreg"]) def _shared_step(self, batch: dict, stage: str) -> torch.Tensor: - loss_pred = self.model.criterion(batch)["loss"] - loss_reg = self.sigreg(self.model.last_embeddings) - loss = loss_pred + self.cfg["sigreg_weight"] * loss_reg + ctx_len = self.cfg["history_size"] + n_preds = self.cfg["num_preds"] + + # NaN occurs at sequence boundaries (padding); zero it out + batch["action"] = torch.nan_to_num(batch["action"], 0.0) + + # Encode pixels and actions → (B, T, embed_dim) each + output = self.model.encode(batch) + emb = output["emb"] # (B, T, D) + act_emb = output["act_emb"] # (B, T, D) + + # Predict next states from the context window (first history_size frames) + ctx_emb = emb[:, :ctx_len] + ctx_act = act_emb[:, :ctx_len] + tgt_emb = emb[:, n_preds:] # ground-truth targets (shifted by n_preds) + pred_emb = self.model.predict(ctx_emb, ctx_act) + + # SIGReg expects (T, B, D) — transpose time and batch dims + loss_pred = (pred_emb - tgt_emb).pow(2).mean() + loss_reg = self.sigreg(emb.transpose(0, 1)) + loss = loss_pred + self.cfg["sigreg_weight"] * loss_reg + self.log(f"{stage}/loss_pred", loss_pred, on_step=(stage == "train"), on_epoch=True, prog_bar=True) self.log(f"{stage}/loss_reg", loss_reg, on_step=(stage == "train"), on_epoch=True) self.log(f"{stage}/loss", loss, on_step=(stage == "train"), on_epoch=True, prog_bar=True) @@ -86,10 +162,17 @@ def configure_optimizers(self): lr=self.cfg["lr"], weight_decay=self.cfg["weight_decay"], ) - sched = LinearWarmupCosineAnnealingLR( - opt, - warmup_epochs=self.cfg["warmup_epochs"], - max_epochs=self.cfg["max_epochs"], + warmup = self.cfg["warmup_epochs"] + total = self.cfg["max_epochs"] + # Linear warmup → cosine decay (matches le-wm's LinearWarmupCosineAnnealingLR) + warmup_sched = torch.optim.lr_scheduler.LinearLR( + opt, start_factor=1e-4, end_factor=1.0, total_iters=warmup + ) + cosine_sched = torch.optim.lr_scheduler.CosineAnnealingLR( + opt, T_max=max(total - warmup, 1), eta_min=0 + ) + sched = torch.optim.lr_scheduler.SequentialLR( + opt, schedulers=[warmup_sched, cosine_sched], milestones=[warmup] ) return [opt], [{"scheduler": sched, "interval": "epoch"}] @@ -98,32 +181,42 @@ def configure_optimizers(self): # Model construction # --------------------------------------------------------------------------- -def build_model(cfg: dict) -> tuple[JEPA, SIGReg]: - m = cfg["model"] +def build_model(cfg: dict, effective_act_dim: int) -> tuple[JEPA, SIGReg]: + """ + Build the LeWM JEPA model from config. - encoder = timm.create_model( - m["encoder_name"], - pretrained=False, - img_size=m["image_size"], - num_classes=0, - ) + effective_act_dim = frameskip × raw_action_dim. + With the default frameskip=5, five consecutive raw action steps are stacked + into one frame-level action vector, so the Embedder sees a larger input. + """ + wm = cfg["wm"] + pred = cfg["predictor"] + encoder = TimmViT(wm["encoder_name"], cfg["img_size"]) + hidden_dim = encoder.embed_dim # ViT-tiny: 192 + + # ARPredictor: input_dim and hidden_dim can differ. + # Here we keep them equal (both embed_dim), matching le-wm defaults. predictor = ARPredictor( - embed_dim=m["embed_dim"], - depth=m["predictor_depth"], - num_heads=m["predictor_heads"], - mlp_dim=m["predictor_mlp_dim"], - max_seq_len=m["history_size"] + m["num_preds"], + num_frames=wm["history_size"], + input_dim=wm["embed_dim"], + hidden_dim=hidden_dim, + depth=pred["depth"], + heads=pred["heads"], + mlp_dim=pred["mlp_dim"], + dim_head=pred["dim_head"], + dropout=pred["dropout"], + emb_dropout=pred["emb_dropout"], ) action_encoder = Embedder( - in_dim=m["action_dim"], - out_dim=m["embed_dim"], + input_dim=effective_act_dim, + emb_dim=wm["embed_dim"], ) - encoder_dim = encoder.embed_dim - projector = MLP(encoder_dim, m["proj_hidden"], m["embed_dim"]) - pred_proj = MLP(m["embed_dim"], m["proj_hidden"], m["embed_dim"]) + # MLP(input_dim, hidden_dim, output_dim, ...) — norm_fn=BatchNorm1d matches le-wm defaults + projector = MLP(hidden_dim, wm["proj_hidden"], wm["embed_dim"], norm_fn=torch.nn.BatchNorm1d) + pred_proj = MLP(wm["embed_dim"], wm["proj_hidden"], wm["embed_dim"], norm_fn=torch.nn.BatchNorm1d) model = JEPA( encoder=encoder, @@ -131,14 +224,13 @@ def build_model(cfg: dict) -> tuple[JEPA, SIGReg]: action_encoder=action_encoder, projector=projector, pred_proj=pred_proj, - history_size=m["history_size"], - num_preds=m["num_preds"], ) + # SIGReg: knots and num_proj only — no embed_dim + sigreg_cfg = cfg["loss"]["sigreg"] sigreg = SIGReg( - embed_dim=m["embed_dim"], - knots=cfg["loss"]["sigreg_knots"], - num_proj=cfg["loss"]["sigreg_num_proj"], + knots=sigreg_cfg["kwargs"]["knots"], + num_proj=sigreg_cfg["kwargs"]["num_proj"], ) return model, sigreg @@ -149,45 +241,22 @@ def build_model(cfg: dict) -> tuple[JEPA, SIGReg]: # --------------------------------------------------------------------------- def build_storage_options(args: argparse.Namespace) -> dict: - """ - Build the storage_options dict for lancedb.connect() from CLI args, - falling back to standard AWS environment variables. - - LanceDB passes storage_options directly to the Rust object_store library, - which accepts these keys for S3: - aws_access_key_id, aws_secret_access_key, aws_session_token, - region, endpoint_url, aws_virtual_hosted_style_request - - Environment variable fallbacks follow the standard AWS SDK convention: - AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_SESSION_TOKEN, - AWS_DEFAULT_REGION, AWS_ENDPOINT_URL - - Returns an empty dict for local URIs (no storage_options needed). - """ uri = args.lance_uri if not (uri.startswith("s3://") or uri.startswith("gs://") or uri.startswith("az://")): return {} - opts: dict[str, str] = {} - - access_key = args.aws_access_key_id or os.environ.get("AWS_ACCESS_KEY_ID") - secret_key = args.aws_secret_access_key or os.environ.get("AWS_SECRET_ACCESS_KEY") - session_token = args.aws_session_token or os.environ.get("AWS_SESSION_TOKEN") - region = args.aws_region or os.environ.get("AWS_DEFAULT_REGION") or os.environ.get("AWS_REGION") - endpoint = args.s3_endpoint or os.environ.get("AWS_ENDPOINT_URL") - - if access_key: - opts["aws_access_key_id"] = access_key - if secret_key: - opts["aws_secret_access_key"] = secret_key - if session_token: - opts["aws_session_token"] = session_token - if region: - opts["region"] = region + access_key = args.aws_access_key_id or os.environ.get("AWS_ACCESS_KEY_ID") + secret_key = args.aws_secret_access_key or os.environ.get("AWS_SECRET_ACCESS_KEY") + session_token = args.aws_session_token or os.environ.get("AWS_SESSION_TOKEN") + region = args.aws_region or os.environ.get("AWS_DEFAULT_REGION") or os.environ.get("AWS_REGION") + endpoint = args.s3_endpoint or os.environ.get("AWS_ENDPOINT_URL") + if access_key: opts["aws_access_key_id"] = access_key + if secret_key: opts["aws_secret_access_key"] = secret_key + if session_token: opts["aws_session_token"] = session_token + if region: opts["region"] = region if endpoint: opts["endpoint_url"] = endpoint opts["aws_virtual_hosted_style_request"] = "false" - return opts @@ -200,101 +269,62 @@ def _build_parser() -> argparse.ArgumentParser: description="Train leWorldModel with LanceDB data backend", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - - parser.add_argument( - "--config", - default="config/lewm_pusht.yaml", - help="Path to YAML training config", - ) - parser.add_argument( - "--lance-uri", - default=None, - help="LanceDB URI. Defaults to data.lance_uri in config. " - "Use s3://bucket/prefix for cloud storage.", - ) - parser.add_argument( - "--table-name", - default=None, - help="LanceDB table name. Defaults to data.table_name in config.", - ) - parser.add_argument( - "--columns", - nargs="+", - default=None, - help="Columns to load. Defaults to data.columns in config.", - ) - parser.add_argument( - "--run-name", - default=None, - help="WandB run name. Defaults to -steps.", - ) - parser.add_argument( - "--no-wandb", - action="store_true", - help="Disable WandB logging.", - ) - - s3 = parser.add_argument_group( - "S3 storage options", - "Credentials for S3-backed LanceDB tables. All args fall back to " - "standard AWS environment variables (AWS_ACCESS_KEY_ID, etc.).", - ) + parser.add_argument("--config", default="config/lewm_pusht.yaml") + parser.add_argument("--lance-uri", default=None) + parser.add_argument("--table-name", default=None) + parser.add_argument("--columns", nargs="+", default=None) + parser.add_argument("--run-name", default=None) + parser.add_argument("--no-wandb", action="store_true") + s3 = parser.add_argument_group("S3 storage options") s3.add_argument("--aws-access-key-id", default=None, metavar="KEY") s3.add_argument("--aws-secret-access-key", default=None, metavar="SECRET") - s3.add_argument("--aws-session-token", default=None, metavar="TOKEN", - help="Temporary session token (STS / IAM role assumed credentials).") - s3.add_argument("--aws-region", default=None, metavar="REGION", - help="AWS region, e.g. us-east-1.") - s3.add_argument("--s3-endpoint", default=None, metavar="URL", - help="Custom S3-compatible endpoint (MinIO, R2, etc.).") - + s3.add_argument("--aws-session-token", default=None, metavar="TOKEN") + s3.add_argument("--aws-region", default=None, metavar="REGION") + s3.add_argument("--s3-endpoint", default=None, metavar="URL") return parser def main(): parser = _build_parser() - args = parser.parse_args() + args = parser.parse_args() with open(args.config) as f: cfg = yaml.safe_load(f) - data_cfg = cfg["data"] - model_cfg = cfg["model"] - train_cfg = cfg["training"] - loss_cfg = cfg.get("loss", {}) + data_cfg = cfg["data"] + loader_cfg = cfg["loader"] + trainer_cfg = cfg["trainer"] + opt_cfg = cfg["optimizer"] + wm_cfg = cfg["wm"] - # Resolve: CLI arg > config value > hardcoded fallback lance_uri = args.lance_uri or data_cfg.get("lance_uri", "./lewm_lance") table_name = args.table_name or data_cfg.get("table_name") columns = args.columns or data_cfg["columns"] - num_steps = model_cfg["history_size"] + model_cfg["num_preds"] + num_steps = wm_cfg["history_size"] + wm_cfg["num_preds"] + frameskip = data_cfg.get("frameskip", 1) if table_name is None: - parser.error("table_name is required: set data.table_name in config or pass --table-name") + parser.error("table_name required: set data.table_name in config or pass --table-name") storage_options = build_storage_options(args) - if storage_options: - print(f" S3 storage: region={storage_options.get('region', 'env')}" - + (f" endpoint={storage_options['endpoint_url']}" if "endpoint_url" in storage_options else "")) + connect_kwargs = {"storage_options": storage_options} if storage_options else {} # ------------------------------------------------------------------ # # Data # ------------------------------------------------------------------ # - # storage_options is only passed for cloud URIs; empty dict for local paths - connect_kwargs = {"storage_options": storage_options} if storage_options else {} - print(f"Building DataLoaders ({lance_uri} / {table_name})...") train_loader, val_loader = make_train_val_loaders( uri=lance_uri, table_name=table_name, columns=columns, - batch_size=train_cfg["batch_size"], + batch_size=loader_cfg["batch_size"], num_steps=num_steps, - img_size=model_cfg["image_size"], - num_workers=data_cfg["num_workers"], - prefetch_factor=data_cfg["prefetch_factor"], + frameskip=frameskip, + img_size=cfg["img_size"], + num_workers=loader_cfg["num_workers"], + prefetch_factor=loader_cfg["prefetch_factor"], val_fraction=data_cfg["val_fraction"], - seed=data_cfg["seed"], + seed=cfg["seed"], **connect_kwargs, ) print(f" Train batches: {len(train_loader):,} | Val batches: {len(val_loader):,}") @@ -302,24 +332,35 @@ def main(): # ------------------------------------------------------------------ # # Model # ------------------------------------------------------------------ # + # Infer effective_act_dim from the first batch (action shape: B, T, eff_dim) + sample_batch = next(iter(train_loader)) + effective_act_dim = sample_batch["action"].shape[-1] + print(f" effective_act_dim={effective_act_dim} " + f"(frameskip={frameskip} × raw_action_dim={effective_act_dim // max(frameskip,1)})") + print("Building model...") - model_cfg["action_dim"] = _infer_action_dim(train_loader) - model, sigreg = build_model(cfg) + model, sigreg = build_model(cfg, effective_act_dim) n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f" Trainable parameters: {n_params / 1e6:.1f}M") - lightning_model = LeWMLightning( - model=model, - sigreg=sigreg, - cfg={**train_cfg, **loss_cfg}, - ) + # Flat dict passed into LeWMLightning for optimizer/scheduler and loss + lightning_cfg = { + "lr": opt_cfg["lr"], + "weight_decay": opt_cfg["weight_decay"], + "warmup_epochs": opt_cfg["warmup_epochs"], + "max_epochs": trainer_cfg["max_epochs"], + "sigreg_weight": cfg["loss"]["sigreg"]["weight"], + "history_size": wm_cfg["history_size"], + "num_preds": wm_cfg["num_preds"], + } + + lightning_model = LeWMLightning(model=model, sigreg=sigreg, cfg=lightning_cfg) # ------------------------------------------------------------------ # # Logging & callbacks # ------------------------------------------------------------------ # run_name = args.run_name or f"{table_name}-{num_steps}T" - - logger = None + logger = None if not args.no_wandb: logger = WandbLogger( project=cfg.get("wandb_project", "lewm-lancedb"), @@ -327,14 +368,13 @@ def main(): config={**cfg, "lance_uri": lance_uri, "table": table_name}, ) - ckpt_dir = train_cfg["checkpoint_dir"] + ckpt_dir = trainer_cfg["checkpoint_dir"] os.makedirs(ckpt_dir, exist_ok=True) - callbacks = [ ModelObjectCallBack( dirpath=ckpt_dir, filename=f"{table_name}_lewm", - epoch_interval=train_cfg["save_every_n_epochs"], + epoch_interval=trainer_cfg["save_every_n_epochs"], ) ] @@ -342,12 +382,13 @@ def main(): # Trainer # ------------------------------------------------------------------ # trainer = pl.Trainer( - max_epochs=train_cfg["max_epochs"], - precision=train_cfg["precision"], - gradient_clip_val=train_cfg["gradient_clip"], + max_epochs=trainer_cfg["max_epochs"], + precision=trainer_cfg["precision"], + gradient_clip_val=trainer_cfg["gradient_clip_val"], logger=logger, callbacks=callbacks, - log_every_n_steps=train_cfg["log_every_n_steps"], + log_every_n_steps=trainer_cfg["log_every_n_steps"], + num_sanity_val_steps=1, enable_progress_bar=True, ) @@ -356,10 +397,5 @@ def main(): print("Training complete.") -def _infer_action_dim(loader: torch.utils.data.DataLoader) -> int: - batch = next(iter(loader)) - return batch["action"].shape[-1] - - if __name__ == "__main__": main() From c9f0b942648c2b58079157f9e11ab25aa81bb0b5 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Wed, 1 Apr 2026 21:19:55 +0530 Subject: [PATCH 08/29] support cpu sanity --- examples/leWorldModel/train.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/examples/leWorldModel/train.py b/examples/leWorldModel/train.py index 048f5be..7a366f4 100644 --- a/examples/leWorldModel/train.py +++ b/examples/leWorldModel/train.py @@ -273,8 +273,12 @@ def _build_parser() -> argparse.ArgumentParser: parser.add_argument("--lance-uri", default=None) parser.add_argument("--table-name", default=None) parser.add_argument("--columns", nargs="+", default=None) - parser.add_argument("--run-name", default=None) - parser.add_argument("--no-wandb", action="store_true") + parser.add_argument("--run-name", default=None) + parser.add_argument("--no-wandb", action="store_true") + parser.add_argument("--fast-dev-run", action="store_true", + help="Run 1 train+val batch then exit (smoke test)") + parser.add_argument("--precision", default=None, + help="Override trainer.precision (e.g. 32, 16-mixed, bf16-mixed)") s3 = parser.add_argument_group("S3 storage options") s3.add_argument("--aws-access-key-id", default=None, metavar="KEY") s3.add_argument("--aws-secret-access-key", default=None, metavar="SECRET") @@ -381,14 +385,16 @@ def main(): # ------------------------------------------------------------------ # # Trainer # ------------------------------------------------------------------ # + precision = args.precision or trainer_cfg["precision"] trainer = pl.Trainer( max_epochs=trainer_cfg["max_epochs"], - precision=trainer_cfg["precision"], + precision=precision, gradient_clip_val=trainer_cfg["gradient_clip_val"], logger=logger, callbacks=callbacks, log_every_n_steps=trainer_cfg["log_every_n_steps"], num_sanity_val_steps=1, + fast_dev_run=args.fast_dev_run, enable_progress_bar=True, ) From 344b2713aa79f1857b3729c7f8fc41edb6762959 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Wed, 1 Apr 2026 21:32:28 +0530 Subject: [PATCH 09/29] update --- examples/leWorldModel/requirements.txt | 2 ++ examples/leWorldModel/train.py | 14 +++++++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/examples/leWorldModel/requirements.txt b/examples/leWorldModel/requirements.txt index 4715490..cb71e99 100644 --- a/examples/leWorldModel/requirements.txt +++ b/examples/leWorldModel/requirements.txt @@ -12,6 +12,8 @@ pytorch-lightning>=2.2.0 timm>=1.0.0 pyyaml>=6.0 stable-worldmodel +# le-wm is not a Python package — clone it next to train.py: +# git clone https://github.com/lucas-maes/le-wm # Data h5py>=3.10.0 diff --git a/examples/leWorldModel/train.py b/examples/leWorldModel/train.py index 7a366f4..cf0de55 100644 --- a/examples/leWorldModel/train.py +++ b/examples/leWorldModel/train.py @@ -26,7 +26,19 @@ import os import sys -sys.path.insert(0, os.path.dirname(__file__)) +_HERE = os.path.dirname(__file__) +sys.path.insert(0, _HERE) + +# jepa.py and module.py live at the root of the le-wm repo (not a Python package). +# Clone it next to this file: git clone https://github.com/lucas-maes/le-wm +_LEWM_DIR = os.path.join(_HERE, "le-wm") +if not os.path.isdir(_LEWM_DIR): + raise RuntimeError( + f"le-wm repo not found at {_LEWM_DIR}.\n" + "Run: git clone https://github.com/lucas-maes/le-wm " + f"{_LEWM_DIR}" + ) +sys.path.insert(0, _LEWM_DIR) import timm import torch From 33e35e70fc530b5659af9c12c5fd793503793158 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Wed, 1 Apr 2026 21:39:43 +0530 Subject: [PATCH 10/29] update --- examples/leWorldModel/README.md | 2 +- examples/leWorldModel/bench.py | 2 +- examples/leWorldModel/{lewm_lance => lewm_loader}/__init__.py | 0 .../leWorldModel/{lewm_lance => lewm_loader}/dataloaders.py | 0 examples/leWorldModel/{lewm_lance => lewm_loader}/dataset.py | 0 examples/leWorldModel/train.py | 2 +- 6 files changed, 3 insertions(+), 3 deletions(-) rename examples/leWorldModel/{lewm_lance => lewm_loader}/__init__.py (100%) rename examples/leWorldModel/{lewm_lance => lewm_loader}/dataloaders.py (100%) rename examples/leWorldModel/{lewm_lance => lewm_loader}/dataset.py (100%) diff --git a/examples/leWorldModel/README.md b/examples/leWorldModel/README.md index 3c5bf20..e90667b 100644 --- a/examples/leWorldModel/README.md +++ b/examples/leWorldModel/README.md @@ -9,7 +9,7 @@ examples/leWorldModel/ ├── eda_analysis.py # EDA, quality scan, splits, vector search ├── config/ │ └── lewm_pusht.yaml # Example config (copy and edit per dataset) -└── lewm_lance/ +└── lewm_loader/ ├── dataset.py # LeWMLanceDataset — temporal window sampler └── dataloaders.py # make_train_val_loaders() factory ``` diff --git a/examples/leWorldModel/bench.py b/examples/leWorldModel/bench.py index 23e9cb5..38a4a20 100644 --- a/examples/leWorldModel/bench.py +++ b/examples/leWorldModel/bench.py @@ -42,7 +42,7 @@ import sys sys.path.insert(0, os.path.dirname(__file__)) -from lewm_lance import make_lewm_lance_loader +from lewm_loader import make_lewm_lance_loader # --------------------------------------------------------------------------- diff --git a/examples/leWorldModel/lewm_lance/__init__.py b/examples/leWorldModel/lewm_loader/__init__.py similarity index 100% rename from examples/leWorldModel/lewm_lance/__init__.py rename to examples/leWorldModel/lewm_loader/__init__.py diff --git a/examples/leWorldModel/lewm_lance/dataloaders.py b/examples/leWorldModel/lewm_loader/dataloaders.py similarity index 100% rename from examples/leWorldModel/lewm_lance/dataloaders.py rename to examples/leWorldModel/lewm_loader/dataloaders.py diff --git a/examples/leWorldModel/lewm_lance/dataset.py b/examples/leWorldModel/lewm_loader/dataset.py similarity index 100% rename from examples/leWorldModel/lewm_lance/dataset.py rename to examples/leWorldModel/lewm_loader/dataset.py diff --git a/examples/leWorldModel/train.py b/examples/leWorldModel/train.py index cf0de55..009118b 100644 --- a/examples/leWorldModel/train.py +++ b/examples/leWorldModel/train.py @@ -50,7 +50,7 @@ from pytorch_lightning.loggers import WandbLogger from jepa import JEPA -from lewm_lance import make_train_val_loaders +from lewm_loader import make_train_val_loaders from module import ARPredictor, Embedder, MLP, SIGReg From 34550cbf5aac6ac05e8b83c498c1a47dea16747b Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Wed, 1 Apr 2026 21:44:18 +0530 Subject: [PATCH 11/29] resolve path first --- examples/leWorldModel/train.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/leWorldModel/train.py b/examples/leWorldModel/train.py index 009118b..f22274d 100644 --- a/examples/leWorldModel/train.py +++ b/examples/leWorldModel/train.py @@ -252,8 +252,7 @@ def build_model(cfg: dict, effective_act_dim: int) -> tuple[JEPA, SIGReg]: # S3 storage options # --------------------------------------------------------------------------- -def build_storage_options(args: argparse.Namespace) -> dict: - uri = args.lance_uri +def build_storage_options(args: argparse.Namespace, uri: str) -> dict: if not (uri.startswith("s3://") or uri.startswith("gs://") or uri.startswith("az://")): return {} opts: dict[str, str] = {} @@ -322,7 +321,7 @@ def main(): if table_name is None: parser.error("table_name required: set data.table_name in config or pass --table-name") - storage_options = build_storage_options(args) + storage_options = build_storage_options(args, lance_uri) connect_kwargs = {"storage_options": storage_options} if storage_options else {} # ------------------------------------------------------------------ # From 01518bdbb60c7a5ad311b5a54985ceb0dc59f6f9 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Wed, 1 Apr 2026 22:01:46 +0530 Subject: [PATCH 12/29] update --- examples/leWorldModel/eda_analysis.py | 38 ++++++++++--------- .../leWorldModel/lewm_loader/dataloaders.py | 2 +- examples/leWorldModel/lewm_loader/dataset.py | 12 +++--- 3 files changed, 28 insertions(+), 24 deletions(-) diff --git a/examples/leWorldModel/eda_analysis.py b/examples/leWorldModel/eda_analysis.py index df14511..8acd992 100644 --- a/examples/leWorldModel/eda_analysis.py +++ b/examples/leWorldModel/eda_analysis.py @@ -43,16 +43,17 @@ def dataset_statistics(tbl: lancedb.table.Table): schema = tbl.schema total_rows = len(tbl) - # Read only the two index columns — negligible memory - idx_arrow = tbl.to_arrow(columns=["episode_idx"]) - n_episodes = len(pc.unique(idx_arrow["episode_idx"])) + # Read only the episode index column — negligible memory + ds = tbl.to_lance() + ep_idx_table = ds.to_table(columns=["episode_idx"]) + n_episodes = len(pc.unique(ep_idx_table["episode_idx"])) print(f"\nTotal timesteps : {total_rows:,}") print(f"Total episodes : {n_episodes:,}") print(f"Avg steps/ep : {total_rows / n_episodes:.1f}") print(f"\nSchema:\n{schema}\n") - # Per-column stats for fixed-size list columns (skip pixels and embeddings) + # Per-column stats — load one column at a time to bound peak memory list_cols = [ f.name for f in schema if (pa.types.is_list(f.type) or pa.types.is_fixed_size_list(f.type)) @@ -62,9 +63,9 @@ def dataset_statistics(tbl: lancedb.table.Table): if not list_cols: return - arrow = tbl.to_arrow(columns=list_cols) for col in list_cols: - data = np.stack([row.as_py() for row in arrow[col]], axis=0).astype(np.float32) + col_table = ds.to_table(columns=[col]) + data = np.array(col_table[col].to_pylist(), dtype=np.float32) valid = ~np.isnan(data).any(axis=1) data = data[valid] print(f" {col:<14} dim={data.shape[1]:3d} | " @@ -73,7 +74,7 @@ def dataset_statistics(tbl: lancedb.table.Table): f"NaN rows={(~valid).sum()}") # Episode length distribution - ep_arr = idx_arrow["episode_idx"].to_numpy() + ep_arr = ep_idx_table["episode_idx"].to_numpy() _, counts = np.unique(ep_arr, return_counts=True) print(f"\nEpisode length min={counts.min()} max={counts.max()} " f"median={int(np.median(counts))} std={counts.std():.1f}") @@ -104,7 +105,7 @@ def create_splits( print("2. EPISODE-LEVEL SPLITS") print("=" * 60) - ep_arr = tbl.to_arrow(columns=["episode_idx"])["episode_idx"].to_numpy() + ep_arr = tbl.to_lance().to_table(columns=["episode_idx"])["episode_idx"].to_numpy() all_eps = np.unique(ep_arr) rng = np.random.default_rng(seed) rng.shuffle(all_eps) @@ -124,7 +125,7 @@ def create_splits( print(f" {name:<6}: {len(eps):5,} episodes {ep_mask.sum():8,} timesteps") print("\n To use a split in training, filter with:") - print(" tbl.to_arrow(columns=[...], filter='episode_idx IN (0,1,...)')") + print(" tbl.to_lance().to_table(columns=[...], filter='episode_idx IN (0,1,...)')") return splits @@ -142,7 +143,7 @@ def temporal_coherence_check(tbl: lancedb.table.Table): print("3. TEMPORAL COHERENCE CHECK") print("=" * 60) - idx = tbl.to_arrow(columns=["episode_idx", "step_idx"]) + idx = tbl.to_lance().to_table(columns=["episode_idx", "step_idx"]) ep = idx["episode_idx"].to_numpy() step = idx["step_idx"].to_numpy() @@ -266,9 +267,9 @@ def episode_retrieval_demo( print(f" [SKIP] '{emb_col}' column not found.") return - arrow = tbl.to_arrow(columns=["episode_idx", emb_col]) + arrow = tbl.to_lance().to_table(columns=["episode_idx", emb_col]) ep_arr = arrow["episode_idx"].to_numpy() - emb_arr = np.stack([row.as_py() for row in arrow[emb_col]], axis=0).astype(np.float32) + emb_arr = np.array(arrow[emb_col].to_pylist(), dtype=np.float32) unique_eps = np.unique(ep_arr) ep_means = {ep: emb_arr[ep_arr == ep].mean(axis=0) for ep in unique_eps} @@ -308,9 +309,9 @@ def action_entropy_analysis(tbl: lancedb.table.Table, top_k: int = 5): print(" [SKIP] No action column.") return - arrow = tbl.to_arrow(columns=["episode_idx", "action"]) + arrow = tbl.to_lance().to_table(columns=["episode_idx", "action"]) ep_arr = arrow["episode_idx"].to_numpy() - act_arr = np.stack([row.as_py() for row in arrow["action"]], axis=0).astype(np.float32) + act_arr = np.array(arrow["action"].to_pylist(), dtype=np.float32) unique_eps = np.unique(ep_arr) entropies = { @@ -353,16 +354,17 @@ def data_quality_scan(tbl: lancedb.table.Table): and f.name not in ("pixels",) and not f.name.startswith("emb_") ] + ds = tbl.to_lance() for col in list_cols: - arrow = tbl.to_arrow(columns=[col]) - data = np.stack([row.as_py() for row in arrow[col]], axis=0).astype(np.float32) + col_table = ds.to_table(columns=[col]) + data = np.array(col_table[col].to_pylist(), dtype=np.float32) n_nan = int(np.isnan(data).any(axis=1).sum()) pct = 100 * n_nan / total flag = "[WARN]" if pct > 5 else "[OK] " print(f" {flag} {col:<14} NaN rows: {n_nan:,} ({pct:.1f}%)") # Degenerate episode check - ep_arr = tbl.to_arrow(columns=["episode_idx"])["episode_idx"].to_numpy() + ep_arr = ds.to_table(columns=["episode_idx"])["episode_idx"].to_numpy() _, counts = np.unique(ep_arr, return_counts=True) short_eps = int((counts < 4).sum()) if short_eps == 0: @@ -429,7 +431,7 @@ def print_lancedb_vs_hdf5(): window into what the world model has learned to focus on. 4. EPISODE FILTERING WITHOUT ARRAY MANIPULATION - tbl.to_arrow(filter="episode_idx IN (...)") returns only the matching + tbl.to_lance().to_table(filter="episode_idx IN (...)") returns only the matching rows, columnar-compressed, as Arrow. With HDF5 you load the full array and mask in Python. diff --git a/examples/leWorldModel/lewm_loader/dataloaders.py b/examples/leWorldModel/lewm_loader/dataloaders.py index 519d368..8cf9e69 100644 --- a/examples/leWorldModel/lewm_loader/dataloaders.py +++ b/examples/leWorldModel/lewm_loader/dataloaders.py @@ -89,7 +89,7 @@ def make_train_val_loaders( tbl = db.open_table(table_name) # Only reads one int32 column — negligible memory even at millions of rows - ep_arr = tbl.to_arrow(columns=["episode_idx"])["episode_idx"].to_numpy() + ep_arr = tbl.to_lance().to_table(columns=["episode_idx"])["episode_idx"].to_numpy() all_episodes = np.unique(ep_arr) rng = np.random.default_rng(seed) diff --git a/examples/leWorldModel/lewm_loader/dataset.py b/examples/leWorldModel/lewm_loader/dataset.py index 9377b9c..78f4b22 100644 --- a/examples/leWorldModel/lewm_loader/dataset.py +++ b/examples/leWorldModel/lewm_loader/dataset.py @@ -63,17 +63,18 @@ def compute_normalizers( ) -> dict[str, tuple[np.ndarray, np.ndarray]]: """ Compute per-column (mean, std) for z-score normalization. - Only reads non-pixel columns; no pixel data loaded. + Only reads non-pixel columns one at a time; pixel/blob data is never loaded. """ db = lancedb.connect(uri, **connect_kwargs) tbl = db.open_table(table_name) non_pixel = [c for c in columns if c != "pixels"] if not non_pixel: return {} - arrow = tbl.to_arrow(columns=non_pixel) + ds = tbl.to_lance() normalizers = {} for col in non_pixel: - data = np.stack([row.as_py() for row in arrow[col]], axis=0).astype(np.float32) + col_table = ds.to_table(columns=[col]) + data = np.array(col_table[col].to_pylist(), dtype=np.float32) valid = ~np.isnan(data).any(axis=1) data = data[valid] normalizers[col] = (data.mean(axis=0), data.std(axis=0)) @@ -122,10 +123,11 @@ def __init__( self._perm: Permutation | None = None self._transform: transforms.Compose | None = None - # Load only the two index columns to precompute valid windows. + # Load only the two int32 index columns to precompute valid windows. + # Pixels and all other data columns are never touched here. db = lancedb.connect(uri, **connect_kwargs) tbl = db.open_table(table_name) - idx = tbl.to_arrow(columns=["episode_idx", "step_idx"]) + idx = tbl.to_lance().to_table(columns=["episode_idx", "step_idx"]) self._ep = idx["episode_idx"].to_numpy().astype(np.int32) self._step = idx["step_idx"].to_numpy().astype(np.int32) self._n_rows = len(self._ep) From 9bb41d0f98eb83d8ee035b8c3a2e730987dd8a2f Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Wed, 1 Apr 2026 22:10:51 +0530 Subject: [PATCH 13/29] update --- examples/leWorldModel/lewm_loader/dataset.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/examples/leWorldModel/lewm_loader/dataset.py b/examples/leWorldModel/lewm_loader/dataset.py index 78f4b22..c380598 100644 --- a/examples/leWorldModel/lewm_loader/dataset.py +++ b/examples/leWorldModel/lewm_loader/dataset.py @@ -214,20 +214,3 @@ def __getitem__(self, window_idx: int) -> dict[str, torch.Tensor]: rows = list(range(start, start + self._span)) batch = self._perm.__getitems__(rows) return self._rows_to_sample(batch) - - def __getitems__(self, window_indices: list[int]) -> list[dict[str, torch.Tensor]]: - """Fetch all B × span rows in one Permutation call.""" - self._ensure_open() - span = self._span - starts = self._window_starts[window_indices] - - all_rows: list[int] = [] - for s in starts: - all_rows.extend(range(int(s), int(s) + span)) - - big_batch: pa.RecordBatch = self._perm.__getitems__(all_rows) - - return [ - self._rows_to_sample(big_batch.slice(b * span, span)) - for b in range(len(window_indices)) - ] From 443b56a0bb367b84cf2831f2b6a52bb44705ee24 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Wed, 1 Apr 2026 22:16:50 +0530 Subject: [PATCH 14/29] update --- examples/leWorldModel/config/lewm_pusht.yaml | 2 +- examples/leWorldModel/requirements.txt | 2 +- examples/leWorldModel/train.py | 50 ++++++++------------ 3 files changed, 23 insertions(+), 31 deletions(-) diff --git a/examples/leWorldModel/config/lewm_pusht.yaml b/examples/leWorldModel/config/lewm_pusht.yaml index 7b2ac6e..80bba91 100644 --- a/examples/leWorldModel/config/lewm_pusht.yaml +++ b/examples/leWorldModel/config/lewm_pusht.yaml @@ -28,7 +28,7 @@ wm: history_size: 3 num_preds: 1 embed_dim: 192 - encoder_name: "vit_tiny_patch14_224" # added: timm model identifier + patch_size: 14 # ViT patch size — with img_size=224 gives 16×16=256 patches + CLS proj_hidden: 2048 # added: projector MLP hidden dim predictor: diff --git a/examples/leWorldModel/requirements.txt b/examples/leWorldModel/requirements.txt index cb71e99..8a4e6a9 100644 --- a/examples/leWorldModel/requirements.txt +++ b/examples/leWorldModel/requirements.txt @@ -9,7 +9,7 @@ torchvision>=0.17.0 # Training pytorch-lightning>=2.2.0 -timm>=1.0.0 +transformers>=4.40.0 pyyaml>=6.0 stable-worldmodel # le-wm is not a Python package — clone it next to train.py: diff --git a/examples/leWorldModel/train.py b/examples/leWorldModel/train.py index f22274d..1fa236d 100644 --- a/examples/leWorldModel/train.py +++ b/examples/leWorldModel/train.py @@ -40,9 +40,9 @@ ) sys.path.insert(0, _LEWM_DIR) -import timm import torch import torch.nn as nn +from transformers import ViTConfig, ViTModel import yaml import pytorch_lightning as pl from pathlib import Path @@ -55,35 +55,15 @@ # --------------------------------------------------------------------------- -# Encoder wrapper +# Encoder # -# JEPA.encode() calls: -# output = self.encoder(pixels, interpolate_pos_encoding=True) -# pixels_emb = output.last_hidden_state[:, 0] ← CLS token -# -# timm ViTs return (B, D) when num_classes=0, not a structured object. -# We wrap forward_features() — which returns (B, N, D) with CLS at [:, 0] — -# to satisfy the HuggingFace-style interface JEPA expects. +# le-wm uses spt.backbone.utils.vit_hf() from stable_pretraining, which +# creates a HuggingFace ViTModel. JEPA.encode() expects exactly that +# interface: output.last_hidden_state[:, 0] → CLS token. +# We build the same model directly with transformers.ViTConfig/ViTModel, +# matching the ViT-tiny spec from the le-wm paper. # --------------------------------------------------------------------------- -class _TimmViTOutput: - __slots__ = ("last_hidden_state",) - def __init__(self, last_hidden_state): - self.last_hidden_state = last_hidden_state - - -class TimmViT(nn.Module): - def __init__(self, model_name: str, img_size: int): - super().__init__() - self._model = timm.create_model( - model_name, pretrained=False, img_size=img_size, num_classes=0 - ) - self.embed_dim = self._model.embed_dim - - def forward(self, x, interpolate_pos_encoding=False): - # forward_features → (B, N, D) where N[0] is the CLS token - return _TimmViTOutput(self._model.forward_features(x)) - # --------------------------------------------------------------------------- # Checkpoint callback (inlined to avoid stable_pretraining dependency) @@ -204,8 +184,20 @@ def build_model(cfg: dict, effective_act_dim: int) -> tuple[JEPA, SIGReg]: wm = cfg["wm"] pred = cfg["predictor"] - encoder = TimmViT(wm["encoder_name"], cfg["img_size"]) - hidden_dim = encoder.embed_dim # ViT-tiny: 192 + # Build ViT-tiny with HuggingFace transformers — same as le-wm's vit_hf("tiny", ...) + vit_cfg = ViTConfig( + hidden_size=wm["embed_dim"], + num_hidden_layers=12, + num_attention_heads=3, + intermediate_size=wm["embed_dim"] * 4, + image_size=cfg["img_size"], + patch_size=wm["patch_size"], + num_channels=3, + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + ) + encoder = ViTModel(vit_cfg, add_pooling_layer=False) + hidden_dim = wm["embed_dim"] # ViT-tiny: 192 # ARPredictor: input_dim and hidden_dim can differ. # Here we keep them equal (both embed_dim), matching le-wm defaults. From 9822648b793aea5ca33a9b8973950428de76d253 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Thu, 2 Apr 2026 11:48:10 +0530 Subject: [PATCH 15/29] update --- examples/leWorldModel/bench.py | 4 +- examples/leWorldModel/jepa.py | 153 +++++++++++++++++ examples/leWorldModel/module.py | 285 ++++++++++++++++++++++++++++++++ examples/leWorldModel/train.py | 36 ++-- 4 files changed, 456 insertions(+), 22 deletions(-) create mode 100644 examples/leWorldModel/jepa.py create mode 100644 examples/leWorldModel/module.py diff --git a/examples/leWorldModel/bench.py b/examples/leWorldModel/bench.py index 38a4a20..eb3c159 100644 --- a/examples/leWorldModel/bench.py +++ b/examples/leWorldModel/bench.py @@ -51,6 +51,7 @@ BATCH_SIZE = 128 NUM_STEPS = 4 +FRAMESKIP = 5 # must match training config; both backends use the same value IMAGE_SIZE = 224 NUM_WORKERS = 8 PREFETCH_FACTOR = 3 @@ -91,7 +92,7 @@ class HDF5LeWMDataset(torch.utils.data.Dataset): h5py is opened lazily per worker because handles are not fork-safe. """ - def __init__(self, hdf5_src, columns, num_steps=NUM_STEPS, frameskip=1): + def __init__(self, hdf5_src, columns, num_steps=NUM_STEPS, frameskip=FRAMESKIP): self._src = hdf5_src self.columns = columns self.num_steps = num_steps @@ -283,6 +284,7 @@ def main(): columns=args.columns, batch_size=args.batch_size, num_steps=NUM_STEPS, + frameskip=FRAMESKIP, img_size=IMAGE_SIZE, num_workers=args.num_workers, prefetch_factor=PREFETCH_FACTOR, diff --git a/examples/leWorldModel/jepa.py b/examples/leWorldModel/jepa.py new file mode 100644 index 0000000..486fe93 --- /dev/null +++ b/examples/leWorldModel/jepa.py @@ -0,0 +1,153 @@ +"""JEPA Implementation""" + +import torch +import torch.nn.functional as F +from einops import rearrange +from torch import nn + +def detach_clone(v): + return v.detach().clone() if torch.is_tensor(v) else v + +class JEPA(nn.Module): + + def __init__( + self, + encoder, + predictor, + action_encoder, + projector=None, + pred_proj=None, + ): + super().__init__() + + self.encoder = encoder + self.predictor = predictor + self.action_encoder = action_encoder + self.projector = projector or nn.Identity() + self.pred_proj = pred_proj or nn.Identity() + + def encode(self, info): + """Encode observations and actions into embeddings. + info: dict with pixels and action keys + """ + + pixels = info['pixels'].float() + b = pixels.size(0) + pixels = rearrange(pixels, "b t ... -> (b t) ...") # flatten for encoding + output = self.encoder(pixels, interpolate_pos_encoding=True) + pixels_emb = output.last_hidden_state[:, 0] # cls token + emb = self.projector(pixels_emb) + info["emb"] = rearrange(emb, "(b t) d -> b t d", b=b) + + if "action" in info: + info["act_emb"] = self.action_encoder(info["action"]) + + return info + + def predict(self, emb, act_emb): + """Predict next state embedding + emb: (B, T, D) + act_emb: (B, T, A_emb) + """ + preds = self.predictor(emb, act_emb) + preds = self.pred_proj(rearrange(preds, "b t d -> (b t) d")) + preds = rearrange(preds, "(b t) d -> b t d", b=emb.size(0)) + return preds + + #################### + ## Inference only ## + #################### + + def rollout(self, info, action_sequence, history_size: int = 3): + """Rollout the model given an initial info dict and action sequence. + pixels: (B, S, T, C, H, W) + action_sequence: (B, S, T, action_dim) + - S is the number of action plan samples + - T is the time horizon + """ + + assert "pixels" in info, "pixels not in info_dict" + H = info["pixels"].size(2) + B, S, T = action_sequence.shape[:3] + act_0, act_future = torch.split(action_sequence, [H, T - H], dim=2) + info["action"] = act_0 + n_steps = T - H + + # copy and encode initial info dict + _init = {k: v[:, 0] for k, v in info.items() if torch.is_tensor(v)} + _init = self.encode(_init) + emb = info["emb"] = _init["emb"].unsqueeze(1).expand(B, S, -1, -1) + _init = {k: detach_clone(v) for k, v in _init.items()} + + # flatten batch and sample dimensions for rollout + emb = rearrange(emb, "b s ... -> (b s) ...").clone() + act = rearrange(act_0, "b s ... -> (b s) ...") + act_future = rearrange(act_future, "b s ... -> (b s) ...") + + # rollout predictor autoregressively for n_steps + HS = history_size + for t in range(n_steps): + act_emb = self.action_encoder(act) + emb_trunc = emb[:, -HS:] # (BS, HS, D) + act_trunc = act_emb[:, -HS:] # (BS, HS, A_emb) + pred_emb = self.predict(emb_trunc, act_trunc)[:, -1:] # (BS, 1, D) + emb = torch.cat([emb, pred_emb], dim=1) # (BS, T+1, D) + + next_act = act_future[:, t : t + 1, :] # (BS, 1, action_dim) + act = torch.cat([act, next_act], dim=1) # (BS, T+1, action_dim) + + # predict the last state + act_emb = self.action_encoder(act) # (BS, T, A_emb) + emb_trunc = emb[:, -HS:] # (BS, HS, D) + act_trunc = act_emb[:, -HS:] # (BS, HS, A_emb) + pred_emb = self.predict(emb_trunc, act_trunc)[:, -1:] # (BS, 1, D) + emb = torch.cat([emb, pred_emb], dim=1) + + # unflatten batch and sample dimensions + pred_rollout = rearrange(emb, "(b s) ... -> b s ...", b=B, s=S) + info["predicted_emb"] = pred_rollout + + return info + + def criterion(self, info_dict: dict): + """Compute the cost between predicted embeddings and goal embeddings.""" + pred_emb = info_dict["predicted_emb"] # (B,S, T-1, dim) + goal_emb = info_dict["goal_emb"] # (B, S, T, dim) + + goal_emb = goal_emb[..., -1:, :].expand_as(pred_emb) + + # return last-step cost per action candidate + cost = F.mse_loss( + pred_emb[..., -1:, :], + goal_emb[..., -1:, :].detach(), + reduction="none", + ).sum(dim=tuple(range(2, pred_emb.ndim))) # (B, S) + + return cost + + def get_cost(self, info_dict: dict, action_candidates: torch.Tensor): + """ Compute the cost of action candidates given an info dict with goal and initial state.""" + + assert "goal" in info_dict, "goal not in info_dict" + + device = next(self.parameters()).device + for k in list(info_dict.keys()): + if torch.is_tensor(info_dict[k]): + info_dict[k] = info_dict[k].to(device) + + goal = {k: v[:, 0] for k, v in info_dict.items() if torch.is_tensor(v)} + goal["pixels"] = goal["goal"] + + for k in info_dict: + if k.startswith("goal_"): + goal[k[len("goal_") :]] = goal.pop(k) + + goal.pop("action") + goal = self.encode(goal) + + info_dict["goal_emb"] = goal["emb"] + info_dict = self.rollout(info_dict, action_candidates) + + cost = self.criterion(info_dict) + + return cost diff --git a/examples/leWorldModel/module.py b/examples/leWorldModel/module.py new file mode 100644 index 0000000..16c4907 --- /dev/null +++ b/examples/leWorldModel/module.py @@ -0,0 +1,285 @@ +import torch +from torch import nn +import torch.nn.functional as F +from einops import rearrange + +def modulate(x, shift, scale): + """AdaLN-zero modulation""" + return x * (1 + scale) + shift + +class SIGReg(torch.nn.Module): + """Sketch Isotropic Gaussian Regularizer (single-GPU!)""" + + def __init__(self, knots=17, num_proj=1024): + super().__init__() + self.num_proj = num_proj + t = torch.linspace(0, 3, knots, dtype=torch.float32) + dt = 3 / (knots - 1) + weights = torch.full((knots,), 2 * dt, dtype=torch.float32) + weights[[0, -1]] = dt + window = torch.exp(-t.square() / 2.0) + self.register_buffer("t", t) + self.register_buffer("phi", window) + self.register_buffer("weights", weights * window) + + def forward(self, proj): + """ + proj: (T, B, D) + """ + # sample random projections + A = torch.randn(proj.size(-1), self.num_proj, device=proj.device) + A = A.div_(A.norm(p=2, dim=0)) + # compute the epps-pulley statistic + x_t = (proj @ A).unsqueeze(-1) * self.t + err = (x_t.cos().mean(-3) - self.phi).square() + x_t.sin().mean(-3).square() + statistic = (err @ self.weights) * proj.size(-2) + return statistic.mean() # average over projections and time + +class FeedForward(nn.Module): + """FeedForward network used in Transformers""" + + def __init__(self, dim, hidden_dim, dropout=0.0): + super().__init__() + self.net = nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout), + ) + + def forward(self, x): + return self.net(x) + + +class Attention(nn.Module): + """Scaled dot-product attention with causal masking""" + + def __init__(self, dim, heads=8, dim_head=64, dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + self.heads = heads + self.scale = dim_head**-0.5 + self.dropout = dropout + self.norm = nn.LayerNorm(dim) + self.attend = nn.Softmax(dim=-1) + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) + self.to_out = ( + nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) + if project_out + else nn.Identity() + ) + + def forward(self, x, causal=True): + """ + x : (B, T, D) + """ + x = self.norm(x) + drop = self.dropout if self.training else 0.0 + qkv = self.to_qkv(x).chunk(3, dim=-1) # q, k, v: (B, heads, T, dim_head) + q, k, v = (rearrange(t, "b t (h d) -> b h t d", h=self.heads) for t in qkv) + out = F.scaled_dot_product_attention(q, k, v, dropout_p=drop, is_causal=causal) + out = rearrange(out, "b h t d -> b t (h d)") + return self.to_out(out) + + +class ConditionalBlock(nn.Module): + """Transformer block with AdaLN-zero conditioning""" + + def __init__(self, dim, heads, dim_head, mlp_dim, dropout=0.0): + super().__init__() + + self.attn = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout) + self.mlp = FeedForward(dim, mlp_dim, dropout=dropout) + self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(dim, 6 * dim, bias=True) + ) + + nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.adaLN_modulation[-1].bias, 0) + + def forward(self, x, c): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.adaLN_modulation(c).chunk(6, dim=-1) + ) + x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class Block(nn.Module): + """Standard Transformer block""" + + def __init__(self, dim, heads, dim_head, mlp_dim, dropout=0.0): + super().__init__() + + self.attn = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout) + self.mlp = FeedForward(dim, mlp_dim, dropout=dropout) + self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + def forward(self, x): + x = x + self.attn(self.norm1(x)) + x = x + self.mlp(self.norm2(x)) + return x + + +class Transformer(nn.Module): + """Standard Transformer with support for AdaLN-zero blocks""" + + def __init__( + self, + input_dim, + hidden_dim, + output_dim, + depth, + heads, + dim_head, + mlp_dim, + dropout=0.0, + block_class=Block, + ): + super().__init__() + self.norm = nn.LayerNorm(hidden_dim) + self.layers = nn.ModuleList([]) + + self.input_proj = ( + nn.Linear(input_dim, hidden_dim) + if input_dim != hidden_dim + else nn.Identity() + ) + + self.cond_proj = ( + nn.Linear(input_dim, hidden_dim) + if input_dim != hidden_dim + else nn.Identity() + ) + + self.output_proj = ( + nn.Linear(hidden_dim, output_dim) + if hidden_dim != output_dim + else nn.Identity() + ) + + for _ in range(depth): + self.layers.append( + block_class(hidden_dim, heads, dim_head, mlp_dim, dropout) + ) + + def forward(self, x, c=None): + + if hasattr(self, "input_proj"): + x = self.input_proj(x) + + if c is not None and hasattr(self, "cond_proj"): + c = self.cond_proj(c) + + for block in self.layers: + x = block(x) if isinstance(block, Block) else block(x, c) + x = self.norm(x) + + if hasattr(self, "output_proj"): + x = self.output_proj(x) + return x + +class Embedder(nn.Module): + def __init__( + self, + input_dim=10, + smoothed_dim=10, + emb_dim=10, + mlp_scale=4, + ): + super().__init__() + self.patch_embed = nn.Conv1d(input_dim, smoothed_dim, kernel_size=1, stride=1) + self.embed = nn.Sequential( + nn.Linear(smoothed_dim, mlp_scale * emb_dim), + nn.SiLU(), + nn.Linear(mlp_scale * emb_dim, emb_dim), + ) + + def forward(self, x): + """ + x: (B, T, D) + """ + x = x.float() + x = x.permute(0, 2, 1) + x = self.patch_embed(x) + x = x.permute(0, 2, 1) + x = self.embed(x) + return x + + +class MLP(nn.Module): + """Simple MLP with optional normalization and activation""" + + def __init__( + self, + input_dim, + hidden_dim, + output_dim=None, + norm_fn=nn.LayerNorm, + act_fn=nn.GELU, + ): + super().__init__() + norm_fn = norm_fn(hidden_dim) if norm_fn is not None else nn.Identity() + self.net = nn.Sequential( + nn.Linear(input_dim, hidden_dim), + norm_fn, + act_fn(), + nn.Linear(hidden_dim, output_dim or input_dim), + ) + + def forward(self, x): + """ + x: (B*T, D) + """ + return self.net(x) + + +class ARPredictor(nn.Module): + """Autoregressive predictor for next-step embedding prediction.""" + + def __init__( + self, + *, + num_frames, + depth, + heads, + mlp_dim, + input_dim, + hidden_dim, + output_dim=None, + dim_head=64, + dropout=0.0, + emb_dropout=0.0, + ): + super().__init__() + self.pos_embedding = nn.Parameter(torch.randn(1, num_frames, input_dim)) + self.dropout = nn.Dropout(emb_dropout) + self.transformer = Transformer( + input_dim, + hidden_dim, + output_dim or input_dim, + depth, + heads, + dim_head, + mlp_dim, + dropout, + block_class=ConditionalBlock, + ) + + def forward(self, x, c): + """ + x: (B, T, d) + c: (B, T, act_dim) + """ + T = x.size(1) + x = x + self.pos_embedding[:, :T] + x = self.dropout(x) + x = self.transformer(x, c) + return x diff --git a/examples/leWorldModel/train.py b/examples/leWorldModel/train.py index 1fa236d..43700c9 100644 --- a/examples/leWorldModel/train.py +++ b/examples/leWorldModel/train.py @@ -5,6 +5,9 @@ a LanceDB-backed DataLoader while keeping the model, loss, and Lightning training loop identical to the original. +jepa.py and module.py are vendored directly from https://github.com/lucas-maes/le-wm +(no git clone required). + Usage: # Local LanceDB store (defaults come from config) python train.py --config config/lewm_pusht.yaml @@ -29,17 +32,6 @@ _HERE = os.path.dirname(__file__) sys.path.insert(0, _HERE) -# jepa.py and module.py live at the root of the le-wm repo (not a Python package). -# Clone it next to this file: git clone https://github.com/lucas-maes/le-wm -_LEWM_DIR = os.path.join(_HERE, "le-wm") -if not os.path.isdir(_LEWM_DIR): - raise RuntimeError( - f"le-wm repo not found at {_LEWM_DIR}.\n" - "Run: git clone https://github.com/lucas-maes/le-wm " - f"{_LEWM_DIR}" - ) -sys.path.insert(0, _LEWM_DIR) - import torch import torch.nn as nn from transformers import ViTConfig, ViTModel @@ -184,27 +176,29 @@ def build_model(cfg: dict, effective_act_dim: int) -> tuple[JEPA, SIGReg]: wm = cfg["wm"] pred = cfg["predictor"] - # Build ViT-tiny with HuggingFace transformers — same as le-wm's vit_hf("tiny", ...) + # ViT-tiny — identical to spt.backbone.utils.vit_hf("tiny", patch_size=14, image_size=224, + # pretrained=False, use_mask_token=False) from stable_pretraining. + # vit_hf builds ViTModel(ViTConfig(**size_configs["tiny"]), add_pooling_layer=False, + # use_mask_token=False) where size_configs["tiny"] = {hidden_size:192, num_hidden_layers:12, + # num_attention_heads:3, intermediate_size:768}. vit_cfg = ViTConfig( - hidden_size=wm["embed_dim"], + hidden_size=wm["embed_dim"], # 192 num_hidden_layers=12, num_attention_heads=3, - intermediate_size=wm["embed_dim"] * 4, - image_size=cfg["img_size"], - patch_size=wm["patch_size"], - num_channels=3, + intermediate_size=wm["embed_dim"] * 4, # 768 + image_size=cfg["img_size"], # 224 + patch_size=wm["patch_size"], # 14 hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0, ) - encoder = ViTModel(vit_cfg, add_pooling_layer=False) - hidden_dim = wm["embed_dim"] # ViT-tiny: 192 + encoder = ViTModel(vit_cfg, add_pooling_layer=False, use_mask_token=False) + hidden_dim = wm["embed_dim"] # ViT-tiny hidden_size: 192 - # ARPredictor: input_dim and hidden_dim can differ. - # Here we keep them equal (both embed_dim), matching le-wm defaults. predictor = ARPredictor( num_frames=wm["history_size"], input_dim=wm["embed_dim"], hidden_dim=hidden_dim, + output_dim=hidden_dim, depth=pred["depth"], heads=pred["heads"], mlp_dim=pred["mlp_dim"], From 38399aebeede5edd3e4989c7b032db5498ef49db Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Thu, 2 Apr 2026 12:40:04 +0530 Subject: [PATCH 16/29] improve loading --- examples/leWorldModel/config/lewm_pusht.yaml | 6 +-- examples/leWorldModel/lewm_loader/dataset.py | 49 ++++++++++++++++---- 2 files changed, 42 insertions(+), 13 deletions(-) diff --git a/examples/leWorldModel/config/lewm_pusht.yaml b/examples/leWorldModel/config/lewm_pusht.yaml index 80bba91..9ab2ba8 100644 --- a/examples/leWorldModel/config/lewm_pusht.yaml +++ b/examples/leWorldModel/config/lewm_pusht.yaml @@ -7,7 +7,7 @@ seed: 42 img_size: 224 trainer: - max_epochs: 100 + max_epochs: 10 precision: "bf16-mixed" # PyTorch Lightning format (le-wm uses "bf16" via Hydra) gradient_clip_val: 1.0 log_every_n_steps: 50 @@ -16,8 +16,8 @@ trainer: loader: batch_size: 128 - num_workers: 6 - prefetch_factor: 3 + num_workers: 12 + prefetch_factor: 4 optimizer: lr: 5.0e-5 diff --git a/examples/leWorldModel/lewm_loader/dataset.py b/examples/leWorldModel/lewm_loader/dataset.py index c380598..f0c3da1 100644 --- a/examples/leWorldModel/lewm_loader/dataset.py +++ b/examples/leWorldModel/lewm_loader/dataset.py @@ -186,19 +186,14 @@ def _rows_to_sample(self, batch: pa.RecordBatch) -> dict[str, torch.Tensor]: continue if col == "action": - # All span steps → (span, action_dim) → (T, frameskip × action_dim) - data = np.array( - [batch[col][i].as_py() for i in range(self._span)], - dtype=np.float32, - ) + # to_pylist() batches the Arrow → Python conversion in one call + # instead of .as_py() in a Python loop + data = np.array(batch.column(col).to_pylist(), dtype=np.float32) data = np.nan_to_num(data, nan=0.0) data = data.reshape(T, -1) # (T, frameskip × action_dim) else: - # Stride by frameskip → T steps - data = np.array( - [batch[col][t * frameskip].as_py() for t in range(T)], - dtype=np.float32, - ) + data = np.array(batch.column(col).to_pylist(), dtype=np.float32) + data = data[::frameskip] # stride without a Python loop if col in self.normalizers: mean, std = self.normalizers[col] data = (data - mean) / (std + 1e-8) @@ -214,3 +209,37 @@ def __getitem__(self, window_idx: int) -> dict[str, torch.Tensor]: rows = list(range(start, start + self._span)) batch = self._perm.__getitems__(rows) return self._rows_to_sample(batch) + + def __getitems__(self, indices: list[int]) -> list[dict[str, torch.Tensor]]: + """ + Fetch an entire DataLoader batch in one round trip. + + Permutation.__getitems__ deduplicates row indices, so we cannot pass + all B*span rows directly (overlapping windows would silently drop rows). + Instead we: + 1. Collect the exact row ranges for each window. + 2. Deduplicate ourselves → sorted unique row list. + 3. Single Permutation fetch for those unique rows. + 4. Reconstruct each window by indexing into the fetched result. + This reduces S3 round trips from B (one per sample) to 1 per batch. + """ + self._ensure_open() + + # Step 1 — row ranges per window + window_rows = [ + list(range(int(self._window_starts[i]), int(self._window_starts[i]) + self._span)) + for i in indices + ] + + # Step 2 — unique sorted rows + reverse mapping + all_rows = sorted(set(r for rows in window_rows for r in rows)) + row_to_pos = {r: pos for pos, r in enumerate(all_rows)} + + # Step 3 — single fetch + fetched = self._perm.__getitems__(all_rows) + + # Step 4 — reconstruct each window + return [ + self._rows_to_sample(fetched.take([row_to_pos[r] for r in rows])) + for rows in window_rows + ] From 23365409431ce0fb1341eb584a357637e3578abb Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Thu, 2 Apr 2026 14:22:20 +0530 Subject: [PATCH 17/29] update --- examples/leWorldModel/bench.py | 10 ++++++++++ examples/leWorldModel/dataset.md | 26 ++++++++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/examples/leWorldModel/bench.py b/examples/leWorldModel/bench.py index eb3c159..731d79d 100644 --- a/examples/leWorldModel/bench.py +++ b/examples/leWorldModel/bench.py @@ -151,6 +151,16 @@ def __getitem__(self, clip_idx): sample[col] = torch.from_numpy(np.nan_to_num(data, nan=0.0)) return sample + def __getitems__(self, indices: list[int]) -> list[dict]: + """Batch pixel reads sorted by file offset to improve sequential access.""" + self._ensure_open() + # Sort by file offset so HDF5 reads are as sequential as possible + order = sorted(range(len(indices)), key=lambda i: self._clip_indices[indices[i]]) + results = [None] * len(indices) + for pos in order: + results[pos] = self.__getitem__(indices[pos]) + return results + def _collate(samples): return {k: torch.stack([s[k] for s in samples], dim=0) for k in samples[0]} diff --git a/examples/leWorldModel/dataset.md b/examples/leWorldModel/dataset.md index ed0d737..c30c716 100644 --- a/examples/leWorldModel/dataset.md +++ b/examples/leWorldModel/dataset.md @@ -187,3 +187,29 @@ the dataset object — it is pickled safely to DataLoader workers. The LanceDB pipeline is a direct structural equivalent of the HDF5 pipeline with a different I/O backend. + + +``` + +# gpu sm mem enc dec jpg ofa +# Idx % % % % % % + 0 100 100 0 0 0 0 + 0 99 91 0 0 0 0 + 0 100 100 0 0 0 0 + 0 100 100 0 0 0 0 + 0 98 91 0 0 0 0 + 0 100 100 0 0 0 0 + 0 99 93 0 0 0 0 + 0 99 91 0 0 0 0 + 0 98 89 0 0 0 0 + 0 100 100 0 0 0 0 + 0 100 94 0 0 0 0 + 0 98 91 0 0 0 0 + 0 100 100 0 0 0 0 + 0 100 100 0 0 0 0 + 0 98 92 0 0 0 0 + 0 100 100 0 0 0 0 + 0 100 100 0 0 0 0 + 0 98 93 0 0 0 0 + +``` \ No newline at end of file From 145317262c364dbb24ffd81aa1e3f58972a77b80 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Thu, 2 Apr 2026 14:47:55 +0530 Subject: [PATCH 18/29] updare --- examples/leWorldModel/config/lewm_pusht.yaml | 3 +-- examples/leWorldModel/train.py | 16 ++++++++-------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/examples/leWorldModel/config/lewm_pusht.yaml b/examples/leWorldModel/config/lewm_pusht.yaml index 9ab2ba8..fcf0184 100644 --- a/examples/leWorldModel/config/lewm_pusht.yaml +++ b/examples/leWorldModel/config/lewm_pusht.yaml @@ -7,7 +7,7 @@ seed: 42 img_size: 224 trainer: - max_epochs: 10 + max_epochs: 100 precision: "bf16-mixed" # PyTorch Lightning format (le-wm uses "bf16" via Hydra) gradient_clip_val: 1.0 log_every_n_steps: 50 @@ -22,7 +22,6 @@ loader: optimizer: lr: 5.0e-5 weight_decay: 1.0e-3 - warmup_epochs: 10 # added: controls LinearWarmup → Cosine schedule wm: history_size: 3 diff --git a/examples/leWorldModel/train.py b/examples/leWorldModel/train.py index 43700c9..6e412d2 100644 --- a/examples/leWorldModel/train.py +++ b/examples/leWorldModel/train.py @@ -146,19 +146,20 @@ def configure_optimizers(self): lr=self.cfg["lr"], weight_decay=self.cfg["weight_decay"], ) - warmup = self.cfg["warmup_epochs"] - total = self.cfg["max_epochs"] - # Linear warmup → cosine decay (matches le-wm's LinearWarmupCosineAnnealingLR) + # Replicate le-wm's LinearWarmupCosineAnnealingLR exactly: + # warmup_steps = 1% of total steps (step-based, not epoch-based) + total_steps = self.trainer.estimated_stepping_batches + warmup_steps = max(1, int(0.01 * total_steps)) warmup_sched = torch.optim.lr_scheduler.LinearLR( - opt, start_factor=1e-4, end_factor=1.0, total_iters=warmup + opt, start_factor=0.0 + 1e-8, end_factor=1.0, total_iters=warmup_steps ) cosine_sched = torch.optim.lr_scheduler.CosineAnnealingLR( - opt, T_max=max(total - warmup, 1), eta_min=0 + opt, T_max=max(total_steps - warmup_steps, 1), eta_min=0 ) sched = torch.optim.lr_scheduler.SequentialLR( - opt, schedulers=[warmup_sched, cosine_sched], milestones=[warmup] + opt, schedulers=[warmup_sched, cosine_sched], milestones=[warmup_steps] ) - return [opt], [{"scheduler": sched, "interval": "epoch"}] + return [opt], [{"scheduler": sched, "interval": "step"}] # --------------------------------------------------------------------------- @@ -348,7 +349,6 @@ def main(): lightning_cfg = { "lr": opt_cfg["lr"], "weight_decay": opt_cfg["weight_decay"], - "warmup_epochs": opt_cfg["warmup_epochs"], "max_epochs": trainer_cfg["max_epochs"], "sigreg_weight": cfg["loss"]["sigreg"]["weight"], "history_size": wm_cfg["history_size"], From 790c67e78096e04e0c276d88015b8a3df4d1b843 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Fri, 3 Apr 2026 12:00:41 +0530 Subject: [PATCH 19/29] minor fix --- examples/leWorldModel/create_data.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/leWorldModel/create_data.py b/examples/leWorldModel/create_data.py index 2c072b1..43a96d0 100644 --- a/examples/leWorldModel/create_data.py +++ b/examples/leWorldModel/create_data.py @@ -480,8 +480,8 @@ def setup(self): ).cuda().eval() self.torch = torch - def __call__(self, pixel_bytes: bytes) -> list[float]: - img = _Image.open(_io.BytesIO(pixel_bytes)).convert("RGB") + def __call__(self, pixels: bytes) -> list[float]: + img = _Image.open(_io.BytesIO(pixels)).convert("RGB") t = _transform(img).unsqueeze(0).cuda() with self.torch.no_grad(): return self.model(t)[0].cpu().tolist() @@ -499,8 +499,8 @@ def setup(self): self.model.eval() self.torch = torch - def __call__(self, pixel_bytes: bytes) -> list[float]: - img = _Image.open(_io.BytesIO(pixel_bytes)).convert("RGB") + def __call__(self, pixels: bytes) -> list[float]: + img = _Image.open(_io.BytesIO(pixels)).convert("RGB") t = self.preprocess(img).unsqueeze(0).cuda() with self.torch.no_grad(): return self.model.encode_image(t)[0].cpu().float().tolist() @@ -520,8 +520,8 @@ def setup(self): self.encoder = model.encoder self.torch = torch - def __call__(self, pixel_bytes: bytes) -> list[float]: - img = _Image.open(_io.BytesIO(pixel_bytes)).convert("RGB") + def __call__(self, pixels: bytes) -> list[float]: + img = _Image.open(_io.BytesIO(pixels)).convert("RGB") t = _transform(img).unsqueeze(0).cuda() with self.torch.no_grad(): out = self.encoder(t) From 021dbe144413afb1deb5d539ae4b21dab31dc816 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Fri, 3 Apr 2026 12:16:57 +0530 Subject: [PATCH 20/29] add eval scripts --- examples/leWorldModel/README.md | 32 +++- examples/leWorldModel/config/eval/cube.yaml | 61 +++++++ .../config/eval/launcher/local.yaml | 7 + examples/leWorldModel/config/eval/pusht.yaml | 48 +++++ .../leWorldModel/config/eval/reacher.yaml | 50 +++++ .../leWorldModel/config/eval/solver/adam.yaml | 13 ++ .../leWorldModel/config/eval/solver/cem.yaml | 9 + .../leWorldModel/config/eval/tworoom.yaml | 47 +++++ examples/leWorldModel/create_data.py | 27 ++- examples/leWorldModel/eval.py | 171 ++++++++++++++++++ examples/leWorldModel/utils.py | 57 ++++++ 11 files changed, 513 insertions(+), 9 deletions(-) create mode 100644 examples/leWorldModel/config/eval/cube.yaml create mode 100644 examples/leWorldModel/config/eval/launcher/local.yaml create mode 100644 examples/leWorldModel/config/eval/pusht.yaml create mode 100644 examples/leWorldModel/config/eval/reacher.yaml create mode 100644 examples/leWorldModel/config/eval/solver/adam.yaml create mode 100644 examples/leWorldModel/config/eval/solver/cem.yaml create mode 100644 examples/leWorldModel/config/eval/tworoom.yaml create mode 100644 examples/leWorldModel/eval.py create mode 100644 examples/leWorldModel/utils.py diff --git a/examples/leWorldModel/README.md b/examples/leWorldModel/README.md index e90667b..d94c34c 100644 --- a/examples/leWorldModel/README.md +++ b/examples/leWorldModel/README.md @@ -140,7 +140,37 @@ python train.py --config config/lewm_pusht.yaml \ python train.py --config config/lewm_pusht.yaml --lance-uri s3://my-bucket/lewm ``` -### Step 5 — Post-training analysis with LeWM embeddings +### Step 5 — Evaluate planning performance (reproducing paper Table 1 / Figure 6) + +The paper's headline metric is **planning success rate** using CEM (Cross-Entropy Method) over the learned latent space — not loss values. To reproduce: + +```bash +# Install the evaluation stack (requires stable-worldmodel with env extras) +pip install "stable-worldmodel[train,env]" + +# Copy the checkpoint to stable-worldmodel's cache directory +mkdir -p ~/.stable_worldmodel/pusht +cp checkpoints/lewm_pusht_lewm_epoch_10_object.ckpt ~/.stable_worldmodel/pusht/lewm_object.ckpt + +# Run evaluation on PushT (target: ~90% success rate, Figure 6) +# eval.py and config/eval/ are vendored from https://github.com/lucas-maes/le-wm +python eval.py --config-name=pusht.yaml policy=pusht/lewm +``` + +Expected results from the paper (Figure 6): + +| Dataset | LeWM success rate | +|----------|-------------------| +| PushT | ~90% | +| TwoRoom | ~97% | +| OGBench-Cube | ~74% | +| Reacher | ~86% | + +> Note: the paper trains for 10 epochs on PushT and observes that further training does not improve planning performance. Evaluate the epoch-10 checkpoint first. + +--- + +### Step 6 — Post-training analysis with LeWM embeddings #### What we're doing and why diff --git a/examples/leWorldModel/config/eval/cube.yaml b/examples/leWorldModel/config/eval/cube.yaml new file mode 100644 index 0000000..3ba34bf --- /dev/null +++ b/examples/leWorldModel/config/eval/cube.yaml @@ -0,0 +1,61 @@ +defaults: + - launcher: local + - solver: cem + - _self_ + +world: + env_name: swm/OGBCube-v0 + num_envs: ${eval.num_eval} + max_episode_steps: ??? # make sure it's >= eval_budget + history_size: 1 + frame_skip: 1 + env_type: single + ob_type: states + multiview: False + width: 224 + height: 224 + visualize_info: False + terminate_at_goal: True + +dataset: + stats: ${eval.dataset_name} + keys_to_cache: + - action + +seed: 42 +policy: random # ckpt name or random + +plan_config: + horizon: 5 + receding_horizon: 5 + action_block: 5 # frameskip + +# evaluation from dataset (replay expert trajectories) +eval: + num_eval: 50 + goal_offset_steps: 25 + eval_budget: 50 + img_size: 224 + dataset_name: ogbench/cube_single_expert + callables: + # -- set state + - method: set_state + args: + qpos: + value: qpos + qvel: + value: qvel + # -- set target pos + - method: set_target_pos + args: + cube_id: + value: 0 + in_dataset: False + target_pos: + value: goal_privileged_block_0_pos + target_quat: + value: goal_privileged_block_0_quat + +output: + filename: ogb_cube_results.txt + diff --git a/examples/leWorldModel/config/eval/launcher/local.yaml b/examples/leWorldModel/config/eval/launcher/local.yaml new file mode 100644 index 0000000..a1c7f4b --- /dev/null +++ b/examples/leWorldModel/config/eval/launcher/local.yaml @@ -0,0 +1,7 @@ +# @package _global_ +# Local launcher configuration (no SLURM) + +defaults: + - override /hydra/launcher: basic + +cache_dir: null # use stable-worldmodel default cache diff --git a/examples/leWorldModel/config/eval/pusht.yaml b/examples/leWorldModel/config/eval/pusht.yaml new file mode 100644 index 0000000..6584ef0 --- /dev/null +++ b/examples/leWorldModel/config/eval/pusht.yaml @@ -0,0 +1,48 @@ +defaults: + - launcher: local + - solver: cem + - _self_ + +world: + env_name: swm/PushT-v1 + num_envs: ${eval.num_eval} + max_episode_steps: ??? # make sure it's >= eval_budget + history_size: 1 + frame_skip: 1 + +dataset: + stats: ${eval.dataset_name} + keys_to_cache: + - action + - proprio + - state + +seed: 42 +policy: random # ckpt name or random + +plan_config: + horizon: 5 + receding_horizon: 5 + action_block: 5 # frameskip + +# evaluation from dataset (replay expert trajectories) +eval: + num_eval: 50 + goal_offset_steps: 25 + eval_budget: 50 + img_size: 224 + dataset_name: pusht_expert_train + callables: + # -- set state + - method: _set_state + args: + state: + value: state + # -- set goal state + - method: _set_goal_state + args: + goal_state: + value: goal_state + +output: + filename: pusht_results.txt \ No newline at end of file diff --git a/examples/leWorldModel/config/eval/reacher.yaml b/examples/leWorldModel/config/eval/reacher.yaml new file mode 100644 index 0000000..d0c62dc --- /dev/null +++ b/examples/leWorldModel/config/eval/reacher.yaml @@ -0,0 +1,50 @@ +defaults: + - launcher: local + - solver: cem + - _self_ + +world: + env_name: swm/ReacherDMControl-v0 + num_envs: ${eval.num_eval} + max_episode_steps: ??? # make sure it's >= eval_budget + history_size: 1 + frame_skip: 1 + task: qpos_match + +dataset: + stats: ${eval.dataset_name} + keys_to_cache: + - action + +seed: 42 +policy: random # ckpt name or random + +plan_config: + horizon: 5 + receding_horizon: 5 + action_block: 5 # frameskip + +# evaluation from dataset (replay expert trajectories) +eval: + num_eval: 50 + goal_offset_steps: 25 + eval_budget: 50 + img_size: 224 + dataset_name: dmc/reacher_random + callables: + # -- set state + - method: set_state + args: + qpos: + value: qpos + qvel: + value: qvel + + - method: set_target_qpos + args: + target_qpos: + value: goal_qpos + +output: + filename: dmc_results.txt + diff --git a/examples/leWorldModel/config/eval/solver/adam.yaml b/examples/leWorldModel/config/eval/solver/adam.yaml new file mode 100644 index 0000000..763b0f0 --- /dev/null +++ b/examples/leWorldModel/config/eval/solver/adam.yaml @@ -0,0 +1,13 @@ +_target_: stable_worldmodel.solver.GradientSolver +model: ??? +n_steps: 30 +batch_size: 1 +num_samples: 100 +action_noise: 0 +device: "cuda" +seed: ${seed} +optimizer_cls: + _target_: hydra.utils.get_class + path: torch.optim.AdamW +optimizer_kwargs: + lr: 0.1 \ No newline at end of file diff --git a/examples/leWorldModel/config/eval/solver/cem.yaml b/examples/leWorldModel/config/eval/solver/cem.yaml new file mode 100644 index 0000000..8d24fda --- /dev/null +++ b/examples/leWorldModel/config/eval/solver/cem.yaml @@ -0,0 +1,9 @@ +_target_: stable_worldmodel.solver.CEMSolver +model: ??? +batch_size: 1 +num_samples: 300 +var_scale: 1.0 +n_steps: 30 +topk: 30 +device: "cuda" +seed: ${seed} diff --git a/examples/leWorldModel/config/eval/tworoom.yaml b/examples/leWorldModel/config/eval/tworoom.yaml new file mode 100644 index 0000000..dd20571 --- /dev/null +++ b/examples/leWorldModel/config/eval/tworoom.yaml @@ -0,0 +1,47 @@ +defaults: + - launcher: local + - solver: cem + - _self_ + +world: + env_name: swm/TwoRoom-v1 + num_envs: ${eval.num_eval} + max_episode_steps: ??? # make sure it's >= eval_budget + history_size: 1 + frame_skip: 1 + +seed: 42 +policy: random # ckpt name or random + +dataset: + stats: ${eval.dataset_name} + keys_to_cache: + - action + - proprio + +plan_config: + horizon: 5 + receding_horizon: 5 + action_block: 5 # frameskip + +# evaluation from dataset (replay expert trajectories) +eval: + num_eval: 50 + goal_offset_steps: 25 + eval_budget: 50 + img_size: 224 + dataset_name: tworoom + callables: + # -- set state + - method: _set_state + args: + state: + value: proprio + # -- set goal state + - method: _set_goal_state + args: + goal_state: + value: goal_proprio + +output: + filename: tworoom_results.txt \ No newline at end of file diff --git a/examples/leWorldModel/create_data.py b/examples/leWorldModel/create_data.py index 43a96d0..e045e97 100644 --- a/examples/leWorldModel/create_data.py +++ b/examples/leWorldModel/create_data.py @@ -390,7 +390,7 @@ def add_embeddings_geneva( checkpoint: str | None = None, batch_size: int = 32, img_size: int = 224, - concurrency: int = 1, + concurrency: int = 2, connect_kwargs: dict | None = None, ): """ @@ -417,8 +417,8 @@ def add_embeddings_geneva( col_name = f"emb_{model_name}" if col_name in tbl.schema.names: - print(f" '{col_name}' column already present. Skipping.") - return + print(f" '{col_name}' column already present. dropping.") + tbl.drop_columns([col_name]) # Build the UDF class for the chosen model udf_cls = _make_embedding_udf(model_name, checkpoint, img_size) @@ -471,7 +471,7 @@ def _make_embedding_udf(model_name: str, checkpoint: str | None, img_size: int): if model_name == "dinov2": EMBED_DIM = 384 - @geneva.udf(data_type=pa.list_(pa.float32(), EMBED_DIM)) + @geneva.udf(data_type=pa.list_(pa.float32(), EMBED_DIM), num_gpus=0.5) class DINOv2Embedder: def setup(self): import timm, torch @@ -491,7 +491,7 @@ def __call__(self, pixels: bytes) -> list[float]: if model_name == "clip": EMBED_DIM = 512 - @geneva.udf(data_type=pa.list_(pa.float32(), EMBED_DIM)) + @geneva.udf(data_type=pa.list_(pa.float32(), EMBED_DIM), num_gpus=0.5) class CLIPEmbedder: def setup(self): import clip, torch @@ -511,16 +511,27 @@ def __call__(self, pixels: bytes) -> list[float]: assert checkpoint, "--checkpoint is required for --embedding-model lewm" _ckpt = checkpoint - @geneva.udf(data_type=pa.list_(pa.float32(), 192)) # ViT-tiny embed_dim + @geneva.udf(data_type=pa.list_(pa.float32(), 192), num_gpus=0.5) # ViT-tiny embed_dim class LeWMEmbedder: def setup(self): - import torch - model = torch.load(_ckpt, map_location="cuda") + import sys, os, torch + # Ray workers need jepa.py on the path (vendored next to create_data.py) + _here = os.path.dirname(os.path.abspath(__file__)) + if _here not in sys.path: + sys.path.insert(0, _here) + # weights_only=False required for torch.save'd model objects (PyTorch >= 2.6) + model = torch.load(_ckpt, map_location="cuda", weights_only=False) model.eval() self.encoder = model.encoder self.torch = torch + + def _ensure_setup(self): + if not hasattr(self, "torch"): + self.setup() + def __call__(self, pixels: bytes) -> list[float]: + self._ensure_setup() img = _Image.open(_io.BytesIO(pixels)).convert("RGB") t = _transform(img).unsqueeze(0).cuda() with self.torch.no_grad(): diff --git a/examples/leWorldModel/eval.py b/examples/leWorldModel/eval.py new file mode 100644 index 0000000..859afd1 --- /dev/null +++ b/examples/leWorldModel/eval.py @@ -0,0 +1,171 @@ +import os + +os.environ["MUJOCO_GL"] = "egl" + +import time +from pathlib import Path + +import hydra +import numpy as np +import stable_pretraining as spt +import torch +from omegaconf import DictConfig, OmegaConf +from sklearn import preprocessing +from torchvision.transforms import v2 as transforms +import stable_worldmodel as swm + +def img_transform(cfg): + transform = transforms.Compose( + [ + transforms.ToImage(), + transforms.ToDtype(torch.float32, scale=True), + transforms.Normalize(**spt.data.dataset_stats.ImageNet), + transforms.Resize(size=cfg.eval.img_size), + ] + ) + return transform + + +def get_episodes_length(dataset, episodes): + col_name = "episode_idx" if "episode_idx" in dataset.column_names else "ep_idx" + + episode_idx = dataset.get_col_data(col_name) + step_idx = dataset.get_col_data("step_idx") + lengths = [] + for ep_id in episodes: + lengths.append(np.max(step_idx[episode_idx == ep_id]) + 1) + return np.array(lengths) + + +def get_dataset(cfg, dataset_name): + dataset_path = Path(cfg.cache_dir or swm.data.utils.get_cache_dir()) + dataset = swm.data.HDF5Dataset( + dataset_name, + keys_to_cache=cfg.dataset.keys_to_cache, + cache_dir=dataset_path, + ) + return dataset + +@hydra.main(version_base=None, config_path="./config/eval", config_name="pusht") +def run(cfg: DictConfig): + """Run evaluation of dinowm vs random policy.""" + assert ( + cfg.plan_config.horizon * cfg.plan_config.action_block <= cfg.eval.eval_budget + ), "Planning horizon must be smaller than or equal to eval_budget" + + # create world environment + cfg.world.max_episode_steps = 2 * cfg.eval.eval_budget + world = swm.World(**cfg.world, image_shape=(224, 224)) + + # create the transform + transform = { + "pixels": img_transform(cfg), + "goal": img_transform(cfg), + } + + dataset = get_dataset(cfg, cfg.eval.dataset_name) + stats_dataset = dataset # get_dataset(cfg, cfg.dataset.stats) + col_name = "episode_idx" if "episode_idx" in dataset.column_names else "ep_idx" + ep_indices, _ = np.unique(stats_dataset.get_col_data(col_name), return_index=True) + + process = {} + for col in cfg.dataset.keys_to_cache: + if col in ["pixels"]: + continue + processor = preprocessing.StandardScaler() + col_data = stats_dataset.get_col_data(col) + col_data = col_data[~np.isnan(col_data).any(axis=1)] + processor.fit(col_data) + process[col] = processor + + if col != "action": + process[f"goal_{col}"] = process[col] + + # -- run evaluation + policy = cfg.get("policy", "random") + + if policy != "random": + model = swm.policy.AutoCostModel(cfg.policy) + model = model.to("cuda") + model = model.eval() + model.requires_grad_(False) + model.interpolate_pos_encoding = True + config = swm.PlanConfig(**cfg.plan_config) + solver = hydra.utils.instantiate(cfg.solver, model=model) + policy = swm.policy.WorldModelPolicy( + solver=solver, config=config, process=process, transform=transform + ) + + else: + policy = swm.policy.RandomPolicy() + + results_path = ( + Path(swm.data.utils.get_cache_dir(), cfg.policy).parent + if cfg.policy != "random" + else Path(__file__).parent + ) + + # sample the episodes and the starting indices + episode_len = get_episodes_length(dataset, ep_indices) + max_start_idx = episode_len - cfg.eval.goal_offset_steps - 1 + max_start_idx_dict = {ep_id: max_start_idx[i] for i, ep_id in enumerate(ep_indices)} + # Map each dataset row’s episode_idx to its max_start_idx + col_name = "episode_idx" if "episode_idx" in dataset.column_names else "ep_idx" + max_start_per_row = np.array( + [max_start_idx_dict[ep_id] for ep_id in dataset.get_col_data(col_name)] + ) + + # remove all the lines of dataset for which dataset['step_idx'] > max_start_per_row + valid_mask = dataset.get_col_data("step_idx") <= max_start_per_row + valid_indices = np.nonzero(valid_mask)[0] + print(valid_mask.sum(), "valid starting points found for evaluation.") + + g = np.random.default_rng(cfg.seed) + random_episode_indices = g.choice( + len(valid_indices) - 1, size=cfg.eval.num_eval, replace=False + ) + + # sort increasingly to avoid issues with HDF5Dataset indexing + random_episode_indices = np.sort(valid_indices[random_episode_indices]) + + print(random_episode_indices) + + eval_episodes = dataset.get_row_data(random_episode_indices)[col_name] + eval_start_idx = dataset.get_row_data(random_episode_indices)["step_idx"] + + if len(eval_episodes) < cfg.eval.num_eval: + raise ValueError("Not enough episodes with sufficient length for evaluation.") + + world.set_policy(policy) + + start_time = time.time() + metrics = world.evaluate_from_dataset( + dataset, + start_steps=eval_start_idx.tolist(), + goal_offset_steps=cfg.eval.goal_offset_steps, + eval_budget=cfg.eval.eval_budget, + episodes_idx=eval_episodes.tolist(), + callables=OmegaConf.to_container(cfg.eval.get("callables"), resolve=True), + video_path=results_path, + ) + end_time = time.time() + + print(metrics) + + results_path = results_path / cfg.output.filename + results_path.parent.mkdir(parents=True, exist_ok=True) + + with results_path.open("a") as f: + f.write("\n") # separate from previous runs + + f.write("==== CONFIG ====\n") + f.write(OmegaConf.to_yaml(cfg)) + f.write("\n") + + f.write("==== RESULTS ====\n") + f.write(f"metrics: {metrics}\n") + f.write(f"evaluation_time: {end_time - start_time} seconds\n") + + +if __name__ == "__main__": + run() diff --git a/examples/leWorldModel/utils.py b/examples/leWorldModel/utils.py new file mode 100644 index 0000000..a1c234e --- /dev/null +++ b/examples/leWorldModel/utils.py @@ -0,0 +1,57 @@ +import numpy as np +import torch +from pathlib import Path +from stable_pretraining import data as dt +from lightning.pytorch.callbacks import Callback + +def get_img_preprocessor(source: str, target: str, img_size: int = 224): + imagenet_stats = dt.dataset_stats.ImageNet + to_image = dt.transforms.ToImage(**imagenet_stats, source=source, target=target) + resize = dt.transforms.Resize(img_size, source=source, target=target) + return dt.transforms.Compose(to_image, resize) + + +def get_column_normalizer(dataset, source: str, target: str): + """Get normalizer for a specific column in the dataset.""" + col_data = dataset.get_col_data(source) + data = torch.from_numpy(np.array(col_data)) + data = data[~torch.isnan(data).any(dim=1)] + mean = data.mean(0, keepdim=True).clone() + std = data.std(0, keepdim=True).clone() + + def norm_fn(x): + return ((x - mean) / std).float() + + normalizer = dt.transforms.WrapTorchTransform(norm_fn, source=source, target=target) + return normalizer + +class ModelObjectCallBack(Callback): + """Callback to pickle model object after each epoch.""" + + def __init__(self, dirpath, filename="model_object", epoch_interval: int = 1): + super().__init__() + self.dirpath = Path(dirpath) + self.filename = filename + self.epoch_interval = epoch_interval + + def on_train_epoch_end(self, trainer, pl_module): + super().on_train_epoch_end(trainer, pl_module) + + output_path = ( + self.dirpath + / f"{self.filename}_epoch_{trainer.current_epoch + 1}_object.ckpt" + ) + + if trainer.is_global_zero: + if (trainer.current_epoch + 1) % self.epoch_interval == 0: + self._dump_model(pl_module.model, output_path) + + # save final epoch + if (trainer.current_epoch + 1) == trainer.max_epochs: + self._dump_model(pl_module.model, output_path) + + def _dump_model(self, model, path): + try: + torch.save(model, path) + except Exception as e: + print(f"Error saving model object: {e}") \ No newline at end of file From b7589ac4c2906bef86d211b9c5f825964b66ec57 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Fri, 3 Apr 2026 12:21:38 +0530 Subject: [PATCH 21/29] update --- examples/leWorldModel/create_data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/leWorldModel/create_data.py b/examples/leWorldModel/create_data.py index e045e97..b02bb77 100644 --- a/examples/leWorldModel/create_data.py +++ b/examples/leWorldModel/create_data.py @@ -535,8 +535,8 @@ def __call__(self, pixels: bytes) -> list[float]: img = _Image.open(_io.BytesIO(pixels)).convert("RGB") t = _transform(img).unsqueeze(0).cuda() with self.torch.no_grad(): - out = self.encoder(t) - return out[0, 0, :].cpu().tolist() # CLS token + out = self.encoder(t, interpolate_pos_encoding=True) + return out.last_hidden_state[0, 0, :].cpu().tolist() # CLS token return LeWMEmbedder From 77f810e29cb3f61d8b0b52726d20020d1542c13c Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Fri, 3 Apr 2026 12:48:11 +0530 Subject: [PATCH 22/29] add prep eval util --- examples/leWorldModel/README.md | 17 ++++--- examples/leWorldModel/prepare_eval.py | 64 +++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 6 deletions(-) create mode 100644 examples/leWorldModel/prepare_eval.py diff --git a/examples/leWorldModel/README.md b/examples/leWorldModel/README.md index d94c34c..240ece8 100644 --- a/examples/leWorldModel/README.md +++ b/examples/leWorldModel/README.md @@ -145,18 +145,23 @@ python train.py --config config/lewm_pusht.yaml --lance-uri s3://my-bucket/lewm The paper's headline metric is **planning success rate** using CEM (Cross-Entropy Method) over the learned latent space — not loss values. To reproduce: ```bash -# Install the evaluation stack (requires stable-worldmodel with env extras) +# Install the evaluation stack pip install "stable-worldmodel[train,env]" -# Copy the checkpoint to stable-worldmodel's cache directory -mkdir -p ~/.stable_worldmodel/pusht -cp checkpoints/lewm_pusht_lewm_epoch_10_object.ckpt ~/.stable_worldmodel/pusht/lewm_object.ckpt +# stable_worldmodel's AutoCostModel looks for checkpoints under $STABLEWM_HOME +# as _object.ckpt. prepare_eval.py handles the placement automatically: +python prepare_eval.py \ + --checkpoint checkpoints/lewm_pusht_lewm_epoch_10_object.ckpt \ + --dataset pusht + +# prepare_eval.py prints the exact command to run, e.g.: +python eval.py --config-name=pusht.yaml policy=lewm_pusht_lewm_epoch_10 -# Run evaluation on PushT (target: ~90% success rate, Figure 6) # eval.py and config/eval/ are vendored from https://github.com/lucas-maes/le-wm -python eval.py --config-name=pusht.yaml policy=pusht/lewm ``` +`prepare_eval.py` symlinks the checkpoint into `~/.stable_worldmodel/` with the name that `AutoCostModel` expects, then prints the ready-to-run `eval.py` command. Use `--copy` if the checkpoint and home directory are on different filesystems. + Expected results from the paper (Figure 6): | Dataset | LeWM success rate | diff --git a/examples/leWorldModel/prepare_eval.py b/examples/leWorldModel/prepare_eval.py new file mode 100644 index 0000000..8d1c38e --- /dev/null +++ b/examples/leWorldModel/prepare_eval.py @@ -0,0 +1,64 @@ +""" +Prepare a LeWM checkpoint for evaluation with eval.py. + +stable_worldmodel's AutoCostModel expects: + - checkpoint at $STABLEWM_HOME/_object.ckpt + - policy argument passed as (without _object.ckpt suffix) + +This script copies/symlinks the checkpoint to the right location and prints +the exact eval.py command to run. + +Usage: + python prepare_eval.py --checkpoint checkpoints/lewm_pusht_lewm_epoch_10_object.ckpt + python prepare_eval.py --checkpoint checkpoints/lewm_pusht_lewm_epoch_10_object.ckpt --run-name lewm_pusht +""" +import argparse +import os +import shutil +from pathlib import Path + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--checkpoint", required=True, help="Path to *_object.ckpt file") + parser.add_argument("--run-name", default=None, + help="Name to use under STABLEWM_HOME (default: derived from checkpoint filename)") + parser.add_argument("--dataset", default="pusht", choices=["pusht", "cube", "reacher", "tworoom"], + help="Dataset to evaluate on") + parser.add_argument("--copy", action="store_true", + help="Copy instead of symlinking (use when src and dst are on different filesystems)") + args = parser.parse_args() + + ckpt = Path(args.checkpoint).resolve() + if not ckpt.exists(): + raise FileNotFoundError(f"Checkpoint not found: {ckpt}") + + stablewm_home = Path(os.environ.get("STABLEWM_HOME", Path.home() / ".stable_worldmodel")) + stablewm_home.mkdir(parents=True, exist_ok=True) + + # Derive run_name: strip _object.ckpt suffix if present + stem = ckpt.stem # e.g. lewm_pusht_lewm_epoch_10_object + if stem.endswith("_object"): + stem = stem[: -len("_object")] # lewm_pusht_lewm_epoch_10 + run_name = args.run_name or stem + + dst = stablewm_home / f"{run_name}_object.ckpt" + + if dst.exists() or dst.is_symlink(): + dst.unlink() + + if args.copy: + shutil.copy2(ckpt, dst) + print(f"Copied {ckpt}") + else: + dst.symlink_to(ckpt) + print(f"Symlinked {ckpt}") + + print(f" → {dst}") + print() + print("Run evaluation with:") + print(f" python eval.py --config-name={args.dataset}.yaml policy={run_name}") + + +if __name__ == "__main__": + main() From 03bba355eb3c41e501bd00ececba0623193ec892 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Fri, 3 Apr 2026 21:19:42 +0530 Subject: [PATCH 23/29] fixes --- examples/leWorldModel/bench.py | 207 ++++++++++++------ .../leWorldModel/lewm_loader/dataloaders.py | 38 ++-- examples/leWorldModel/lewm_loader/dataset.py | 43 +--- 3 files changed, 168 insertions(+), 120 deletions(-) diff --git a/examples/leWorldModel/bench.py b/examples/leWorldModel/bench.py index 731d79d..cecc776 100644 --- a/examples/leWorldModel/bench.py +++ b/examples/leWorldModel/bench.py @@ -1,15 +1,25 @@ """ leWorldModel dataloader throughput benchmark: LanceDB vs HDF5. -Measures how fast each backend can feed batches to the GPU, independently -of training compute. Three backends: - - LanceDB S3/local — our implementation, parallel workers, no download step - HDF5 local — reads from a local file (best-case for HDF5) - HDF5 s3fs — reads directly from S3 via s3fs (realistic, no download) +Measures raw dataloader throughput (samples/sec) for two backends, independently +of GPU compute. Run this on your target hardware to size num_workers and +batch_size before committing to a long training run. + +Backends +-------- + LanceDB — our implementation; local or S3-backed + HDF5 — reads from a local file (put in /dev/shm for RAM-backed I/O) + HDF5-s3fs — reads the HDF5 file directly from S3 via s3fs (no local copy) + +Usage +----- + # LanceDB local vs HDF5 local + python bench.py \\ + --lance-uri ./lewm_lance \\ + --table-name lewm_pusht \\ + --hdf5-local /path/to/pusht.hdf5 -Usage: - # LanceDB S3 vs HDF5 local + # LanceDB S3 vs HDF5 local (put HDF5 in /dev/shm for best-case HDF5) python bench.py \\ --lance-uri s3://my-bucket/lewm \\ --table-name lewm_pusht \\ @@ -23,8 +33,42 @@ --hdf5-s3-key hdf5/pusht.hdf5 \\ --s3-bucket my-bucket - # Credentials via environment variables (AWS_ACCESS_KEY_ID etc.) - python bench.py --lance-uri s3://my-bucket/lewm --table-name lewm_pusht +Why LanceDB S3 outperforms HDF5 local (/dev/shm) +------------------------------------------------- +Counter-intuitive but consistent across hardware. Three compounding reasons: + +1. HDF5 POSIX file lock — the dominant factor + h5py acquires a POSIX advisory lock at the file level for every read + operation. With num_workers=8, all 8 worker processes try to read + concurrently but are serialized by this lock. Effective throughput is + limited to roughly 1 worker at a time, regardless of how many workers + you spawn. LanceDB opens an independent async connection per worker with + no shared lock — all 8 workers genuinely read in parallel. + +2. JPEG pixel compression — reduces bytes in flight + LanceDB stores pixels as JPEG (~3–5 KB per frame for typical robotics + frames). A training window of T=4 frames requires ~12–20 KB of pixel + data. The equivalent HDF5 raw uint8 window is ~110–550 KB depending on + resolution. Even over S3, LanceDB transfers far less data per sample. + The JPEG decode cost on CPU is small compared to the I/O savings. + +3. Batch-level row fetching via __getitems__ + LeWMLanceDataset.__getitems__ resolves all B×span rows for a batch in a + single async call to Permutation.__getitems__. HDF5 individual seeks, + even when sorted by file offset, are still each subject to the POSIX lock. + +Combined effect: HDF5 has ~1 effective worker reading ~200+ KB/sample; +LanceDB has num_workers effective workers reading ~16 KB/sample. +Theoretical ratio: 8 × (200/16) ≈ 100×. Observed speedup is ~7× due to +real-world overhead (S3 latency, JPEG decode CPU, prefetch pipeline startup). + +Why HDF5-s3fs is worst +----------------------- +HDF5 was designed for local POSIX filesystems. It issues many small random +seeks per batch (one per column per window). Over S3, each seek becomes a +separate HTTP range request. For a T=4 window with 3 columns, that's ~12 +HTTP requests per sample — hundreds per batch — all still serialized by the +POSIX lock. """ import argparse @@ -46,22 +90,31 @@ # --------------------------------------------------------------------------- -# Defaults +# Benchmark defaults — match the training config # --------------------------------------------------------------------------- BATCH_SIZE = 128 -NUM_STEPS = 4 -FRAMESKIP = 5 # must match training config; both backends use the same value +NUM_STEPS = 4 # history_size (3) + num_preds (1) +FRAMESKIP = 5 # le-wm paper default; both backends must use the same value IMAGE_SIZE = 224 NUM_WORKERS = 8 PREFETCH_FACTOR = 3 WARMUP_BATCHES = 5 BENCH_BATCHES = 50 -COLUMNS = ["pixels", "action", "proprio", "state"] # --------------------------------------------------------------------------- -# HDF5 dataset (mirrors the original stable_worldmodel.data.HDF5Dataset) +# HDF5 dataset +# +# Mirrors stable-worldmodel's HDF5Dataset to give HDF5 its best-case numbers: +# - All non-pixel columns cached in RAM at init (no per-sample column I/O) +# - Pixels read with stride=frameskip to minimise seek distance +# - __getitems__ sorts by file offset for sequential access +# - action is stacked (span rows → T × frameskip×action_dim), matching +# the LanceDB training format and le-wm's effective_act_dim convention +# +# Despite these optimisations, the POSIX file lock still limits effective +# parallelism to ~1 worker. See module docstring for details. # --------------------------------------------------------------------------- _TRANSFORM = transforms.Compose([ @@ -73,23 +126,18 @@ class HDF5LeWMDataset(torch.utils.data.Dataset): """ - HDF5-backed temporal-window dataset matching stable-worldmodel's HDF5Dataset. - - The HDF5 schema uses per-episode metadata arrays: - ep_len — shape (n_episodes,) episode lengths - ep_offset — shape (n_episodes,) global start row per episode - - Valid clip_indices are (episode_idx, local_start) pairs where a full window - of span = num_steps * frameskip rows fits within the episode. At read time, - the global slice [offset + local_start : offset + local_start + span] is - fetched and every frameskip-th frame is selected. - - Pixels are stored as (N, H, W, C) uint8 — no transpose needed before PIL. - Non-pixel columns (action, proprio, etc.) are cached in RAM at init time; - only pixels are read from the file at __getitem__ time. - - hdf5_src can be a local file path (str) or an s3fs file object. - h5py is opened lazily per worker because handles are not fork-safe. + HDF5-backed temporal-window dataset. + + Schema expected (stable-worldmodel format): + ep_len — (n_episodes,) episode lengths + ep_offset — (n_episodes,) global start row per episode + pixels — (N, H, W, C) uint8 + action — (N, action_dim) float32 + ... + + All non-pixel columns are fully cached in RAM at __init__ to avoid + repeated random HDF5 seeks for vector data. Only pixels are read + from the file per sample, since caching them would consume too much RAM. """ def __init__(self, hdf5_src, columns, num_steps=NUM_STEPS, frameskip=FRAMESKIP): @@ -103,10 +151,10 @@ def __init__(self, hdf5_src, columns, num_steps=NUM_STEPS, frameskip=FRAMESKIP): with h5py.File(self._src, "r", rdcc_nbytes=256 * 1024 * 1024) as f: ep_len = np.array(f["ep_len"], dtype=np.int32) ep_offset = np.array(f["ep_offset"], dtype=np.int32) - # Cache all non-pixel columns in RAM — avoids repeated random HDF5 seeks + # Cache all non-pixel columns in RAM to avoid repeated HDF5 seeks self._cached: dict[str, np.ndarray] = {} for col in columns: - if col != "pixels": + if col != "pixels" and col in f: self._cached[col] = np.array(f[col], dtype=np.float32) # Build (ep_idx, local_start) pairs for all valid windows @@ -124,37 +172,48 @@ def __len__(self): def __getstate__(self): state = self.__dict__.copy() - state["_file"] = None # h5py handle can't be pickled + state["_file"] = None # h5py handle is not fork-safe return state def _ensure_open(self): if self._file is None: - self._file = h5py.File(self._src, "r", swmr=True, rdcc_nbytes=256 * 1024 * 1024) + # Open without SWMR: represents typical HDF5 usage. + # The POSIX advisory lock acquired here is per-file (by inode), so + # all 8 DataLoader worker processes sharing this file will serialize + # their reads through it regardless of which worker holds the handle. + self._file = h5py.File(self._src, "r", rdcc_nbytes=256 * 1024 * 1024) - def __getitem__(self, clip_idx): + def __getitem__(self, clip_idx: int) -> dict[str, torch.Tensor]: self._ensure_open() ep_idx, local_start = self._clip_indices[clip_idx] g_start = int(self._ep_offset[ep_idx]) + local_start g_end = g_start + self._span - pixels_raw = self._file["pixels"][g_start:g_end:self.frameskip] # (T, H, W, C) + # Pixels: stride by frameskip → (T, H, W, C) then PIL-transform + pixels_raw = self._file["pixels"][g_start:g_end:self.frameskip] frames = [ _TRANSFORM(Image.fromarray(pixels_raw[t].astype(np.uint8))) for t in range(self.num_steps) ] - sample = {"pixels": torch.stack(frames)} + for col in self.columns: - if col == "pixels": + if col == "pixels" or col not in self._cached: continue - data = self._cached[col][g_start:g_end:self.frameskip] - sample[col] = torch.from_numpy(np.nan_to_num(data, nan=0.0)) + if col == "action": + # Stack all span rows: (span, action_dim) → (T, frameskip×action_dim) + # Matches le-wm's effective_act_dim = frameskip × raw_action_dim + data = self._cached[col][g_start:g_end].reshape(self.num_steps, -1) + else: + # Proprio, state, observation: stride by frameskip → (T, D) + data = self._cached[col][g_start:g_end:self.frameskip] + sample[col] = torch.from_numpy(np.nan_to_num(data.astype(np.float32), nan=0.0)) + return sample def __getitems__(self, indices: list[int]) -> list[dict]: - """Batch pixel reads sorted by file offset to improve sequential access.""" + """Sort reads by file offset to minimise seeks — best-case HDF5 access.""" self._ensure_open() - # Sort by file offset so HDF5 reads are as sequential as possible order = sorted(range(len(indices)), key=lambda i: self._clip_indices[indices[i]]) results = [None] * len(indices) for pos in order: @@ -182,13 +241,13 @@ def make_hdf5_loader(hdf5_src, columns, batch_size, num_workers, prefetch_factor # --------------------------------------------------------------------------- -# Benchmark +# Benchmark runner # --------------------------------------------------------------------------- def measure_throughput(loader, label, warmup, steps): """ Iterate the loader for `warmup` batches (discarded), then time `steps` batches. - Returns samples/sec and average batch latency in ms. + Returns a result dict with samples/sec, avg batch latency, and p99 latency. """ print(f"\n{'─' * 60}") print(f" {label}") @@ -203,7 +262,6 @@ def measure_throughput(loader, label, warmup, steps): if batch is None: it = iter(loader) batch = next(it) - # Touch the pixels tensor to ensure decoding actually happened _ = batch["pixels"].shape print(f" benchmarking ({steps} batches)...") @@ -240,16 +298,25 @@ def _build_parser(): description="Benchmark LanceDB vs HDF5 dataloader throughput for leWorldModel", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - p.add_argument("--lance-uri", required=True) - p.add_argument("--table-name", required=True) - p.add_argument("--hdf5-local", default=None, help="Path to local HDF5 file") - p.add_argument("--hdf5-s3-key", default=None, help="S3 object key for HDF5 file") - p.add_argument("--s3-bucket", default=None, help="S3 bucket (for --hdf5-s3-key)") - p.add_argument("--columns", nargs="+", default=COLUMNS) - p.add_argument("--batch-size", type=int, default=BATCH_SIZE) - p.add_argument("--num-workers", type=int, default=NUM_WORKERS) - p.add_argument("--warmup", type=int, default=WARMUP_BATCHES) - p.add_argument("--steps", type=int, default=BENCH_BATCHES) + p.add_argument("--lance-uri", required=True, + help="LanceDB URI (local path or s3://bucket/prefix)") + p.add_argument("--table-name", required=True, + help="LanceDB table name (e.g. lewm_pusht)") + p.add_argument("--hdf5-local", default=None, + help="Path to local HDF5 file (use /dev/shm for RAM-backed best-case)") + p.add_argument("--hdf5-s3-key", default=None, + help="S3 object key for HDF5 file (requires --s3-bucket)") + p.add_argument("--s3-bucket", default=None, + help="S3 bucket name (for --hdf5-s3-key)") + p.add_argument("--columns", nargs="+", + default=["pixels", "action", "proprio"], + help="Columns to load. Use dataset-appropriate columns " + "(pusht: pixels action proprio state; " + "reacher/cube: pixels action observation)") + p.add_argument("--batch-size", type=int, default=BATCH_SIZE) + p.add_argument("--num-workers", type=int, default=NUM_WORKERS) + p.add_argument("--warmup", type=int, default=WARMUP_BATCHES) + p.add_argument("--steps", type=int, default=BENCH_BATCHES) s3 = p.add_argument_group("S3 credentials (fall back to AWS_* env vars)") s3.add_argument("--aws-access-key-id", default=os.environ.get("AWS_ACCESS_KEY_ID")) @@ -266,10 +333,10 @@ def main(): print(f"\nleWorldModel dataloader benchmark") print(f" batch_size : {args.batch_size}") print(f" num_workers : {args.num_workers}") - print(f" T (frames) : {NUM_STEPS}") + print(f" T (frames) : {NUM_STEPS} frameskip: {FRAMESKIP}") print(f" warmup : {args.warmup} batches bench: {args.steps} batches") + print(f" columns : {args.columns}") - # Build S3 storage_options for LanceDB storage_options = {} if args.aws_access_key_id: storage_options["aws_access_key_id"] = args.aws_access_key_id @@ -310,7 +377,8 @@ def main(): # 2. HDF5 local if args.hdf5_local: hdf5_local_loader = make_hdf5_loader( - args.hdf5_local, args.columns, args.batch_size, args.num_workers, PREFETCH_FACTOR, + args.hdf5_local, args.columns, + args.batch_size, args.num_workers, PREFETCH_FACTOR, ) results.append(measure_throughput( hdf5_local_loader, @@ -318,7 +386,7 @@ def main(): args.warmup, args.steps, )) - # 3. HDF5 via s3fs (reads directly from S3, no local copy) + # 3. HDF5 via s3fs (reads directly from S3 without downloading) if args.hdf5_s3_key and args.s3_bucket: import s3fs s3_kwargs = {} @@ -340,7 +408,8 @@ def main(): s3_file = fs.open(f"{args.s3_bucket}/{args.hdf5_s3_key}", "rb") hdf5_s3_loader = make_hdf5_loader( - s3_file, args.columns, args.batch_size, args.num_workers, PREFETCH_FACTOR, + s3_file, args.columns, + args.batch_size, args.num_workers, PREFETCH_FACTOR, ) results.append(measure_throughput( hdf5_s3_loader, @@ -348,16 +417,22 @@ def main(): args.warmup, args.steps, )) - # Summary table + # Summary table — baseline is the slowest backend if len(results) > 1: - baseline = results[-1]["samples_sec"] + baseline = min(r["samples_sec"] for r in results) print(f"\n{'=' * 60}") - print(f" {'Backend':<46} {'samples/sec':>12} {'avg ms':>8} {'speedup':>8}") + print(f" {'Backend':<44} {'samples/sec':>12} {'avg ms':>8} {'speedup':>8}") print(f"{'─' * 60}") for r in sorted(results, key=lambda x: -x["samples_sec"]): speedup = r["samples_sec"] / baseline - print(f" {r['label']:<46} {r['samples_sec']:>12,.0f} {r['avg_ms']:>7.1f} {speedup:>7.1f}×") + print(f" {r['label']:<44} {r['samples_sec']:>12,.0f} {r['avg_ms']:>7.1f} {speedup:>7.1f}×") print(f"{'=' * 60}") + print() + print("Key: LanceDB S3 > HDF5 local despite the network hop because:") + print(" - HDF5 POSIX lock serialises all worker reads (effective parallelism ~1)") + print(" - LanceDB: each worker holds an independent S3 connection (true parallelism)") + print(" - LanceDB JPEG pixels: ~13× smaller than raw HDF5 uint8 → less I/O per sample") + print(" - LanceDB __getitems__: entire batch in one async round trip vs N serial seeks") if __name__ == "__main__": diff --git a/examples/leWorldModel/lewm_loader/dataloaders.py b/examples/leWorldModel/lewm_loader/dataloaders.py index 8cf9e69..1fec1f7 100644 --- a/examples/leWorldModel/lewm_loader/dataloaders.py +++ b/examples/leWorldModel/lewm_loader/dataloaders.py @@ -2,7 +2,7 @@ DataLoader factories for leWorldModel LanceDB-backed training. Two public functions: - make_lewm_lance_loader() – single loader, caller provides a pre-built dataset + make_lewm_lance_loader() – single loader (no split) make_train_val_loaders() – episode-level train/val split, returns two loaders Episode-level split (not random-row split) avoids data leakage: @@ -14,7 +14,7 @@ import torch from torch.utils.data import DataLoader -from .dataset import LeWMLanceDataset, compute_normalizers +from .dataset import LeWMLanceDataset # --------------------------------------------------------------------------- @@ -40,11 +40,11 @@ def make_lewm_lance_loader( img_size: int = 224, num_workers: int = 6, prefetch_factor: int = 3, - normalizers: dict | None = None, + shuffle: bool = False, **connect_kwargs, ) -> DataLoader: """ - Build a DataLoader over a LanceDB leWorldModel table. + Build a single DataLoader over a LanceDB leWorldModel table. frameskip=5 matches the le-wm paper default. With T=4 and frameskip=5, each window spans 20 raw rows; action is reshaped to (T, 5×action_dim). @@ -56,10 +56,9 @@ def make_lewm_lance_loader( num_steps=num_steps, frameskip=frameskip, img_size=img_size, - normalizers=normalizers, **connect_kwargs, ) - return _build_loader(dataset, batch_size, num_workers, prefetch_factor) + return _build_loader(dataset, batch_size, num_workers, prefetch_factor, shuffle=shuffle) def make_train_val_loaders( @@ -80,7 +79,8 @@ def make_train_val_loaders( Episode-level train/val split. val_fraction of episodes (randomly sampled, seeded) are held out for - validation. Normalizers are computed on training episodes only. + validation. All timesteps within an episode go entirely to one split — + no row-level leakage between train and val. Returns: (train_loader, val_loader) @@ -100,12 +100,10 @@ def make_train_val_loaders( print(f" Split: {len(train_episodes)} train episodes, {len(val_episodes)} val episodes") - # Compute normalizers on training data only to avoid leakage - normalizers = compute_normalizers(uri, table_name, columns, **connect_kwargs) - - # Build full datasets then restrict _window_starts by episode membership - train_ds = LeWMLanceDataset(uri, table_name, columns, num_steps, frameskip, img_size, normalizers, **connect_kwargs) - val_ds = LeWMLanceDataset(uri, table_name, columns, num_steps, frameskip, img_size, normalizers, **connect_kwargs) + # Build full datasets then restrict _window_starts by episode membership. + # Both datasets share the same table — no data is copied. + train_ds = LeWMLanceDataset(uri, table_name, columns, num_steps, frameskip, img_size, **connect_kwargs) + val_ds = LeWMLanceDataset(uri, table_name, columns, num_steps, frameskip, img_size, **connect_kwargs) train_ep_mask = np.isin(train_ds._ep[train_ds._window_starts], list(train_episodes)) val_ep_mask = np.isin(val_ds._ep[val_ds._window_starts], list(val_episodes)) @@ -116,8 +114,8 @@ def make_train_val_loaders( print(f" Windows: {len(train_ds):,} train, {len(val_ds):,} val") return ( - _build_loader(train_ds, batch_size, num_workers, prefetch_factor), - _build_loader(val_ds, batch_size, num_workers, prefetch_factor), + _build_loader(train_ds, batch_size, num_workers, prefetch_factor, shuffle=True), + _build_loader(val_ds, batch_size, num_workers, prefetch_factor, shuffle=False), ) @@ -125,11 +123,17 @@ def make_train_val_loaders( # Internal # --------------------------------------------------------------------------- -def _build_loader(dataset: LeWMLanceDataset, batch_size: int, num_workers: int, prefetch_factor: int) -> DataLoader: +def _build_loader( + dataset: LeWMLanceDataset, + batch_size: int, + num_workers: int, + prefetch_factor: int, + shuffle: bool = False, +) -> DataLoader: return DataLoader( dataset, batch_size=batch_size, - shuffle=False, + shuffle=shuffle, num_workers=num_workers, pin_memory=True, drop_last=True, diff --git a/examples/leWorldModel/lewm_loader/dataset.py b/examples/leWorldModel/lewm_loader/dataset.py index f0c3da1..d6f60c2 100644 --- a/examples/leWorldModel/lewm_loader/dataset.py +++ b/examples/leWorldModel/lewm_loader/dataset.py @@ -14,7 +14,7 @@ Each dataset item is a dict of tensors: "pixels" : (T, C, H, W) float32 ImageNet-normalized - "action" : (T, frameskip×A) float32 z-score normalized, NaN→0 + "action" : (T, frameskip×A) float32 NaN→0 "proprio" : (T, P) float32 [if present] ... @@ -55,32 +55,6 @@ def _jpeg_to_tensor(jpeg_bytes: bytes, transform: transforms.Compose) -> torch.T return transform(img) -def compute_normalizers( - uri: str, - table_name: str, - columns: list[str], - **connect_kwargs, -) -> dict[str, tuple[np.ndarray, np.ndarray]]: - """ - Compute per-column (mean, std) for z-score normalization. - Only reads non-pixel columns one at a time; pixel/blob data is never loaded. - """ - db = lancedb.connect(uri, **connect_kwargs) - tbl = db.open_table(table_name) - non_pixel = [c for c in columns if c != "pixels"] - if not non_pixel: - return {} - ds = tbl.to_lance() - normalizers = {} - for col in non_pixel: - col_table = ds.to_table(columns=[col]) - data = np.array(col_table[col].to_pylist(), dtype=np.float32) - valid = ~np.isnan(data).any(axis=1) - data = data[valid] - normalizers[col] = (data.mean(axis=0), data.std(axis=0)) - return normalizers - - class LeWMLanceDataset(torch.utils.data.Dataset): """ Temporal-window dataset backed by a LanceDB table. @@ -95,7 +69,6 @@ class LeWMLanceDataset(torch.utils.data.Dataset): action is kept at full resolution and reshaped to (T, frameskip × action_dim); all other columns are strided. img_size: Target image size after resize. - normalizers: {col: (mean, std)} from compute_normalizers(). **connect_kwargs: Passed to lancedb.connect(). """ @@ -107,7 +80,6 @@ def __init__( num_steps: int = 4, frameskip: int = 5, img_size: int = 224, - normalizers: dict | None = None, **connect_kwargs, ): self.uri = uri @@ -116,7 +88,6 @@ def __init__( self.num_steps = num_steps self.frameskip = frameskip self.img_size = img_size - self.normalizers = normalizers or {} self.connect_kwargs = connect_kwargs self._span = num_steps * frameskip # raw rows per window @@ -186,17 +157,15 @@ def _rows_to_sample(self, batch: pa.RecordBatch) -> dict[str, torch.Tensor]: continue if col == "action": - # to_pylist() batches the Arrow → Python conversion in one call - # instead of .as_py() in a Python loop + # Keep all span rows, reshape to (T, frameskip × action_dim). + # This matches le-wm's effective_act_dim = frameskip × raw_action_dim. data = np.array(batch.column(col).to_pylist(), dtype=np.float32) data = np.nan_to_num(data, nan=0.0) - data = data.reshape(T, -1) # (T, frameskip × action_dim) + data = data.reshape(T, -1) else: + # Proprio, state, observation: stride by frameskip → (T, D) data = np.array(batch.column(col).to_pylist(), dtype=np.float32) - data = data[::frameskip] # stride without a Python loop - if col in self.normalizers: - mean, std = self.normalizers[col] - data = (data - mean) / (std + 1e-8) + data = data[::frameskip] data = np.nan_to_num(data, nan=0.0) sample[col] = torch.from_numpy(data) From 067c538800355145113008359d4eb919f4b125a3 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Fri, 3 Apr 2026 21:26:41 +0530 Subject: [PATCH 24/29] update --- examples/leWorldModel/lewm_loader/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/leWorldModel/lewm_loader/__init__.py b/examples/leWorldModel/lewm_loader/__init__.py index f00b070..ce47cf6 100644 --- a/examples/leWorldModel/lewm_loader/__init__.py +++ b/examples/leWorldModel/lewm_loader/__init__.py @@ -1,9 +1,8 @@ -from .dataset import LeWMLanceDataset, compute_normalizers +from .dataset import LeWMLanceDataset from .dataloaders import make_lewm_lance_loader, make_train_val_loaders __all__ = [ "LeWMLanceDataset", - "compute_normalizers", "make_lewm_lance_loader", "make_train_val_loaders", ] From e0f0c248efda53847bf89b6d4e0ee538b4033aad Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Sat, 4 Apr 2026 11:41:37 +0530 Subject: [PATCH 25/29] update prep --- examples/leWorldModel/prepare_eval.py | 93 ++++++++++++++++++++++++--- 1 file changed, 85 insertions(+), 8 deletions(-) diff --git a/examples/leWorldModel/prepare_eval.py b/examples/leWorldModel/prepare_eval.py index 8d1c38e..3c5c56e 100644 --- a/examples/leWorldModel/prepare_eval.py +++ b/examples/leWorldModel/prepare_eval.py @@ -5,25 +5,100 @@ - checkpoint at $STABLEWM_HOME/_object.ckpt - policy argument passed as (without _object.ckpt suffix) -This script copies/symlinks the checkpoint to the right location and prints -the exact eval.py command to run. +eval.py also needs the source HDF5 file at $STABLEWM_HOME/.h5 +to sample episode starting states and goals. This script downloads it from +HuggingFace automatically if it is not already present. + +This script handles all of the above and prints the exact eval.py command to run. Usage: python prepare_eval.py --checkpoint checkpoints/lewm_pusht_lewm_epoch_10_object.ckpt python prepare_eval.py --checkpoint checkpoints/lewm_pusht_lewm_epoch_10_object.ckpt --run-name lewm_pusht """ import argparse +import glob import os import shutil +import subprocess from pathlib import Path +# HuggingFace repo and expected HDF5 filename for each dataset +_DATASET_META = { + "pusht": ("quentinll/lewm-pusht", "pusht_expert_train.h5"), + "cube": ("quentinll/lewm-cube", "cube_single_expert.h5"), + "reacher": ("quentinll/lewm-reacher", "reacher.h5"), + "tworoom": ("quentinll/lewm-tworooms", "tworoom.h5"), +} + + +def _ensure_hdf5(dataset: str, stablewm_home: Path) -> Path: + """ + Ensure the source HDF5 file is present at $STABLEWM_HOME/.h5. + Downloads from HuggingFace if missing. + """ + hf_repo, hdf5_name = _DATASET_META[dataset] + dst = stablewm_home / hdf5_name + if dst.exists(): + return dst + + # Check if already cached from a previous create_data.py run + cache_dir = stablewm_home / "datasets" / hf_repo.replace("/", "--") + existing = glob.glob(str(cache_dir / "*.h5")) + glob.glob(str(cache_dir / "*.hdf5")) + if existing: + dst.symlink_to(existing[0]) + print(f"Linked HDF5 from cache → {dst}") + return dst + + # Download from HuggingFace + print(f"HDF5 not found. Downloading {hf_repo} from HuggingFace...") + cache_dir.mkdir(parents=True, exist_ok=True) + try: + from huggingface_hub import list_repo_files, hf_hub_download + except ImportError: + raise SystemExit("huggingface_hub not installed. Run: pip install huggingface_hub") + + repo_files = list(list_repo_files(hf_repo, repo_type="dataset")) + data_file = next( + (f for f in repo_files + if f.endswith(".tar.zst") or f.endswith(".h5.zst") + or f.endswith(".h5") or f.endswith(".hdf5")), + None, + ) + if not data_file: + raise FileNotFoundError(f"No HDF5 file found in HuggingFace repo {hf_repo}") + + local = hf_hub_download( + repo_id=hf_repo, filename=data_file, + repo_type="dataset", local_dir=str(cache_dir), + ) + + if local.endswith(".tar.zst"): + subprocess.run( + ["tar", "--use-compress-program=unzstd", "-xf", local, "-C", str(cache_dir)], + check=True, + ) + os.remove(local) + elif local.endswith(".h5.zst"): + out = local[:-4] + subprocess.run(["zstd", "-d", local, "-o", out], check=True) + os.remove(local) + + h5_files = glob.glob(str(cache_dir / "*.h5")) + glob.glob(str(cache_dir / "*.hdf5")) + if not h5_files: + raise FileNotFoundError(f"No .h5 file found in {cache_dir} after download") + + dst.symlink_to(h5_files[0]) + print(f"Downloaded and linked HDF5 → {dst}") + return dst + + def main(): parser = argparse.ArgumentParser() parser.add_argument("--checkpoint", required=True, help="Path to *_object.ckpt file") parser.add_argument("--run-name", default=None, help="Name to use under STABLEWM_HOME (default: derived from checkpoint filename)") - parser.add_argument("--dataset", default="pusht", choices=["pusht", "cube", "reacher", "tworoom"], + parser.add_argument("--dataset", default="pusht", choices=list(_DATASET_META), help="Dataset to evaluate on") parser.add_argument("--copy", action="store_true", help="Copy instead of symlinking (use when src and dst are on different filesystems)") @@ -42,19 +117,21 @@ def main(): stem = stem[: -len("_object")] # lewm_pusht_lewm_epoch_10 run_name = args.run_name or stem + # 1. Link checkpoint dst = stablewm_home / f"{run_name}_object.ckpt" - if dst.exists() or dst.is_symlink(): dst.unlink() - if args.copy: shutil.copy2(ckpt, dst) - print(f"Copied {ckpt}") + print(f"Copied {ckpt}") else: dst.symlink_to(ckpt) - print(f"Symlinked {ckpt}") + print(f"Symlinked {ckpt}") + print(f" → {dst}") + + # 2. Ensure HDF5 is present + _ensure_hdf5(args.dataset, stablewm_home) - print(f" → {dst}") print() print("Run evaluation with:") print(f" python eval.py --config-name={args.dataset}.yaml policy={run_name}") From b8803f88ec8789d685c10a5e71435a6deeab51d9 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Sat, 4 Apr 2026 12:45:45 +0530 Subject: [PATCH 26/29] codex fix --- .../leWorldModel/lewm_loader/dataloaders.py | 113 +++++++++++++++++- examples/leWorldModel/lewm_loader/dataset.py | 21 +++- 2 files changed, 130 insertions(+), 4 deletions(-) diff --git a/examples/leWorldModel/lewm_loader/dataloaders.py b/examples/leWorldModel/lewm_loader/dataloaders.py index 1fec1f7..5d1acfb 100644 --- a/examples/leWorldModel/lewm_loader/dataloaders.py +++ b/examples/leWorldModel/lewm_loader/dataloaders.py @@ -17,6 +17,89 @@ from .dataset import LeWMLanceDataset +def _update_running_stats(entry: dict, data: np.ndarray) -> None: + """Numerically stable running mean/std update (per-dimension).""" + + if data.size == 0: + return + + batch_count = data.shape[0] + batch_mean = data.mean(axis=0, dtype=np.float64) + batch_m2 = ((data - batch_mean) ** 2).sum(axis=0, dtype=np.float64) + + if entry["count"] == 0: + entry["count"] = batch_count + entry["mean"] = batch_mean + entry["m2"] = batch_m2 + return + + total = entry["count"] + batch_count + delta = batch_mean - entry["mean"] + entry["mean"] = entry["mean"] + delta * (batch_count / total) + entry["m2"] = ( + entry["m2"] + + batch_m2 + + (delta**2) * entry["count"] * batch_count / total + ) + entry["count"] = total + + +def _compute_column_normalizers( + uri: str, + table_name: str, + columns: list[str], + train_episodes: set[int], + connect_kwargs: dict, +) -> dict[str, dict[str, np.ndarray]]: + """Compute per-column (mean,std) stats on the training episodes only.""" + + norm_cols = [c for c in columns if c != "pixels"] + if not norm_cols or not train_episodes: + return {} + + db = lancedb.connect(uri, **connect_kwargs) + tbl = db.open_table(table_name) + lance_ds = tbl.to_lance() + scanner = lance_ds.scanner( + columns=["episode_idx", *norm_cols], + batch_size=8192, + ) + + stats = {col: {"count": 0, "mean": None, "m2": None} for col in norm_cols} + episode_ids = np.array(sorted(train_episodes), dtype=np.int32) + + for batch in scanner.to_batches(): + ep = np.array(batch["episode_idx"].to_pylist(), dtype=np.int32) + mask = np.isin(ep, episode_ids) + if not mask.any(): + continue + + for col in norm_cols: + arr = np.array(batch[col].to_pylist(), dtype=np.float32) + arr = arr[mask] + if arr.ndim == 1: + arr = arr[:, None] + arr = arr[~np.isnan(arr).any(axis=1)] + if arr.size == 0: + continue + _update_running_stats(stats[col], arr) + + normalizers: dict[str, dict[str, np.ndarray]] = {} + for col, entry in stats.items(): + if entry["count"] == 0: + continue + mean = entry["mean"].astype(np.float32) + if entry["count"] > 1: + var = entry["m2"] / (entry["count"] - 1) + else: + var = np.ones_like(mean, dtype=np.float64) + std = np.sqrt(var).astype(np.float32) + std = np.where(std > 1e-6, std, np.ones_like(std)) + normalizers[col] = {"mean": mean, "std": std} + + return normalizers + + # --------------------------------------------------------------------------- # Collate: list[{key: (T,...) tensor}] → {key: (B,T,...) tensor} # --------------------------------------------------------------------------- @@ -100,10 +183,36 @@ def make_train_val_loaders( print(f" Split: {len(train_episodes)} train episodes, {len(val_episodes)} val episodes") + normalizers = _compute_column_normalizers( + uri=uri, + table_name=table_name, + columns=columns, + train_episodes=train_episodes, + connect_kwargs=connect_kwargs, + ) + # Build full datasets then restrict _window_starts by episode membership. # Both datasets share the same table — no data is copied. - train_ds = LeWMLanceDataset(uri, table_name, columns, num_steps, frameskip, img_size, **connect_kwargs) - val_ds = LeWMLanceDataset(uri, table_name, columns, num_steps, frameskip, img_size, **connect_kwargs) + train_ds = LeWMLanceDataset( + uri, + table_name, + columns, + num_steps, + frameskip, + img_size, + normalizers=normalizers, + **connect_kwargs, + ) + val_ds = LeWMLanceDataset( + uri, + table_name, + columns, + num_steps, + frameskip, + img_size, + normalizers=normalizers, + **connect_kwargs, + ) train_ep_mask = np.isin(train_ds._ep[train_ds._window_starts], list(train_episodes)) val_ep_mask = np.isin(val_ds._ep[val_ds._window_starts], list(val_episodes)) diff --git a/examples/leWorldModel/lewm_loader/dataset.py b/examples/leWorldModel/lewm_loader/dataset.py index d6f60c2..cbfb1c9 100644 --- a/examples/leWorldModel/lewm_loader/dataset.py +++ b/examples/leWorldModel/lewm_loader/dataset.py @@ -80,6 +80,7 @@ def __init__( num_steps: int = 4, frameskip: int = 5, img_size: int = 224, + normalizers: dict[str, dict[str, np.ndarray]] | None = None, **connect_kwargs, ): self.uri = uri @@ -93,6 +94,13 @@ def __init__( self._perm: Permutation | None = None self._transform: transforms.Compose | None = None + self._normalizers: dict[str, dict[str, np.ndarray]] = {} + if normalizers: + for col, stats in normalizers.items(): + mean = np.array(stats["mean"], dtype=np.float32) + std = np.array(stats["std"], dtype=np.float32) + std = np.where(std > 1e-6, std, np.ones_like(std)) + self._normalizers[col] = {"mean": mean, "std": std} # Load only the two int32 index columns to precompute valid windows. # Pixels and all other data columns are never touched here. @@ -159,14 +167,23 @@ def _rows_to_sample(self, batch: pa.RecordBatch) -> dict[str, torch.Tensor]: if col == "action": # Keep all span rows, reshape to (T, frameskip × action_dim). # This matches le-wm's effective_act_dim = frameskip × raw_action_dim. - data = np.array(batch.column(col).to_pylist(), dtype=np.float32) + data = np.array(batch[col].to_pylist(), dtype=np.float32) data = np.nan_to_num(data, nan=0.0) + data = data.reshape(T, frameskip, -1) + if col in self._normalizers: + norm = self._normalizers[col] + mean = norm["mean"][None, None, :] + std = norm["std"][None, None, :] + data = (data - mean) / std data = data.reshape(T, -1) else: # Proprio, state, observation: stride by frameskip → (T, D) - data = np.array(batch.column(col).to_pylist(), dtype=np.float32) + data = np.array(batch[col].to_pylist(), dtype=np.float32) data = data[::frameskip] data = np.nan_to_num(data, nan=0.0) + if col in self._normalizers: + norm = self._normalizers[col] + data = (data - norm["mean"]) / norm["std"] sample[col] = torch.from_numpy(data) From a8536428bc74d7518898059b2bcb7dd2dea7858b Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Sat, 4 Apr 2026 15:42:06 +0530 Subject: [PATCH 27/29] more codex debug --- .../leWorldModel/lewm_loader/dataloaders.py | 4 ++ examples/leWorldModel/train.py | 51 ++++++++++++++++++- 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/examples/leWorldModel/lewm_loader/dataloaders.py b/examples/leWorldModel/lewm_loader/dataloaders.py index 5d1acfb..ace280d 100644 --- a/examples/leWorldModel/lewm_loader/dataloaders.py +++ b/examples/leWorldModel/lewm_loader/dataloaders.py @@ -183,6 +183,7 @@ def make_train_val_loaders( print(f" Split: {len(train_episodes)} train episodes, {len(val_episodes)} val episodes") + print(" Computing column normalizers on train episodes...") normalizers = _compute_column_normalizers( uri=uri, table_name=table_name, @@ -191,6 +192,9 @@ def make_train_val_loaders( connect_kwargs=connect_kwargs, ) + for col, stats in normalizers.items(): + print(f" {col}: mean={stats['mean'].tolist()}, std={stats['std'].tolist()}") + # Build full datasets then restrict _window_starts by episode membership. # Both datasets share the same table — no data is copied. train_ds = LeWMLanceDataset( diff --git a/examples/leWorldModel/train.py b/examples/leWorldModel/train.py index 6e412d2..f4fb25b 100644 --- a/examples/leWorldModel/train.py +++ b/examples/leWorldModel/train.py @@ -38,8 +38,10 @@ import yaml import pytorch_lightning as pl from pathlib import Path +from typing import Optional from pytorch_lightning.callbacks import Callback from pytorch_lightning.loggers import WandbLogger +import csv from jepa import JEPA from lewm_loader import make_train_val_loaders @@ -99,11 +101,14 @@ class LeWMLightning(pl.LightningModule): never called here. """ - def __init__(self, model: JEPA, sigreg: SIGReg, cfg: dict): + def __init__(self, model: JEPA, sigreg: SIGReg, cfg: dict, debug_path: Optional[str] = None): super().__init__() self.model = model self.sigreg = sigreg self.cfg = cfg + self.debug_path = Path(debug_path) if debug_path else None + self._debug_header_written = False + self._debug_fieldnames: Optional[list[str]] = None self.save_hyperparameters(ignore=["model", "sigreg"]) def _shared_step(self, batch: dict, stage: str) -> torch.Tensor: @@ -124,14 +129,32 @@ def _shared_step(self, batch: dict, stage: str) -> torch.Tensor: tgt_emb = emb[:, n_preds:] # ground-truth targets (shifted by n_preds) pred_emb = self.model.predict(ctx_emb, ctx_act) + emb_std = emb.detach().float().std().item() + act_std = act_emb.detach().float().std().item() # SIGReg expects (T, B, D) — transpose time and batch dims loss_pred = (pred_emb - tgt_emb).pow(2).mean() loss_reg = self.sigreg(emb.transpose(0, 1)) loss = loss_pred + self.cfg["sigreg_weight"] * loss_reg + pred_std = pred_emb.detach().float().std().item() self.log(f"{stage}/loss_pred", loss_pred, on_step=(stage == "train"), on_epoch=True, prog_bar=True) self.log(f"{stage}/loss_reg", loss_reg, on_step=(stage == "train"), on_epoch=True) self.log(f"{stage}/loss", loss, on_step=(stage == "train"), on_epoch=True, prog_bar=True) + self.log(f"{stage}/emb_std", emb_std, on_step=(stage == "train"), on_epoch=True) + self.log(f"{stage}/act_emb_std", act_std, on_step=(stage == "train"), on_epoch=True) + self.log(f"{stage}/pred_emb_std", pred_std, on_step=(stage == "train"), on_epoch=True) + + self._write_debug_row(stage, { + "stage": stage, + "global_step": int(self.global_step), + "epoch": int(self.current_epoch), + "loss": float(loss.detach().item()), + "loss_pred": float(loss_pred.detach().item()), + "loss_reg": float(loss_reg.detach().item()), + "emb_std": emb_std, + "act_emb_std": act_std, + "pred_emb_std": pred_std, + }) return loss def training_step(self, batch: dict, _) -> torch.Tensor: @@ -161,6 +184,22 @@ def configure_optimizers(self): ) return [opt], [{"scheduler": sched, "interval": "step"}] + def _write_debug_row(self, stage: str, row: dict) -> None: + if not self.debug_path: + return + path = self.debug_path + if path.parent and not path.parent.exists(): + path.parent.mkdir(parents=True, exist_ok=True) + fieldnames = self._debug_fieldnames or list(row.keys()) + need_header = not self._debug_header_written and not path.exists() + with path.open("a", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + if need_header: + writer.writeheader() + self._debug_header_written = True + writer.writerow(row) + self._debug_fieldnames = fieldnames + # --------------------------------------------------------------------------- # Model construction @@ -277,6 +316,8 @@ def _build_parser() -> argparse.ArgumentParser: help="Run 1 train+val batch then exit (smoke test)") parser.add_argument("--precision", default=None, help="Override trainer.precision (e.g. 32, 16-mixed, bf16-mixed)") + parser.add_argument("--debug-log", default=None, + help="Optional path to write per-step debug metrics (emb/action stds, losses)") s3 = parser.add_argument_group("S3 storage options") s3.add_argument("--aws-access-key-id", default=None, metavar="KEY") s3.add_argument("--aws-secret-access-key", default=None, metavar="SECRET") @@ -353,9 +394,15 @@ def main(): "sigreg_weight": cfg["loss"]["sigreg"]["weight"], "history_size": wm_cfg["history_size"], "num_preds": wm_cfg["num_preds"], + "debug_log": args.debug_log, } - lightning_model = LeWMLightning(model=model, sigreg=sigreg, cfg=lightning_cfg) + lightning_model = LeWMLightning( + model=model, + sigreg=sigreg, + cfg=lightning_cfg, + debug_path=args.debug_log, + ) # ------------------------------------------------------------------ # # Logging & callbacks From 7db23990e5ed8d0b0fd613d64c9c6b2d4b3226c0 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Sat, 4 Apr 2026 23:01:05 +0530 Subject: [PATCH 28/29] fix weight decay --- examples/leWorldModel/train.py | 35 +++++++++++++++++++++++++++++++--- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/examples/leWorldModel/train.py b/examples/leWorldModel/train.py index f4fb25b..fa5b2fb 100644 --- a/examples/leWorldModel/train.py +++ b/examples/leWorldModel/train.py @@ -126,7 +126,7 @@ def _shared_step(self, batch: dict, stage: str) -> torch.Tensor: # Predict next states from the context window (first history_size frames) ctx_emb = emb[:, :ctx_len] ctx_act = act_emb[:, :ctx_len] - tgt_emb = emb[:, n_preds:] # ground-truth targets (shifted by n_preds) + tgt_emb = emb[:, n_preds:] pred_emb = self.model.predict(ctx_emb, ctx_act) emb_std = emb.detach().float().std().item() @@ -164,10 +164,11 @@ def validation_step(self, batch: dict, _) -> None: self._shared_step(batch, "val") def configure_optimizers(self): + param_groups = self._build_param_groups(self.model) opt = torch.optim.AdamW( - self.model.parameters(), + param_groups, lr=self.cfg["lr"], - weight_decay=self.cfg["weight_decay"], + weight_decay=0.0, ) # Replicate le-wm's LinearWarmupCosineAnnealingLR exactly: # warmup_steps = 1% of total steps (step-based, not epoch-based) @@ -200,6 +201,34 @@ def _write_debug_row(self, stage: str, row: dict) -> None: writer.writerow(row) self._debug_fieldnames = fieldnames + def _build_param_groups(self, module: nn.Module) -> list[dict]: + decay: list[torch.nn.Parameter] = [] + no_decay: list[torch.nn.Parameter] = [] + for name, param in module.named_parameters(): + if not param.requires_grad: + continue + if _is_bias_or_norm(name, param): + no_decay.append(param) + else: + decay.append(param) + groups: list[dict] = [] + if decay: + groups.append({"params": decay, "weight_decay": self.cfg["weight_decay"]}) + if no_decay: + groups.append({"params": no_decay, "weight_decay": 0.0}) + return groups + + +def _is_bias_or_norm(name: str, param: torch.nn.Parameter) -> bool: + if name.endswith(".bias") or name.endswith("bias"): + return True + lname = name.lower() + if "norm" in lname: + return True + if param.ndim == 1: + return True + return False + # --------------------------------------------------------------------------- # Model construction From 476b3f9b62bdb0505d2abeeefba1ce6c8aebffa5 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Sun, 5 Apr 2026 09:11:37 +0530 Subject: [PATCH 29/29] update codex --- examples/leWorldModel/config/lewm_pusht.yaml | 42 +++++------ .../leWorldModel/lewm_loader/dataloaders.py | 72 ++++++++----------- 2 files changed, 47 insertions(+), 67 deletions(-) diff --git a/examples/leWorldModel/config/lewm_pusht.yaml b/examples/leWorldModel/config/lewm_pusht.yaml index fcf0184..29bb82c 100644 --- a/examples/leWorldModel/config/lewm_pusht.yaml +++ b/examples/leWorldModel/config/lewm_pusht.yaml @@ -1,35 +1,25 @@ -# leWorldModel training config — PushT dataset -# Mirrors le-wm/config/train/lewm.yaml + config/data/pusht.yaml -# LanceDB-specific keys added under 'data:' (not present in original le-wm). -# All other sections are a strict superset of the original le-wm config. - seed: 42 img_size: 224 - trainer: max_epochs: 100 - precision: "bf16-mixed" # PyTorch Lightning format (le-wm uses "bf16" via Hydra) + precision: bf16-mixed gradient_clip_val: 1.0 log_every_n_steps: 50 save_every_n_epochs: 10 - checkpoint_dir: "./checkpoints" - + checkpoint_dir: ./checkpoints loader: batch_size: 128 num_workers: 12 prefetch_factor: 4 - optimizer: - lr: 5.0e-5 - weight_decay: 1.0e-3 - + lr: 5.0e-05 + weight_decay: 0.001 wm: history_size: 3 num_preds: 1 embed_dim: 192 - patch_size: 14 # ViT patch size — with img_size=224 gives 16×16=256 patches + CLS - proj_hidden: 2048 # added: projector MLP hidden dim - + patch_size: 14 + proj_hidden: 2048 predictor: depth: 6 heads: 16 @@ -37,20 +27,20 @@ predictor: dim_head: 64 dropout: 0.1 emb_dropout: 0.0 - loss: sigreg: weight: 0.09 kwargs: knots: 17 num_proj: 1024 - -# ── LanceDB-specific additions (not present in original le-wm) ──────────────── data: - lance_uri: "./lewm_lance" # override with s3://bucket/prefix for cloud - table_name: "lewm_pusht" - columns: ["pixels", "action", "proprio", "state"] - frameskip: 5 # matches le-wm paper default - val_fraction: 0.1 # episode-level hold-out fraction - -wandb_project: "lewm-lancedb" + lance_uri: ./lewm_lance + table_name: lewm_pusht + columns: + - pixels + - action + - proprio + - state + frameskip: 5 + val_fraction: 0.1 +wandb_project: lewm-lancedb diff --git a/examples/leWorldModel/lewm_loader/dataloaders.py b/examples/leWorldModel/lewm_loader/dataloaders.py index ace280d..b9a7c66 100644 --- a/examples/leWorldModel/lewm_loader/dataloaders.py +++ b/examples/leWorldModel/lewm_loader/dataloaders.py @@ -3,10 +3,7 @@ Two public functions: make_lewm_lance_loader() – single loader (no split) - make_train_val_loaders() – episode-level train/val split, returns two loaders - -Episode-level split (not random-row split) avoids data leakage: - all timesteps of a given episode go entirely to train or entirely to val. + make_train_val_loaders() – random window train/val split, returns two loaders """ import lancedb @@ -48,13 +45,13 @@ def _compute_column_normalizers( uri: str, table_name: str, columns: list[str], - train_episodes: set[int], + train_episodes: set[int] | None, connect_kwargs: dict, ) -> dict[str, dict[str, np.ndarray]]: - """Compute per-column (mean,std) stats on the training episodes only.""" + """Compute per-column (mean,std) stats on selected episodes (or all).""" norm_cols = [c for c in columns if c != "pixels"] - if not norm_cols or not train_episodes: + if not norm_cols: return {} db = lancedb.connect(uri, **connect_kwargs) @@ -66,11 +63,18 @@ def _compute_column_normalizers( ) stats = {col: {"count": 0, "mean": None, "m2": None} for col in norm_cols} - episode_ids = np.array(sorted(train_episodes), dtype=np.int32) + episode_ids = ( + np.array(sorted(train_episodes), dtype=np.int32) + if train_episodes is not None + else None + ) for batch in scanner.to_batches(): ep = np.array(batch["episode_idx"].to_pylist(), dtype=np.int32) - mask = np.isin(ep, episode_ids) + if episode_ids is None: + mask = np.ones_like(ep, dtype=bool) + else: + mask = np.isin(ep, episode_ids) if not mask.any(): continue @@ -159,45 +163,24 @@ def make_train_val_loaders( **connect_kwargs, ) -> tuple[DataLoader, DataLoader]: """ - Episode-level train/val split. - - val_fraction of episodes (randomly sampled, seeded) are held out for - validation. All timesteps within an episode go entirely to one split — - no row-level leakage between train and val. + Random window train/val split (matches le-wm Hydra config). Returns: (train_loader, val_loader) """ - db = lancedb.connect(uri, **connect_kwargs) - tbl = db.open_table(table_name) - - # Only reads one int32 column — negligible memory even at millions of rows - ep_arr = tbl.to_lance().to_table(columns=["episode_idx"])["episode_idx"].to_numpy() - all_episodes = np.unique(ep_arr) - - rng = np.random.default_rng(seed) - rng.shuffle(all_episodes) - n_val = max(1, int(len(all_episodes) * val_fraction)) - val_episodes = set(all_episodes[:n_val].tolist()) - train_episodes = set(all_episodes[n_val:].tolist()) - print(f" Split: {len(train_episodes)} train episodes, {len(val_episodes)} val episodes") - - print(" Computing column normalizers on train episodes...") + print(" Computing column normalizers (all episodes)...") normalizers = _compute_column_normalizers( uri=uri, table_name=table_name, columns=columns, - train_episodes=train_episodes, + train_episodes=None, connect_kwargs=connect_kwargs, ) - for col, stats in normalizers.items(): print(f" {col}: mean={stats['mean'].tolist()}, std={stats['std'].tolist()}") - # Build full datasets then restrict _window_starts by episode membership. - # Both datasets share the same table — no data is copied. - train_ds = LeWMLanceDataset( + base_ds = LeWMLanceDataset( uri, table_name, columns, @@ -207,7 +190,19 @@ def make_train_val_loaders( normalizers=normalizers, **connect_kwargs, ) - val_ds = LeWMLanceDataset( + + total_windows = len(base_ds) + n_val = max(1, int(total_windows * val_fraction)) + n_train = total_windows - n_val + rng = np.random.default_rng(seed) + perm = rng.permutation(total_windows) + train_idx = np.sort(perm[:n_train]) + val_idx = np.sort(perm[n_train:]) + + train_ds = base_ds + train_ds._window_starts = train_ds._window_starts[train_idx] + + val_ds = LeWMLanceDataset( uri, table_name, columns, @@ -217,12 +212,7 @@ def make_train_val_loaders( normalizers=normalizers, **connect_kwargs, ) - - train_ep_mask = np.isin(train_ds._ep[train_ds._window_starts], list(train_episodes)) - val_ep_mask = np.isin(val_ds._ep[val_ds._window_starts], list(val_episodes)) - - train_ds._window_starts = train_ds._window_starts[train_ep_mask] - val_ds._window_starts = val_ds._window_starts[val_ep_mask] + val_ds._window_starts = val_ds._window_starts[val_idx] print(f" Windows: {len(train_ds):,} train, {len(val_ds):,} val")