Skip to content

Commit cc5c351

Browse files
committed
Fixed up bottleneck and came up with better conventions for plotting matching, but wasserstein is still broken
1 parent b8a0af7 commit cc5c351

File tree

6 files changed

+172
-123
lines changed

6 files changed

+172
-123
lines changed

docs/notebooks/distances.ipynb

Lines changed: 48 additions & 36 deletions
Large diffs are not rendered by default.

persim/bottleneck.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,11 @@ def bottleneck(dgm1, dgm2, matching=False):
3939
4040
d: float
4141
bottleneck distance between dgm1 and dgm2
42-
(matching, D): Only returns if `matching=True`
43-
(tuples of matched indices, (N+M)x(N+M) cross-similarity matrix)
42+
matching: ndarray(Mx+Nx, 3), Only returns if `matching=True`
43+
A list of correspondences in an optimal matching, as well as their distance, where:
44+
* First column is index of point in first persistence diagram, or -1 if diagonal
45+
* Second column is index of point in second persistence diagram, or -1 if diagonal
46+
* Third column is the distance of each matching
4447
"""
4548

4649
return_matching = matching
@@ -83,19 +86,21 @@ def bottleneck(dgm1, dgm2, matching=False):
8386

8487
# Put diagonal elements into the matrix, being mindful that Linfinity
8588
# balls meet the diagonal line at a diamond vertex
86-
D = np.inf*np.ones((M + N, M + N))
87-
np.fill_diagonal(D, 0)
89+
D = np.zeros((M + N, M + N))
90+
# Upper left is Linfinity cross-similarity between two diagrams
8891
D[0:M, 0:N] = DUL
89-
UR = np.max(D) * np.ones((M, M))
92+
# Upper right is diagonal matching of points from S
93+
UR = np.inf * np.ones((M, M))
9094
np.fill_diagonal(UR, 0.5 * (S[:, 1] - S[:, 0]))
9195
D[0:M, N::] = UR
92-
UL = np.max(D) * np.ones((N, N))
96+
# Lower left is diagonal matching of points from T
97+
UL = np.inf * np.ones((N, N))
9398
np.fill_diagonal(UL, 0.5 * (T[:, 1] - T[:, 0]))
9499
D[M::, 0:N] = UL
100+
# Lower right is all 0s by default (remaining diagonals match to diagonals)
95101

96102
# Step 2: Perform a binary search + Hopcroft Karp to find the
97103
# bottleneck distance
98-
M = D.shape[0]
99104
ds = np.sort(np.unique(D.flatten()))[0:-1] # Everything but np.inf
100105
bdist = ds[-1]
101106
matching = {}
@@ -105,18 +110,29 @@ def bottleneck(dgm1, dgm2, matching=False):
105110
idx = bisect_left(range(ds.size), int(ds.size / 2))
106111
d = ds[idx]
107112
graph = {}
108-
for i in range(M):
109-
graph["%s" % i] = {j for j in range(M) if D[i, j] <= d}
113+
for i in range(D.shape[0]):
114+
graph["{}".format(i)] = {j for j in range(D.shape[1]) if D[i, j] <= d}
110115
res = HopcroftKarp(graph).maximum_matching()
111-
if len(res) == 2 * M and d <= bdist:
116+
if len(res) == 2 * D.shape[0] and d <= bdist:
112117
bdist = d
113118
matching = res
114119
ds = ds[0:idx]
115120
else:
116121
ds = ds[idx + 1::]
117122

118123
if return_matching:
119-
matchidx = [(i, matching["%i" % i]) for i in range(M)]
120-
return bdist, (matchidx, D)
124+
matchidx = []
125+
for i in range(M+N):
126+
j = matching["{}".format(i)]
127+
d = D[i, j]
128+
if i < M:
129+
if j >= N:
130+
j = -1 # Diagonal match from first persistence diagram
131+
else:
132+
if j >= N: # Diagonal to diagonal, so don't include this
133+
continue
134+
i = -1
135+
matchidx.append([i, j, d])
136+
return bdist, np.array(matchidx)
121137
else:
122138
return bdist

persim/visuals.py

Lines changed: 73 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -168,20 +168,21 @@ def plot_diagrams(
168168
def plot_a_bar(p, q, c='b', linestyle='-'):
169169
plt.plot([p[0], q[0]], [p[1], q[1]], c=c, linestyle=linestyle, linewidth=1)
170170

171-
def bottleneck_matching(I1, I2, matchidx, D, labels=["dgm1", "dgm2"], ax=None):
171+
def bottleneck_matching(dgm1, dgm2, matching, labels=["dgm1", "dgm2"], ax=None):
172172
""" Visualize bottleneck matching between two diagrams
173173
174174
Parameters
175175
===========
176176
177-
I1: array
178-
A diagram
179-
I2: array
180-
A diagram
181-
matchidx: tuples of matched indices
182-
if input `matching=True`, then return matching
183-
D: array
184-
cross-similarity matrix
177+
dgm1: Mx(>=2)
178+
array of birth/death pairs for PD 1
179+
dgm2: Nx(>=2)
180+
array of birth/death paris for PD 2
181+
matching: ndarray(Mx+Nx, 3)
182+
A list of correspondences in an optimal matching, as well as their distance, where:
183+
* First column is index of point in first persistence diagram, or -1 if diagonal
184+
* Second column is index of point in second persistence diagram, or -1 if diagonal
185+
* Third column is the distance of each matching
185186
labels: list of strings
186187
names of diagrams for legend. Default = ["dgm1", "dgm2"],
187188
ax: matplotlib Axis object
@@ -191,50 +192,61 @@ def bottleneck_matching(I1, I2, matchidx, D, labels=["dgm1", "dgm2"], ax=None):
191192
Examples
192193
==========
193194
194-
bn_matching, (matchidx, D) = persim.bottleneck(A_h1, B_h1, matching=True)
195-
persim.bottleneck_matching(A_h1, B_h1, matchidx, D)
195+
dist, matching = persim.bottleneck(A_h1, B_h1, matching=True)
196+
persim.bottleneck_matching(A_h1, B_h1, matching)
196197
197198
"""
198199
ax = ax or plt.gca()
199200

200-
plot_diagrams([I1, I2], labels=labels, ax=ax)
201+
plot_diagrams([dgm1, dgm2], labels=labels, ax=ax)
201202
cp = np.cos(np.pi / 4)
202203
sp = np.sin(np.pi / 4)
203204
R = np.array([[cp, -sp], [sp, cp]])
204-
if I1.size == 0:
205-
I1 = np.array([[0, 0]])
206-
if I2.size == 0:
207-
I2 = np.array([[0, 0]])
208-
I1Rot = I1.dot(R)
209-
I2Rot = I2.dot(R)
210-
dists = [D[i, j] for (i, j) in matchidx]
211-
(i, j) = matchidx[np.argmax(dists)]
212-
if i >= I1.shape[0] and j >= I2.shape[0]:
213-
return
214-
if i >= I1.shape[0]:
215-
diagElem = np.array([I2Rot[j, 0], 0])
216-
diagElem = diagElem.dot(R.T)
217-
plt.plot([I2[j, 0], diagElem[0]], [I2[j, 1], diagElem[1]], "g")
218-
elif j >= I2.shape[0]:
219-
diagElem = np.array([I1Rot[i, 0], 0])
220-
diagElem = diagElem.dot(R.T)
221-
ax.plot([I1[i, 0], diagElem[0]], [I1[i, 1], diagElem[1]], "g")
222-
else:
223-
ax.plot([I1[i, 0], I2[j, 0]], [I1[i, 1], I2[j, 1]], "g")
224-
225-
226-
def wasserstein_matching(I1, I2, matchidx, palette=None, labels=["dgm1", "dgm2"], colors=None, ax=None):
205+
if dgm1.size == 0:
206+
dgm1 = np.array([[0, 0]])
207+
if dgm2.size == 0:
208+
dgm2 = np.array([[0, 0]])
209+
dgm1Rot = dgm1.dot(R)
210+
dgm2Rot = dgm2.dot(R)
211+
max_idx = np.argmax(matching[:, 2])
212+
for idx, [i, j, d] in enumerate(matching):
213+
i = int(i)
214+
j = int(j)
215+
linestyle = '--'
216+
linewidth = 1
217+
c = 'C2'
218+
if idx == max_idx:
219+
linestyle = '-'
220+
linewidth = 2
221+
c = 'C3'
222+
if i != -1 or j != -1: # At least one point is a non-diagonal point
223+
if i == -1:
224+
diagElem = np.array([dgm2Rot[j, 0], 0])
225+
diagElem = diagElem.dot(R.T)
226+
plt.plot([dgm2[j, 0], diagElem[0]], [dgm2[j, 1], diagElem[1]], c, linewidth=linewidth, linestyle=linestyle)
227+
elif j == -1:
228+
diagElem = np.array([dgm1Rot[i, 0], 0])
229+
diagElem = diagElem.dot(R.T)
230+
ax.plot([dgm1[i, 0], diagElem[0]], [dgm1[i, 1], diagElem[1]], c, linewidth=linewidth, linestyle=linestyle)
231+
else:
232+
ax.plot([dgm1[i, 0], dgm2[j, 0]], [dgm1[i, 1], dgm2[j, 1]], c, linewidth=linewidth, linestyle=linestyle)
233+
234+
235+
def wasserstein_matching(dgm1, dgm2, matching, labels=["dgm1", "dgm2"], ax=None):
227236
""" Visualize bottleneck matching between two diagrams
228237
229238
Parameters
230239
===========
231240
232-
I1: array
241+
dgm1: array
233242
A diagram
234-
I2: array
243+
dgm2: array
235244
A diagram
236-
matchidx: tuples of matched indices
237-
if input `matching=True`, then return matching
245+
matching: ndarray(Mx+Nx, 3)
246+
A list of correspondences in an optimal matching, as well as their distance, where:
247+
* First column is index of point in first persistence diagram, or -1 if diagonal
248+
* Second column is index of point in second persistence diagram, or -1 if diagonal
249+
* Third column is the distance of each matching
238250
labels: list of strings
239251
names of diagrams for legend. Default = ["dgm1", "dgm2"],
240252
ax: matplotlib Axis object
@@ -252,25 +264,25 @@ def wasserstein_matching(I1, I2, matchidx, palette=None, labels=["dgm1", "dgm2"]
252264
cp = np.cos(np.pi / 4)
253265
sp = np.sin(np.pi / 4)
254266
R = np.array([[cp, -sp], [sp, cp]])
255-
if I1.size == 0:
256-
I1 = np.array([[0, 0]])
257-
if I2.size == 0:
258-
I2 = np.array([[0, 0]])
259-
I1Rot = I1.dot(R)
260-
I2Rot = I2.dot(R)
261-
for index in matchidx:
262-
(i, j) = index
263-
if i >= I1.shape[0] and j >= I2.shape[0]:
264-
continue
265-
if i >= I1.shape[0]:
266-
diagElem = np.array([I2Rot[j, 0], 0])
267-
diagElem = diagElem.dot(R.T)
268-
plt.plot([I2[j, 0], diagElem[0]], [I2[j, 1], diagElem[1]], "g")
269-
elif j >= I2.shape[0]:
270-
diagElem = np.array([I1Rot[i, 0], 0])
271-
diagElem = diagElem.dot(R.T)
272-
ax.plot([I1[i, 0], diagElem[0]], [I1[i, 1], diagElem[1]], "g")
273-
else:
274-
ax.plot([I1[i, 0], I2[j, 0]], [I1[i, 1], I2[j, 1]], "g")
275-
276-
plot_diagrams([I1, I2], labels=labels, ax=ax)
267+
if dgm1.size == 0:
268+
dgm1 = np.array([[0, 0]])
269+
if dgm2.size == 0:
270+
dgm2 = np.array([[0, 0]])
271+
dgm1Rot = dgm1.dot(R)
272+
dgm2Rot = dgm2.dot(R)
273+
for [i, j, d] in matching:
274+
i = int(i)
275+
j = int(j)
276+
if i != -1 or j != -1: # At least one point is a non-diagonal point
277+
if i == -1:
278+
diagElem = np.array([dgm2Rot[j, 0], 0])
279+
diagElem = diagElem.dot(R.T)
280+
plt.plot([dgm2[j, 0], diagElem[0]], [dgm2[j, 1], diagElem[1]], "g")
281+
elif j == -1:
282+
diagElem = np.array([dgm1Rot[i, 0], 0])
283+
diagElem = diagElem.dot(R.T)
284+
ax.plot([dgm1[i, 0], diagElem[0]], [dgm1[i, 1], diagElem[1]], "g")
285+
else:
286+
ax.plot([dgm1[i, 0], dgm2[j, 0]], [dgm1[i, 1], dgm2[j, 1]], "g")
287+
288+
plot_diagrams([dgm1, dgm2], labels=labels, ax=ax)

persim/wasserstein.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def wasserstein(dgm1, dgm2, matching=False):
7070
if N == 0:
7171
T = np.array([[0, 0]])
7272
N = 1
73-
# Step 1: Compute CSM between S and dgm2, including points on diagonal
73+
# Compute CSM between S and dgm2, including points on diagonal
7474
DUL = metrics.pairwise.pairwise_distances(S, T)
7575

7676
# Put diagonal elements into the matrix
@@ -81,21 +81,29 @@ def wasserstein(dgm1, dgm2, matching=False):
8181
R = np.array([[cp, -sp], [sp, cp]])
8282
S = S[:, 0:2].dot(R)
8383
T = T[:, 0:2].dot(R)
84-
D = np.zeros((M+N, M+N))
84+
D = np.inf*np.ones((M+N, M+N))
85+
np.fill_diagonal(D, 0)
8586
D[0:M, 0:N] = DUL
86-
UR = np.max(D)*np.ones((M, M))
87+
UR = np.inf*np.ones((M, M))
8788
np.fill_diagonal(UR, S[:, 1])
8889
D[0:M, N:N+M] = UR
89-
UL = np.max(D)*np.ones((N, N))
90+
UL = np.inf*np.ones((N, N))
9091
np.fill_diagonal(UL, T[:, 1])
9192
D[M:N+M, 0:N] = UL
93+
print(D)
9294

9395
# Step 2: Run the hungarian algorithm
9496
matchi, matchj = optimize.linear_sum_assignment(D)
9597
matchdist = np.sum(D[matchi, matchj])
9698

9799
if matching:
98100
matchidx = [(i, j) for i, j in zip(matchi, matchj)]
99-
return matchdist, (matchidx, D)
101+
ret = np.zeros((len(matchidx), 3))
102+
ret[:, 0:2] = np.array(matchidx)
103+
ret[:, 2] = D[matchi, matchj]
104+
# Indicate diagonally matched points
105+
ret[ret[:, 0] >= M, 0] = -1
106+
ret[ret[:, 1] >= N, 1] = -1
107+
return matchdist, ret
100108

101109
return matchdist

test/test_distances.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,14 +75,15 @@ def test_matching(self):
7575
[1.0, 1.1],
7676
])
7777

