Skip to content

Commit 859d619

Browse files
authored
Merge pull request #29 from scikit-tda/single-point-same-bug
Fixes #28
2 parents 8b2a7f4 + fee60ae commit 859d619

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

persim/bottleneck.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def bottleneck(dgm1, dgm2, matching=False):
8585
for i in range(N):
8686
graph["%s" % i] = {j for j in range(N) if D[i, j] <= d}
8787
res = HopcroftKarp(graph).maximum_matching()
88-
if len(res) == 2 * N and d < bdist:
88+
if len(res) == 2 * N and d <= bdist:
8989
bdist = d
9090
matching = res
9191
ds = ds[0:idx]

test/test_distances.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,17 @@ def test_matching_to_self(self):
9696
dist = bottleneck(pd, pd)
9797
assert dist == 0
9898

99+
def test_single_point_same(self):
100+
dgm = np.array([[0.11371516, 4.45734882]])
101+
dist = bottleneck(dgm, dgm)
102+
assert dist == 0
103+
104+
def test_2x2_bisect_bug(self):
105+
dgm1 = np.array([[6, 9], [6, 8]])
106+
dgm2 = np.array([[4, 10], [9, 10]])
107+
dist = bottleneck(dgm1, dgm2)
108+
assert dist == 2
109+
99110
class TestWasserstein:
100111
def test_single(self):
101112
d = wasserstein(
@@ -130,6 +141,11 @@ def test_matching_to_self(self):
130141
dist = wasserstein(pd, pd)
131142
assert dist == 0
132143

144+
def test_single_point_same(self):
145+
dgm = np.array([[0.11371516, 4.45734882]])
146+
dist = wasserstein(dgm, dgm)
147+
assert dist == 0
148+
133149

134150
class TestSliced:
135151
def test_single(self):
@@ -170,6 +186,10 @@ def test_different_size(self):
170186
# These are very loose bounds
171187
assert d == pytest.approx(0.314, 0.1)
172188

189+
def test_single_point_same(self):
190+
dgm = np.array([[0.11371516, 4.45734882]])
191+
dist = sliced_wasserstein(dgm, dgm)
192+
assert dist == 0
173193

174194
class TestHeat:
175195
def test_compare(self):
@@ -186,6 +206,11 @@ def test_compare(self):
186206
# These are very loose bounds
187207
assert d1 < d2
188208

209+
def test_single_point_same(self):
210+
dgm = np.array([[0.11371516, 4.45734882]])
211+
dist = heat(dgm, dgm)
212+
assert dist == 0
213+
189214

190215
class TestModifiedGromovHausdorff:
191216
def test_single_point(self):

0 commit comments

Comments
 (0)