Skip to content

Commit c35fa04

Browse files
authored
Merge pull request #53 from scikit-tda/better_corresp_api
Better corresp api
2 parents a040a58 + bf64346 commit c35fa04

File tree

8 files changed

+213
-125
lines changed

8 files changed

+213
-125
lines changed

RELEASE.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
0.3.1
2+
- Fixed bug with repeated intervals in bottleneck
3+
- Tidied up API for indicating matchings for bottleneck and wasserstein, and updated notebook
4+
15
0.3.0
26
- Add implementations of Persistence Landscapes, including plotting methods, a transformer, and additional notebooks.
37

docs/notebooks/distances.ipynb

Lines changed: 48 additions & 36 deletions
Large diffs are not rendered by default.

persim/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.3.0"
1+
__version__ = "0.3.1"

persim/bottleneck.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,11 @@ def bottleneck(dgm1, dgm2, matching=False):
3939
4040
d: float
4141
bottleneck distance between dgm1 and dgm2
42-
(matching, D): Only returns if `matching=True`
43-
(tuples of matched indices, (N+M)x(N+M) cross-similarity matrix)
42+
matching: ndarray(Mx+Nx, 3), Only returns if `matching=True`
43+
A list of correspondences in an optimal matching, as well as their distance, where:
44+
* First column is index of point in first persistence diagram, or -1 if diagonal
45+
* Second column is index of point in second persistence diagram, or -1 if diagonal
46+
* Third column is the distance of each matching
4447
"""
4548

4649
return_matching = matching
@@ -84,18 +87,21 @@ def bottleneck(dgm1, dgm2, matching=False):
8487
# Put diagonal elements into the matrix, being mindful that Linfinity
8588
# balls meet the diagonal line at a diamond vertex
8689
D = np.zeros((M + N, M + N))
90+
# Upper left is Linfinity cross-similarity between two diagrams
8791
D[0:M, 0:N] = DUL
88-
UR = np.max(D) * np.ones((M, M))
92+
# Upper right is diagonal matching of points from S
93+
UR = np.inf * np.ones((M, M))
8994
np.fill_diagonal(UR, 0.5 * (S[:, 1] - S[:, 0]))
9095
D[0:M, N::] = UR
91-
UL = np.max(D) * np.ones((N, N))
96+
# Lower left is diagonal matching of points from T
97+
UL = np.inf * np.ones((N, N))
9298
np.fill_diagonal(UL, 0.5 * (T[:, 1] - T[:, 0]))
9399
D[M::, 0:N] = UL
100+
# Lower right is all 0s by default (remaining diagonals match to diagonals)
94101

95102
# Step 2: Perform a binary search + Hopcroft Karp to find the
96103
# bottleneck distance
97-
M = D.shape[0]
98-
ds = np.sort(np.unique(D.flatten()))
104+
ds = np.sort(np.unique(D.flatten()))[0:-1] # Everything but np.inf
99105
bdist = ds[-1]
100106
matching = {}
101107
while len(ds) >= 1:
@@ -104,18 +110,29 @@ def bottleneck(dgm1, dgm2, matching=False):
104110
idx = bisect_left(range(ds.size), int(ds.size / 2))
105111
d = ds[idx]
106112
graph = {}
107-
for i in range(M):
108-
graph["%s" % i] = {j for j in range(M) if D[i, j] <= d}
113+
for i in range(D.shape[0]):
114+
graph["{}".format(i)] = {j for j in range(D.shape[1]) if D[i, j] <= d}
109115
res = HopcroftKarp(graph).maximum_matching()
110-
if len(res) == 2 * M and d <= bdist:
116+
if len(res) == 2 * D.shape[0] and d <= bdist:
111117
bdist = d
112118
matching = res
113119
ds = ds[0:idx]
114120
else:
115121
ds = ds[idx + 1::]
116122

117123
if return_matching:
118-
matchidx = [(i, matching["%i" % i]) for i in range(M)]
119-
return bdist, (matchidx, D)
124+
matchidx = []
125+
for i in range(M+N):
126+
j = matching["{}".format(i)]
127+
d = D[i, j]
128+
if i < M:
129+
if j >= N:
130+
j = -1 # Diagonal match from first persistence diagram
131+
else:
132+
if j >= N: # Diagonal to diagonal, so don't include this
133+
continue
134+
i = -1
135+
matchidx.append([i, j, d])
136+
return bdist, np.array(matchidx)
120137
else:
121138
return bdist

persim/visuals.py

Lines changed: 73 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -168,20 +168,21 @@ def plot_diagrams(
168168
def plot_a_bar(p, q, c='b', linestyle='-'):
169169
plt.plot([p[0], q[0]], [p[1], q[1]], c=c, linestyle=linestyle, linewidth=1)
170170

171-
def bottleneck_matching(I1, I2, matchidx, D, labels=["dgm1", "dgm2"], ax=None):
171+
def bottleneck_matching(dgm1, dgm2, matching, labels=["dgm1", "dgm2"], ax=None):
172172
""" Visualize bottleneck matching between two diagrams
173173
174174
Parameters
175175
===========
176176
177-
I1: array
178-
A diagram
179-
I2: array
180-
A diagram
181-
matchidx: tuples of matched indices
182-
if input `matching=True`, then return matching
183-
D: array
184-
cross-similarity matrix
177+
dgm1: Mx(>=2)
178+
array of birth/death pairs for PD 1
179+
dgm2: Nx(>=2)
180+
array of birth/death paris for PD 2
181+
matching: ndarray(Mx+Nx, 3)
182+
A list of correspondences in an optimal matching, as well as their distance, where:
183+
* First column is index of point in first persistence diagram, or -1 if diagonal
184+
* Second column is index of point in second persistence diagram, or -1 if diagonal
185+
* Third column is the distance of each matching
185186
labels: list of strings
186187
names of diagrams for legend. Default = ["dgm1", "dgm2"],
187188
ax: matplotlib Axis object
@@ -191,50 +192,61 @@ def bottleneck_matching(I1, I2, matchidx, D, labels=["dgm1", "dgm2"], ax=None):
191192
Examples
192193
==========
193194
194-
bn_matching, (matchidx, D) = persim.bottleneck(A_h1, B_h1, matching=True)
195-
persim.bottleneck_matching(A_h1, B_h1, matchidx, D)
195+
dist, matching = persim.bottleneck(A_h1, B_h1, matching=True)
196+
persim.bottleneck_matching(A_h1, B_h1, matching)
196197
197198
"""
198199
ax = ax or plt.gca()
199200

200-
plot_diagrams([I1, I2], labels=labels, ax=ax)
201+
plot_diagrams([dgm1, dgm2], labels=labels, ax=ax)
201202
cp = np.cos(np.pi / 4)
202203
sp = np.sin(np.pi / 4)
203204
R = np.array([[cp, -sp], [sp, cp]])
204-
if I1.size == 0:
205-
I1 = np.array([[0, 0]])
206-
if I2.size == 0:
207-
I2 = np.array([[0, 0]])
208-
I1Rot = I1.dot(R)
209-
I2Rot = I2.dot(R)
210-
dists = [D[i, j] for (i, j) in matchidx]
211-
(i, j) = matchidx[np.argmax(dists)]
212-
if i >= I1.shape[0] and j >= I2.shape[0]:
213-
return
214-
if i >= I1.shape[0]:
215-
diagElem = np.array([I2Rot[j, 0], 0])
216-
diagElem = diagElem.dot(R.T)
217-
plt.plot([I2[j, 0], diagElem[0]], [I2[j, 1], diagElem[1]], "g")
218-
elif j >= I2.shape[0]:
219-
diagElem = np.array([I1Rot[i, 0], 0])
220-
diagElem = diagElem.dot(R.T)
221-
ax.plot([I1[i, 0], diagElem[0]], [I1[i, 1], diagElem[1]], "g")
222-
else:
223-
ax.plot([I1[i, 0], I2[j, 0]], [I1[i, 1], I2[j, 1]], "g")
224-
225-
226-
def wasserstein_matching(I1, I2, matchidx, palette=None, labels=["dgm1", "dgm2"], colors=None, ax=None):
205+
if dgm1.size == 0:
206+
dgm1 = np.array([[0, 0]])
207+
if dgm2.size == 0:
208+
dgm2 = np.array([[0, 0]])
209+
dgm1Rot = dgm1.dot(R)
210+
dgm2Rot = dgm2.dot(R)
211+
max_idx = np.argmax(matching[:, 2])
212+
for idx, [i, j, d] in enumerate(matching):
213+
i = int(i)
214+
j = int(j)
215+
linestyle = '--'
216+
linewidth = 1
217+
c = 'C2'
218+
if idx == max_idx:
219+
linestyle = '-'
220+
linewidth = 2
221+
c = 'C3'
222+
if i != -1 or j != -1: # At least one point is a non-diagonal point
223+
if i == -1:
224+
diagElem = np.array([dgm2Rot[j, 0], 0])
225+
diagElem = diagElem.dot(R.T)
226+
plt.plot([dgm2[j, 0], diagElem[0]], [dgm2[j, 1], diagElem[1]], c, linewidth=linewidth, linestyle=linestyle)
227+
elif j == -1:
228+
diagElem = np.array([dgm1Rot[i, 0], 0])
229+
diagElem = diagElem.dot(R.T)
230+
ax.plot([dgm1[i, 0], diagElem[0]], [dgm1[i, 1], diagElem[1]], c, linewidth=linewidth, linestyle=linestyle)
231+
else:
232+
ax.plot([dgm1[i, 0], dgm2[j, 0]], [dgm1[i, 1], dgm2[j, 1]], c, linewidth=linewidth, linestyle=linestyle)
233+
234+
235+
def wasserstein_matching(dgm1, dgm2, matching, labels=["dgm1", "dgm2"], ax=None):
227236
""" Visualize bottleneck matching between two diagrams
228237
229238
Parameters
230239
===========
231240
232-
I1: array
241+
dgm1: array
233242
A diagram
234-
I2: array
243+
dgm2: array
235244
A diagram
236-
matchidx: tuples of matched indices
237-
if input `matching=True`, then return matching
245+
matching: ndarray(Mx+Nx, 3)
246+
A list of correspondences in an optimal matching, as well as their distance, where:
247+
* First column is index of point in first persistence diagram, or -1 if diagonal
248+
* Second column is index of point in second persistence diagram, or -1 if diagonal
249+
* Third column is the distance of each matching
238250
labels: list of strings
239251
names of diagrams for legend. Default = ["dgm1", "dgm2"],
240252
ax: matplotlib Axis object
@@ -252,25 +264,25 @@ def wasserstein_matching(I1, I2, matchidx, palette=None, labels=["dgm1", "dgm2"]
252264
cp = np.cos(np.pi / 4)
253265
sp = np.sin(np.pi / 4)
254266
R = np.array([[cp, -sp], [sp, cp]])
255-
if I1.size == 0:
256-
I1 = np.array([[0, 0]])
257-
if I2.size == 0:
258-
I2 = np.array([[0, 0]])
259-
I1Rot = I1.dot(R)
260-
I2Rot = I2.dot(R)
261-
for index in matchidx:
262-
(i, j) = index
263-
if i >= I1.shape[0] and j >= I2.shape[0]:
264-
continue
265-
if i >= I1.shape[0]:
266-
diagElem = np.array([I2Rot[j, 0], 0])
267-
diagElem = diagElem.dot(R.T)
268-
plt.plot([I2[j, 0], diagElem[0]], [I2[j, 1], diagElem[1]], "g")
269-
elif j >= I2.shape[0]:
270-
diagElem = np.array([I1Rot[i, 0], 0])
271-
diagElem = diagElem.dot(R.T)
272-
ax.plot([I1[i, 0], diagElem[0]], [I1[i, 1], diagElem[1]], "g")
273-
else:
274-
ax.plot([I1[i, 0], I2[j, 0]], [I1[i, 1], I2[j, 1]], "g")
275-
276-
plot_diagrams([I1, I2], labels=labels, ax=ax)
267+
if dgm1.size == 0:
268+
dgm1 = np.array([[0, 0]])
269+
if dgm2.size == 0:
270+
dgm2 = np.array([[0, 0]])
271+
dgm1Rot = dgm1.dot(R)
272+
dgm2Rot = dgm2.dot(R)
273+
for [i, j, d] in matching:
274+
i = int(i)
275+
j = int(j)
276+
if i != -1 or j != -1: # At least one point is a non-diagonal point
277+
if i == -1:
278+
diagElem = np.array([dgm2Rot[j, 0], 0])
279+
diagElem = diagElem.dot(R.T)
280+
plt.plot([dgm2[j, 0], diagElem[0]], [dgm2[j, 1], diagElem[1]], "g")
281+
elif j == -1:
282+
diagElem = np.array([dgm1Rot[i, 0], 0])
283+
diagElem = diagElem.dot(R.T)
284+
ax.plot([dgm1[i, 0], diagElem[0]], [dgm1[i, 1], diagElem[1]], "g")
285+
else:
286+
ax.plot([dgm1[i, 0], dgm2[j, 0]], [dgm1[i, 1], dgm2[j, 1]], "g")
287+
288+
plot_diagrams([dgm1, dgm2], labels=labels, ax=ax)

persim/wasserstein.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def wasserstein(dgm1, dgm2, matching=False):
7070
if N == 0:
7171
T = np.array([[0, 0]])
7272
N = 1
73-
# Step 1: Compute CSM between S and dgm2, including points on diagonal
73+
# Compute CSM between S and dgm2, including points on diagonal
7474
DUL = metrics.pairwise.pairwise_distances(S, T)
7575

7676
# Put diagonal elements into the matrix
@@ -82,11 +82,12 @@ def wasserstein(dgm1, dgm2, matching=False):
8282
S = S[:, 0:2].dot(R)
8383
T = T[:, 0:2].dot(R)
8484
D = np.zeros((M+N, M+N))
85+
np.fill_diagonal(D, 0)
8586
D[0:M, 0:N] = DUL
86-
UR = np.max(D)*np.ones((M, M))
87+
UR = np.inf*np.ones((M, M))
8788
np.fill_diagonal(UR, S[:, 1])
8889
D[0:M, N:N+M] = UR
89-
UL = np.max(D)*np.ones((N, N))
90+
UL = np.inf*np.ones((N, N))
9091
np.fill_diagonal(UL, T[:, 1])
9192
D[M:N+M, 0:N] = UL
9293

@@ -96,6 +97,14 @@ def wasserstein(dgm1, dgm2, matching=False):
9697

9798
if matching:
9899
matchidx = [(i, j) for i, j in zip(matchi, matchj)]
99-
return matchdist, (matchidx, D)
100+
ret = np.zeros((len(matchidx), 3))
101+
ret[:, 0:2] = np.array(matchidx)
102+
ret[:, 2] = D[matchi, matchj]
103+
# Indicate diagonally matched points
104+
ret[ret[:, 0] >= M, 0] = -1
105+
ret[ret[:, 1] >= N, 1] = -1
106+
# Exclude diagonal to diagonal
107+
ret = ret[ret[:, 0] + ret[:, 1] != -2, :]
108+
return matchdist, ret
100109

101110
return matchdist

test/test_distances.py

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,7 @@ def test_different_size(self):
6161
[0.5, 1.1]
6262
])
6363
)
64-
65-
# These are very loose bounds
66-
assert d == pytest.approx(0.1, 0.001)
64+
assert d == 0.25
6765

