Skip to content

Commit b8a0af7

Browse files
committed
Found bug in wasserstein
1 parent e39c19b commit b8a0af7

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

test/test_distances.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,12 @@ def test_inf_deathtime(self):
180180
with pytest.warns(UserWarning, match="dgm2 has points with non-finite death") as w:
181181
dist2 = wasserstein(dgm, empty)
182182
assert (np.allclose(dist1, np.sqrt(2)/2)) and (np.allclose(dist2, np.sqrt(2)/2))
183+
184+
def test_repeated(self):
185+
dgm1 = np.array([[0, 10], [0,10], [0,11]])
186+
dgm2 = np.array([[0, 10]])
187+
dist = wasserstein(dgm1, dgm2)
188+
assert dist == 10.5
183189

184190

185191
class TestSliced:

0 commit comments

Comments
 (0)