78-
d, (m, D) = bottleneck(
78+
d, m = bottleneck(
7979
dgm1, dgm2,
8080
matching=True
8181
)
82-
83-
# These are very loose bounds
84-
assert len(m) == len(dgm1) + len(dgm2)
85-
assert D.shape == (len(dgm1) + len(dgm2), len(dgm1) + len(dgm2))
82+
u1 = np.unique(m[:, 0])
83+
u1 = u1[u1 >= 0]
84+
u2 = np.unique(m[:, 1])
85+
u2 = u2[u2 >= 0]
86+
assert u1.size == dgm1.shape[0] and u2.size == dgm2.shape[0]
8687

8788
def test_matching_to_self(self):
8889
# Matching a diagram to itself should yield 0

test/test_visuals.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,8 @@ def test_bottleneck_matching(self):
198198
[0.3, 0.45]
199199
])
200200

201-
d, (matching, D) = persim.bottleneck(dgm1, dgm2, matching=True)
202-
persim.bottleneck_matching(dgm1, dgm2, matching, D)
201+
d, matching = persim.bottleneck(dgm1, dgm2, matching=True)
202+
persim.bottleneck_matching(dgm1, dgm2, matching)
203203

204204
def test_plot_labels(self):
205205
dgm1 = np.array([
@@ -211,6 +211,6 @@ def test_plot_labels(self):
211211
[0.3, 0.45]
212212
])
213213

214-
d, (matching, D) = persim.bottleneck(dgm1, dgm2, matching=True)
214+
d, matching = persim.bottleneck(dgm1, dgm2, matching=True)
215215
persim.bottleneck_matching(
216-
dgm1, dgm2, matching, D, labels=["X", "Y"])
216+
dgm1, dgm2, matching, labels=["X", "Y"])

0 commit comments

Comments
 (0)