6866
def test_matching(self):
6967
dgm1 = np.array([
@@ -77,14 +75,15 @@ def test_matching(self):
7775
[1.0, 1.1],
7876
])
7977

80-
d, (m, D) = bottleneck(
78+
d, m = bottleneck(
8179
dgm1, dgm2,
8280
matching=True
8381
)
84-
85-
# These are very loose bounds
86-
assert len(m) == len(dgm1) + len(dgm2)
87-
assert D.shape == (len(dgm1) + len(dgm2), len(dgm1) + len(dgm2))
82+
u1 = np.unique(m[:, 0])
83+
u1 = u1[u1 >= 0]
84+
u2 = np.unique(m[:, 1])
85+
u2 = u2[u2 >= 0]
86+
assert u1.size == dgm1.shape[0] and u2.size == dgm2.shape[0]
8887

8988
def test_matching_to_self(self):
9089
# Matching a diagram to itself should yield 0
@@ -122,6 +121,13 @@ def test_inf_deathtime(self):
122121
dist2 = bottleneck(dgm, empty)
123122
assert (dist1 == 0.5) and (dist2 == 0.5)
124123

124+
def test_repeated(self):
125+
# Issue #44
126+
G = np.array([[ 0, 1], [0,1]])
127+
H = np.array([[ 0, 1]])
128+
dist = bottleneck(G, H)
129+
assert dist == 0.5
130+
125131
class TestWasserstein:
126132
def test_single(self):
127133
d = wasserstein(
@@ -175,6 +181,34 @@ def test_inf_deathtime(self):
175181
with pytest.warns(UserWarning, match="dgm2 has points with non-finite death") as w:
176182
dist2 = wasserstein(dgm, empty)
177183
assert (np.allclose(dist1, np.sqrt(2)/2)) and (np.allclose(dist2, np.sqrt(2)/2))
184+
185+
def test_repeated(self):
186+
dgm1 = np.array([[0, 10], [0,10]])
187+
dgm2 = np.array([[0, 10]])
188+
dist = wasserstein(dgm1, dgm2)
189+
assert dist == 5*np.sqrt(2)
190+
191+
def test_matching(self):
192+
dgm1 = np.array([
193+
[0.5, 1],
194+
[0.6, 1.1]
195+
])
196+
dgm2 = np.array([
197+
[0.5, 1.1],
198+
[0.6, 1.1],
199+
[0.8, 1.1],
200+
[1.0, 1.1],
201+
])
202+
203+
d, m = wasserstein(
204+
dgm1, dgm2,
205+
matching=True
206+
)
207+
u1 = np.unique(m[:, 0])
208+
u1 = u1[u1 >= 0]
209+
u2 = np.unique(m[:, 1])
210+
u2 = u2[u2 >= 0]
211+
assert u1.size == dgm1.shape[0] and u2.size == dgm2.shape[0]
178212

179213

180214
class TestSliced:

test/test_visuals.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,8 @@ def test_bottleneck_matching(self):
198198
[0.3, 0.45]
199199
])
200200

201-
d, (matching, D) = persim.bottleneck(dgm1, dgm2, matching=True)
202-
persim.bottleneck_matching(dgm1, dgm2, matching, D)
201+
d, matching = persim.bottleneck(dgm1, dgm2, matching=True)
202+
persim.bottleneck_matching(dgm1, dgm2, matching)
203203

204204
def test_plot_labels(self):
205205
dgm1 = np.array([
@@ -211,6 +211,6 @@ def test_plot_labels(self):
211211
[0.3, 0.45]
212212
])
213213

214-
d, (matching, D) = persim.bottleneck(dgm1, dgm2, matching=True)
214+
d, matching = persim.bottleneck(dgm1, dgm2, matching=True)
215215
persim.bottleneck_matching(
216-
dgm1, dgm2, matching, D, labels=["X", "Y"])
216+
dgm1, dgm2, matching, labels=["X", "Y"])

0 commit comments

Comments
 (0)