Skip to content

soran-ghaderi/torchebm

Repository files navigation

TorchEBM Logo

PyPI License GitHub Stars Ask DeepWiki Build Status Docs Downloads Python Versions

⚑ A PyTorch library for energy-based modeling, with support for flow and diffusion methods.

EBM Training Animation

What is βˆ‡ TorchEBM πŸ“?

Energy-based models define distributions through a scalar energy function, where lower energy means higher probability. This is a very general formulation and many generative approaches, from MCMC sampling to score matching to flow-based generation, can be understood through this lens.

TorchEBM is a PyTorch library that gives you composable tools for this entire spectrum. You can define energy landscapes, train models with various learning objectives, and sample via MCMC, optimization, or learned continuous-time dynamics (ODEs/SDEs). The library handles classical EBM training (contrastive divergence, score matching) as well as modern interpolant-based and equilibrium-based generation methods.

πŸ“š For the full documentation, please visit the official website of TorchEBM πŸ“.

Features

  • Energy models with built-in analytical potentials and support for custom neural network energy functions
  • MCMC and optimization-based samplers for drawing samples from energy landscapes
  • Flow and diffusion samplers that generate via ODE/SDE integration of learned velocity or score fields
  • Training objectives including contrastive divergence variants, score matching variants, and equilibrium matching
  • Interpolation schemes for specifying noise-to-data paths in flow and diffusion models
  • Numerical integrators for SDE, ODE, and Hamiltonian dynamics
  • Neural network architectures ready for conditional generation
  • Synthetic datasets for rapid prototyping and benchmarking
  • Hyperparameter schedulers for step sizes, noise scales, and other training parameters
  • CUDA acceleration and mixed precision support

8 Gaussians Flow

Gaussian Double Well Rastrigin Rosenbrock
Gaussian Double Well Rastrigin Rosenbrock
Gaussian Mixture Two Moons Swiss Roll Checkerboard
Gaussian Mixture Two Moons Swiss Roll Checkerboard

Installation

pip install torchebm

Dependencies

Usage Examples

MCMC Sampling

import torch
from torchebm.core import GaussianModel
from torchebm.samplers import LangevinDynamics

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GaussianModel(mean=torch.zeros(2), cov=torch.eye(2), device=device)

sampler = LangevinDynamics(model=model, step_size=0.01, device=device)
samples = sampler.sample(x=torch.randn(500, 2, device=device), n_steps=100)
print(samples.shape)  # torch.Size([500, 2])

Training with Contrastive Divergence

import torch
from torchebm.core import BaseModel
from torchebm.samplers import LangevinDynamics
from torchebm.losses import ContrastiveDivergence
from torchebm.datasets import GaussianMixtureDataset
from torch.utils.data import DataLoader

class MLPEnergy(BaseModel):
    def __init__(self, dim):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(dim, 64), torch.nn.SiLU(),
            torch.nn.Linear(64, 64), torch.nn.SiLU(),
            torch.nn.Linear(64, 1),
        )
    def forward(self, x):
        return self.net(x).squeeze(-1)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MLPEnergy(dim=2).to(device)
sampler = LangevinDynamics(model=model, step_size=0.01, device=device)
cd_loss = ContrastiveDivergence(model=model, sampler=sampler, k_steps=10)

data = GaussianMixtureDataset(n_samples=1000, n_components=4).get_data()
loader = DataLoader(data, batch_size=64, shuffle=True)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(10):
    for batch in loader:
        optimizer.zero_grad()
        loss, _ = cd_loss(batch.to(device))
        loss.backward()
        optimizer.step()

Hamiltonian Monte Carlo

import torch
from torchebm.core import GaussianModel
from torchebm.samplers import HamiltonianMonteCarlo

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GaussianModel(mean=torch.zeros(10), cov=torch.eye(10), device=device)

hmc = HamiltonianMonteCarlo(model=model, step_size=0.1, n_leapfrog_steps=10, device=device)
samples = hmc.sample(dim=10, n_steps=500, n_samples=1000)
print(samples.shape)  # torch.Size([1000, 10])

Library Structure

torchebm/
β”œβ”€β”€ core/           # Base classes, energy models, schedulers, device management
β”œβ”€β”€ samplers/       # MCMC, optimization, and flow/diffusion samplers
β”œβ”€β”€ losses/         # Training objectives (CD, score matching, equilibrium matching)
β”œβ”€β”€ interpolants/   # Noise-to-data interpolation schemes
β”œβ”€β”€ integrators/    # Numerical integrators for SDE/ODE/Hamiltonian dynamics
β”œβ”€β”€ models/         # Neural network architectures
β”œβ”€β”€ datasets/       # Synthetic data generators
β”œβ”€β”€ utils/          # Visualization and training utilities
└── cuda/           # CUDA-accelerated implementations

Visualization Examples

Langevin Dynamics Sampling Langevin Dynamics Trajectory Parallel Sampling
Langevin Dynamics Sampling Langevin Dynamics Trajectory Parallel Sampling

Flow Comparison
Equilibrium Matching: Linear, VP, and Cosine interpolants transforming noise into data.

Check out the examples/ directory for sample scripts.

Contributing

Contributions are welcome! Step-by-step instructions for contributing to the project can be found on the contributing.md page on the website.

Please check the issues page for current tasks or create a new issue to discuss proposed changes.

Show your Support for βˆ‡ TorchEBM πŸ“

Please ⭐️ this repository if βˆ‡ TorchEBM helped you and spread the word.

Thank you! πŸš€

Citation

If TorchEBM is useful in your research, please cite it:

@misc{torchebm_library_2025,
  author       = {Ghaderi, Soran and Contributors},
  title        = {{TorchEBM}: A PyTorch Library for Training Energy-Based Models},
  year         = {2025},
  url          = {https://github.com/soran-ghaderi/torchebm},
}

Changelog

See CHANGELOG for version history.

License

MIT License. See LICENSE for details.

Research Collaboration

If you are interested in collaborating on research around energy-based, flow-based, or diffusion models, feel free to reach out. Contributions to TorchEBM πŸ“ and discussions that push the field forward are always welcome.