Skip to content

wendlerc/toy-wm

Repository files navigation

TLDR

A toy implementation of a diffusion transformer based "world model". Supports both Pong (pixel space, 9 hours of gameplay) and Doom (latent space via DC-AE, PvP deathmatch). Shoutout @pufferlib for their great pong environment that was used for dataset creation.

Training dashboard. 0: unconditional, 1:don't move, 2:up, 3:down (for cyan) dashboard

The only optimization this codebase really uses so far is flexattention but even without it you can train a pong simulator within a reasonable budget.

The folder structure and repo are hopefully self explanatory. If you have any questions or find bugs or problems with the environment setup please don't hesitate to create an issue.

Setup

  • install dependencies using uv sync

Inference / running the demo

Pong

  • download model uv run scripts/download_model.py
  • start pong server: uv run python play_pong.py --checkpoint experiments/<run-name>/model.pt
  • server runs at http://localhost:5000

Doom

Doom world model demo

  • train a model first (see Training below), or use a checkpoint
  • start doom server: uv run python play_doom.py --checkpoint experiments/<run-name>/model.pt
  • server runs at http://localhost:4444
  • controls: WASD + mouse look + click to shoot (click frame to capture mouse, Esc to release)

Training

Pong

  • download pong dataset uv run scripts/download_dataset.py
  • train a pong simulator (should take <= 30 minutes on a A6000): uv run python -m src.main --config configs/pong.yaml. The first few training steps are slow due to torch.compile graph capture.
  • checkpoints are saved to ./experiments/<wandb-run-name>/. To play with your model: uv run python play_pong.py --checkpoint experiments/<run-name>/model.pt. You can also pass a directory to auto-load the best checkpoint.

Doom (latent diffusion)

The Doom world model operates in latent space using DC-AE (32x spatial compression, 32 channels). The dataset contains pre-encoded latent frames from Doom PvP deathmatch gameplay.

1. Download dataset

# 5 shards / ~20 GB (default, enough for training):
uv run scripts/download_doom_dataset.py
# or to a custom location:
uv run scripts/download_doom_dataset.py --output /tmp/doom_latents
# minimal test (1 shard / ~4 GB):
uv run scripts/download_doom_dataset.py --n-shards 1
# everything (~770 GB):
uv run scripts/download_doom_dataset.py --all

If you download to a custom location, pass --shard-dir to the training and play scripts.

2. Train

uv run python -m src.main --config configs/doom.yaml
# or with a custom shard location:
uv run python -m src.main --config configs/doom.yaml --shard-dir /tmp/doom_latents

The default config trains for 5000 steps with batch_size=64 in bf16. On an A6000 (48 GB), this takes ~2.5 hours at ~1.65s/step. Loss goes from ~5.4 to ~1.3. The first step is slow (~4 min) due to torch.compile graph capture.

Checkpoints are saved every 500 steps to experiments/<wandb-run-name>/. The final model is saved as model.pt in the same directory.

3. Play

uv run python play_doom.py --checkpoint experiments/<run-name>/model.pt
# or with a custom shard location (used for start frame loading):
uv run python play_doom.py --checkpoint experiments/<run-name>/model.pt --shard-dir /tmp/doom_latents

The server starts at http://localhost:4444. Startup takes several minutes (model + VAE compilation, warmup, start frame loading). Controls: WASD to move, mouse to look, click to shoot (click the frame to capture the mouse, Esc to release). The DC-AE VAE decoder (~1 GB) is downloaded automatically on first use.

If running on a remote server, use SSH port forwarding: ssh -L 4444:localhost:4444 <host>

Technical details

Resources

Other repositories that I found useful along the way:

About

A toy implementation of a diffusion transformer based "world model" trained on the game of pong or doom.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors