Skip to content

thanhquangtran/FW-TSW

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

11 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Revisiting Tree-Sliced Wasserstein Distance Through the Lens of the Fermat-Weber Problem

Conference License: Apache 2.0

📄 Paper: Revisiting Tree-Sliced Wasserstein Distance Through the Lens of the Fermat-Weber Problem (ICLR 2026)

Authors: Viet-Hoang Tran*, Thanh Tran*, Thanh Chu*, Duy-Tung Pham, Trung-Khang Tran, Tam Le†, Tan Minh Nguyen†

* Equal contribution    † Co-advising

FW-TSW improves Tree-Sliced Wasserstein (TSW) distances by leveraging the Fermat-Weber (geometric median) formulation for tree construction. Instead of placing tree roots at random locations, FW-TSW centers roots at the geometric median of the data. Its variant FW-TSW* further generates line directions from data-driven paths between the two distributions. This yields:

  • Better root placement via the geometric median (Fermat-Weber point), centering trees where data actually lives.
  • Data-driven line directions (FW-TSW*) that align projections with the transport structure.
  • Drop-in replacement for standard TSW that outperforms existing methods across gradient flows, generative modeling, and runtime benchmarks in both Euclidean and spherical settings.

Requirements

To create the conda environment used in our experiments:

conda env create --file=environment.yaml
conda activate fw-tsw
pip install -e .

This installs PyTorch with CUDA 11.8 and all Python dependencies needed to run the experiments.


Quick Start

Below is a minimal example computing FW-TSW between two empirical measures in $\mathbb{R}^d$. Uncomment the FW-TSW* block to switch to data-driven line directions:

import torch
from fw_tsw.tsw import TWConcurrentLines
from fw_tsw.utils import generate_trees_frames

device = "cuda" if torch.cuda.is_available() else "cpu"

ntrees = 250   # number of trees (L in the paper)
nlines = 4     # lines per tree (k in the paper)
d = 64         # data dimension

N = 500
X = torch.randn(N, d, device=device)
Y = torch.randn(N, d, device=device)

# FW-TSW: geometric median root + Gaussian line directions
theta, intercept = generate_trees_frames(
    ntrees=ntrees,
    nlines=nlines,
    d=d,
    std=0.1,
    intercept_mode="geometric_median",
    gen_mode="gaussian_raw",
    X=X, Y=Y,
    device=device,
)

# # FW-TSW*: geometric median root + data-driven line directions
# theta, intercept = generate_trees_frames(
#     ntrees=ntrees,
#     nlines=nlines,
#     d=d,
#     std=0.1,
#     intercept_mode="geometric_median",
#     gen_mode="data_driven",
#     X=X, Y=Y,
#     kappa=10.0,
#     device=device,
# )

tw_obj = TWConcurrentLines(
    ntrees=ntrees, nlines=nlines,
    p=2, delta=2,
    mass_division="distance_based",
    device=device,
)

distance = tw_obj(X, Y, theta, intercept)
print(f"FW-TSW distance: {distance.item():.6f}")

Experiments

Gradient flow (experiments/gradient-flow/)

Particle transport toward Gaussian targets on $\mathbb{R}^d$, comparing SW, TSW, and FW-TSW variants (root-path, root-only, orthogonal, distance-based, uniform).

Denoising Diffusion GAN (experiments/denoising-diffusion-gan/)

FW-TSW as a discriminator loss for CIFAR-10 generation using denoising diffusion GANs.

Gradient flow on images (experiments/gradient-flow-images/)

Gradient flow on synthetic 28x28 digit images (appendix experiment).

Point cloud gradient flow (experiments/point-cloud-gf/)

Gradient flow on 3D point clouds (appendix experiment).

Topic modelling (experiments/topic-model/)

FW-TSW* as a latent-space loss for neural topic models.

Runtime benchmarks (experiments/runtime-plot/)

Runtime and memory comparisons between FW-TSW variants and baselines across dimensions and sample sizes.

Each subdirectory has its own README or inline comments describing configuration and how to run the corresponding experiments.


Acknowledgments

This code builds on prior work in Tree-Sliced and Sliced Optimal Transport, including:


Citation

If you find this repository useful, please cite:

@inproceedings{tran2026revisiting,
    title={Revisiting Tree-Sliced Wasserstein Distance Through the Lens of the Fermat{\textendash}Weber Problem},
    author={Viet-Hoang Tran and Thanh Tran and Thanh Chu and Duy-Tung Pham and Trung-Khang Tran and Tam Le and Tan Minh Nguyen},
    booktitle={The Fourteenth International Conference on Learning Representations},
    year={2026},
    url={https://openreview.net/forum?id=kDqG03v05B}
}

License

This project is licensed under the Apache License 2.0; see the LICENSE file for details.

About

[ICLR 2026] Revisiting Tree-Sliced Wasserstein Distance Through the Lens of the Fermat–Weber Problem

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors