Skip to content

Commit 6bb7562

Browse files
committed
Merge branch 'master' of https://github.com/scikit-tda/persim
2 parents fa6fba0 + 82d8687 commit 6bb7562

File tree

3 files changed

+116
-37
lines changed

3 files changed

+116
-37
lines changed

persim/bottleneck.py

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
3-
Implementation of the bottleneck distance
3+
Implementation of the bottleneck distance using binary
4+
search and the Hopcroft-Karp algorithm
45
56
Author: Chris Tralie
67
@@ -10,6 +11,7 @@
1011

1112
from bisect import bisect_left
1213
from hopcroftkarp import HopcroftKarp
14+
import warnings
1315

1416
__all__ = ["bottleneck"]
1517

@@ -44,12 +46,32 @@ def bottleneck(dgm1, dgm2, matching=False):
4446
return_matching = matching
4547

4648
S = np.array(dgm1)
47-
S = S[np.isfinite(S[:, 1]), :]
49+
M = min(S.shape[0], S.size)
50+
if S.size > 0:
51+
S = S[np.isfinite(S[:, 1]), :]
52+
if S.shape[0] < M:
53+
warnings.warn(
54+
"dgm1 has points with non-finite death times;"+
55+
"ignoring those points"
56+
)
57+
M = S.shape[0]
4858
T = np.array(dgm2)
49-
T = T[np.isfinite(T[:, 1]), :]
50-
51-
N = S.shape[0]
52-
M = T.shape[0]
59+
N = min(T.shape[0], T.size)
60+
if T.size > 0:
61+
T = T[np.isfinite(T[:, 1]), :]
62+
if T.shape[0] < N:
63+
warnings.warn(
64+
"dgm2 has points with non-finite death times;"+
65+
"ignoring those points"
66+
)
67+
N = T.shape[0]
68+
69+
if M == 0:
70+
S = np.array([[0, 0]])
71+
M = 1
72+
if N == 0:
73+
T = np.array([[0, 0]])
74+
N = 1
5375

5476
# Step 1: Compute CSM between S and T, including points on diagonal
5577
# L Infinity distance
@@ -61,18 +83,18 @@ def bottleneck(dgm1, dgm2, matching=False):
6183

6284
# Put diagonal elements into the matrix, being mindful that Linfinity
6385
# balls meet the diagonal line at a diamond vertex
64-
D = np.zeros((N + M, N + M))
65-
D[0:N, 0:M] = DUL
66-
UR = np.max(D) * np.ones((N, N))
86+
D = np.zeros((M + N, M + N))
87+
D[0:M, 0:N] = DUL
88+
UR = np.max(D) * np.ones((M, M))
6789
np.fill_diagonal(UR, 0.5 * (S[:, 1] - S[:, 0]))
68-
D[0:N, M::] = UR
69-
UL = np.max(D) * np.ones((M, M))
90+
D[0:M, N::] = UR
91+
UL = np.max(D) * np.ones((N, N))
7092
np.fill_diagonal(UL, 0.5 * (T[:, 1] - T[:, 0]))
71-
D[N::, 0:M] = UL
93+
D[M::, 0:N] = UL
7294

7395
# Step 2: Perform a binary search + Hopcroft Karp to find the
7496
# bottleneck distance
75-
N = D.shape[0]
97+
M = D.shape[0]
7698
ds = np.sort(np.unique(D.flatten()))
7799
bdist = ds[-1]
78100
matching = {}
@@ -82,18 +104,18 @@ def bottleneck(dgm1, dgm2, matching=False):
82104
idx = bisect_left(range(ds.size), int(ds.size / 2))
83105
d = ds[idx]
84106
graph = {}
85-
for i in range(N):
86-
graph["%s" % i] = {j for j in range(N) if D[i, j] <= d}
107+
for i in range(M):
108+
graph["%s" % i] = {j for j in range(M) if D[i, j] <= d}
87109
res = HopcroftKarp(graph).maximum_matching()
88-
if len(res) == 2 * N and d <= bdist:
110+
if len(res) == 2 * M and d <= bdist:
89111
bdist = d
90112
matching = res
91113
ds = ds[0:idx]
92114
else:
93115
ds = ds[idx + 1::]
94116

95117
if return_matching:
96-
matchidx = [(i, matching["%i" % i]) for i in range(N)]
118+
matchidx = [(i, matching["%i" % i]) for i in range(M)]
97119
return bdist, (matchidx, D)
98120
else:
99121
return bdist

persim/wasserstein.py

Lines changed: 47 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
1+
"""
2+
3+
Implementation of the Wasserstein distance using
4+
the Hungarian algorithm
5+
6+
Author: Chris Tralie
7+
8+
"""
19
import numpy as np
210
from sklearn import metrics
311
from scipy import optimize
12+
import warnings
413

514
__all__ = ["wasserstein"]
615

@@ -22,7 +31,7 @@ def wasserstein(dgm1, dgm2, matching=False):
2231
dgm2: Nx(>=2)
2332
array of birth/death paris for PD 2
2433
matching: bool, default False
25-
if True, return matching infromation and cross-similarity matrix
34+
if True, return matching information and cross-similarity matrix
2635
2736
Returns
2837
---------
@@ -34,34 +43,52 @@ def wasserstein(dgm1, dgm2, matching=False):
3443
3544
"""
3645

37-
# Step 1: Compute CSM between S and dgm2, including points on diagonal
38-
N = dgm1.shape[0]
39-
M = dgm2.shape[0]
40-
# Handle the cases where there are no points in the diagrams
41-
if N == 0:
42-
dgm1 = np.array([[0, 0]])
43-
N = 1
46+
S = np.array(dgm1)
47+
M = min(S.shape[0], S.size)
48+
if S.size > 0:
49+
S = S[np.isfinite(S[:, 1]), :]
50+
if S.shape[0] < M:
51+
warnings.warn(
52+
"dgm1 has points with non-finite death times;"+
53+
"ignoring those points"
54+
)
55+
M = S.shape[0]
56+
T = np.array(dgm2)
57+
N = min(T.shape[0], T.size)
58+
if T.size > 0:
59+
T = T[np.isfinite(T[:, 1]), :]
60+
if T.shape[0] < N:
61+
warnings.warn(
62+
"dgm2 has points with non-finite death times;"+
63+
"ignoring those points"
64+
)
65+
N = T.shape[0]
66+
4467
if M == 0:
45-
dgm2 = np.array([[0, 0]])
68+
S = np.array([[0, 0]])
4669
M = 1
47-
DUL = metrics.pairwise.pairwise_distances(dgm1, dgm2)
70+
if N == 0:
71+
T = np.array([[0, 0]])
72+
N = 1
73+
# Step 1: Compute CSM between S and dgm2, including points on diagonal
74+
DUL = metrics.pairwise.pairwise_distances(S, T)
4875

4976
# Put diagonal elements into the matrix
5077
# Rotate the diagrams to make it easy to find the straight line
5178
# distance to the diagonal
5279
cp = np.cos(np.pi/4)
5380
sp = np.sin(np.pi/4)
5481
R = np.array([[cp, -sp], [sp, cp]])
55-
dgm1 = dgm1[:, 0:2].dot(R)
56-
dgm2 = dgm2[:, 0:2].dot(R)
57-
D = np.zeros((N+M, N+M))
58-
D[0:N, 0:M] = DUL
59-
UR = np.max(D)*np.ones((N, N))
60-
np.fill_diagonal(UR, dgm1[:, 1])
61-
D[0:N, M:M+N] = UR
62-
UL = np.max(D)*np.ones((M, M))
63-
np.fill_diagonal(UL, dgm2[:, 1])
64-
D[N:M+N, 0:M] = UL
82+
S = S[:, 0:2].dot(R)
83+
T = T[:, 0:2].dot(R)
84+
D = np.zeros((M+N, M+N))
85+
D[0:M, 0:N] = DUL
86+
UR = np.max(D)*np.ones((M, M))
87+
np.fill_diagonal(UR, S[:, 1])
88+
D[0:M, N:N+M] = UR
89+
UL = np.max(D)*np.ones((N, N))
90+
np.fill_diagonal(UL, T[:, 1])
91+
D[M:N+M, 0:N] = UL
6592

6693
# Step 2: Run the hungarian algorithm
6794
matchi, matchj = optimize.linear_sum_assignment(D)

test/test_distances.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,21 @@ def test_2x2_bisect_bug(self):
106106
dgm2 = np.array([[4, 10], [9, 10]])
107107
dist = bottleneck(dgm1, dgm2)
108108
assert dist == 2
109+
110+
def test_one_empty(self):
111+
dgm1 = np.array([[1, 2]])
112+
empty = np.array([[]])
113+
dist = bottleneck(dgm1, empty)
114+
assert dist == 0.5
115+
116+
def test_inf_deathtime(self):
117+
dgm = np.array([[1, 2]])
118+
empty = np.array([[0, np.inf]])
119+
with pytest.warns(UserWarning, match="dgm1 has points with non-finite death") as w:
120+
dist1 = bottleneck(empty, dgm)
121+
with pytest.warns(UserWarning, match="dgm2 has points with non-finite death") as w:
122+
dist2 = bottleneck(dgm, empty)
123+
assert (dist1 == 0.5) and (dist2 == 0.5)
109124

110125
class TestWasserstein:
111126
def test_single(self):
@@ -145,6 +160,21 @@ def test_single_point_same(self):
145160
dgm = np.array([[0.11371516, 4.45734882]])
146161
dist = wasserstein(dgm, dgm)
147162
assert dist == 0
163+
164+
def test_one_empty(self):
165+
dgm1 = np.array([[1, 2]])
166+
empty = np.array([])
167+
dist = wasserstein(dgm1, empty)
168+
assert np.allclose(dist, np.sqrt(2)/2)
169+
170+
def test_inf_deathtime(self):
171+
dgm = np.array([[1, 2]])
172+
empty = np.array([[0, np.inf]])
173+
with pytest.warns(UserWarning, match="dgm1 has points with non-finite death") as w:
174+
dist1 = wasserstein(empty, dgm)
175+
with pytest.warns(UserWarning, match="dgm2 has points with non-finite death") as w:
176+
dist2 = wasserstein(dgm, empty)
177+
assert (np.allclose(dist1, np.sqrt(2)/2)) and (np.allclose(dist2, np.sqrt(2)/2))
148178

149179

150180
class TestSliced:

0 commit comments

Comments
 (0)