Skip to content

wendlerc/toy-wm

Repository files navigation

TLDR

A toy implementation of a diffusion transformer based "world model" trained on 9 hours of pong. Shoutout @pufferlib for their great pong environment that was used for dataset creation.

Training dashboard. 0: unconditional, 1:don't move, 2:up, 3:down (for cyan) dashboard

The only optimization this codebase really uses so far is flexattention but even without it you can train a pong simulator within a reasonable budget.

The folder structure and repo are hopefully self explanatory. If you have any questions or find bugs or problems with the environment setup please don't hesitate to create an issue.

Setup

  • install dependencies using uv sync

Inference / running the demo

  • download model uv run scripts/download_model.py
  • start pong server uv run play_pong.py

Training

  • download pong dataset uv run scripts/download_dataset.py
  • you can train your own pong simulator using (should take <= 30 minutes on a A6000): uv run python -m src.main
  • to use it update configs/inference.yaml. By default, the checkpoints will be in ./experiments/wandb-run-name. If you want to play with your model while it is training you can put the run folder into the checkpoint field. Then run uv run python play_pong.py. This should start a server running pong that you can connect to and play interactively. There is also generate_with_cache.ipynb to play around with inference.

Technical details

Resources

Other repositories that I found useful along the way:

About

A toy implementation of a diffusion transformer based "world model" trained on 9 hours of pong.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors