Skip to content

sablecrestlabs/HistoVAE

Repository files navigation

Banner

HistoVAE

CI Status Docker Hub Maintained Python PyTorch CUDA OpenSlide License

Fast-converging convolutional Variational Autoencoder (VAE) for whole-slide image (WSI) .tif/.svs files.

This repo trains directly on random WSI tiles via OpenSlide and is designed to converge quickly on histology tile distributions.

Trained on a single RTX 5090 with default settings, this implementation demonstrates accurate half-resolution reconstruction within about 90 seconds. Within ~15 minutes, the reconstructed tiles are only distinguishable by differences in noise patterns. This makes it suitable for real-time, human-in-the-loop workflows.

What’s in this repo

Model/training highlights

Implemented in vae.py:

  • Convolutional VAE with spatial latents (not flattened)
  • Cyclic KL annealing to reduce posterior collapse
  • Mixed precision (AMP) support
  • TensorBoard logging (loss curves + image reconstructions)
  • OpenSlide-backed dataset that samples random tiles and filters empty/background tiles

Quickstart

Run with Docker (GPU)

This repo includes a GPU-capable Docker image (Ubuntu base + CUDA-enabled PyTorch installed via pip). To use the GPU, you’ll need:

The GitHub Docker publish workflow also pushes a prebuilt image to sablecrestlabs/histovae:latest, and tagged releases additionally publish sablecrestlabs/histovae:<tag>.

  • NVIDIA drivers installed on the host
  • Docker + NVIDIA Container Toolkit (so --gpus all works)

Pull

docker pull sablecrestlabs/histovae:latest

Build

docker build -t histovae .

Defaults are set in the Dockerfile (PYTORCH_VERSION=2.10.0, CUDA_VERSION=13.0). You can override them:

docker build -t histovae \
  --build-arg PYTORCH_VERSION=2.10.0 \
  --build-arg CUDA_VERSION=13.0 \
  .

The image sets CUDA_VERSION inside the container as an environment variable as well.

Train (mount host data directory)

Mount your WSI directory from the host into /data in the container:

docker run --rm --gpus all \
  -v /host/path/to/wsi_files:/data:ro \
  -v "$PWD/runs_vae:/workspace/runs_vae" \
  -v "$PWD/checkpoints_vae:/workspace/checkpoints_vae" \
  histovae \
  --data-root /data \
  --device cuda

If you want a shell instead of running training, override the entrypoint:

docker run --rm -it --gpus all \
  -v /host/path/to/wsi:/data:ro \
  --entrypoint bash \
  histovae

Run on bare metal

Requirements

  • Python 3.9+
  • A working OpenSlide install (system library) + openslide-python

On Ubuntu/Debian, you typically need:

sudo apt-get update
sudo apt-get install -y libopenslide0

On macOS (Homebrew):

brew install openslide

Install

python -m venv .venv
source .venv/bin/activate
pip install -r requirements.txt

Train

Point --data-root at a directory containing WSI .tif / .svs files (case insensitive, recursively searched):

python vae.py --data-root /path/to/wsi_files

Common knobs:

python vae.py \
  --data-root /path/to/wsi_files \
  --img-size 256 \
  --batch-size 8 \
  --tiles-per-epoch 10000 \
  --level 0 \
  --epochs 50 \
  --beta 0.3 \
  --kl-warmup-steps 8000

By default, tiles are normalized from [0, 1] to [-1, 1] before being fed to the model.

Supported Formats

HistoVAE relies on OpenSlide for slide access, so the formats it can open are the formats OpenSlide supports on the host system. Supported formats include:

  • .svs
  • .tif
  • .dcm
  • .ndpi
  • .vms
  • .vmu
  • .scn
  • .mrxs
  • .tiff
  • .svslide
  • .bif
  • .czi

Monitor with TensorBoard

Training logs go under runs_vae/<timestamp>/ by default.

If you have Docker, you can run:

# note: arguments are optional
./tensorboard.sh runs_vae 6006

Then open http://localhost:6006.

Data format

vae.py uses OpenSlideTileDataset, which:

  • Recursively scans --data-root for .tif and .svs files (case-insensitive)
  • Randomly samples tile coordinates at a chosen OpenSlide pyramid --level
  • Converts OpenSlide RGBA output to RGB on a white background
  • Filters near-empty tiles (very low variance / mostly black / mostly white)
  • Applies simple augmentations (random flips, rotations, optional light color jitter)

If you have tiles already extracted as PNG/JPEG, you’ll need to swap the dataset to a standard image-folder dataset.

Outputs

  • Checkpoints (default --checkpoint-dir checkpoints_vae):
    • checkpoint_epoch_<N>.pt (periodic)
    • checkpoint_best.pt (best validation loss)
    • checkpoint_final.pt
  • TensorBoard logs (default --log-dir runs_vae):
    • Scalar losses (train/val)
    • Image grids of original vs reconstruction

Loading a checkpoint (example)

Checkpoints saved by training are dictionaries with at least model_state_dict.

import torch

from vae import VAE, VAEConfig

ckpt = torch.load("checkpoints_vae/checkpoint_best.pt", map_location="cpu")

# Training saves a small config subset in ckpt["config"].
cfg = ckpt.get("config", {})
config = VAEConfig(
    img_channels=cfg.get("img_channels", 3),
    img_size=cfg.get("img_size", 256),
    base_channels=cfg.get("base_channels", 32),
    channel_multipliers=tuple(cfg.get("channel_multipliers", (1, 2, 4))),
    latent_channels=cfg.get("latent_channels", 32),
)

model = VAE(config=config)
model.load_state_dict(ckpt["model_state_dict"], strict=True)
model.eval()

License

Dual-licensed under MIT and Apache 2.0.