📄 Paper: Revisiting Tree-Sliced Wasserstein Distance Through the Lens of the Fermat-Weber Problem (ICLR 2026)
Authors: Viet-Hoang Tran*, Thanh Tran*, Thanh Chu*, Duy-Tung Pham, Trung-Khang Tran, Tam Le†, Tan Minh Nguyen†
* Equal contribution † Co-advising
FW-TSW improves Tree-Sliced Wasserstein (TSW) distances by leveraging the Fermat-Weber (geometric median) formulation for tree construction. Instead of placing tree roots at random locations, FW-TSW centers roots at the geometric median of the data. Its variant FW-TSW* further generates line directions from data-driven paths between the two distributions. This yields:
- Better root placement via the geometric median (Fermat-Weber point), centering trees where data actually lives.
- Data-driven line directions (FW-TSW*) that align projections with the transport structure.
- Drop-in replacement for standard TSW that outperforms existing methods across gradient flows, generative modeling, and runtime benchmarks in both Euclidean and spherical settings.
To create the conda environment used in our experiments:
conda env create --file=environment.yaml
conda activate fw-tsw
pip install -e .This installs PyTorch with CUDA 11.8 and all Python dependencies needed to run the experiments.
Below is a minimal example computing FW-TSW between two empirical measures in
import torch
from fw_tsw.tsw import TWConcurrentLines
from fw_tsw.utils import generate_trees_frames
device = "cuda" if torch.cuda.is_available() else "cpu"
ntrees = 250 # number of trees (L in the paper)
nlines = 4 # lines per tree (k in the paper)
d = 64 # data dimension
N = 500
X = torch.randn(N, d, device=device)
Y = torch.randn(N, d, device=device)
# FW-TSW: geometric median root + Gaussian line directions
theta, intercept = generate_trees_frames(
ntrees=ntrees,
nlines=nlines,
d=d,
std=0.1,
intercept_mode="geometric_median",
gen_mode="gaussian_raw",
X=X, Y=Y,
device=device,
)
# # FW-TSW*: geometric median root + data-driven line directions
# theta, intercept = generate_trees_frames(
# ntrees=ntrees,
# nlines=nlines,
# d=d,
# std=0.1,
# intercept_mode="geometric_median",
# gen_mode="data_driven",
# X=X, Y=Y,
# kappa=10.0,
# device=device,
# )
tw_obj = TWConcurrentLines(
ntrees=ntrees, nlines=nlines,
p=2, delta=2,
mass_division="distance_based",
device=device,
)
distance = tw_obj(X, Y, theta, intercept)
print(f"FW-TSW distance: {distance.item():.6f}")Particle transport toward Gaussian targets on
FW-TSW as a discriminator loss for CIFAR-10 generation using denoising diffusion GANs.
Gradient flow on synthetic 28x28 digit images (appendix experiment).
Gradient flow on 3D point clouds (appendix experiment).
FW-TSW* as a latent-space loss for neural topic models.
Runtime and memory comparisons between FW-TSW variants and baselines across dimensions and sample sizes.
Each subdirectory has its own README or inline comments describing configuration and how to run the corresponding experiments.
This code builds on prior work in Tree-Sliced and Sliced Optimal Transport, including:
If you find this repository useful, please cite:
@inproceedings{tran2026revisiting,
title={Revisiting Tree-Sliced Wasserstein Distance Through the Lens of the Fermat{\textendash}Weber Problem},
author={Viet-Hoang Tran and Thanh Tran and Thanh Chu and Duy-Tung Pham and Trung-Khang Tran and Tam Le and Tan Minh Nguyen},
booktitle={The Fourteenth International Conference on Learning Representations},
year={2026},
url={https://openreview.net/forum?id=kDqG03v05B}
}This project is licensed under the Apache License 2.0; see the LICENSE file for details.