@@ -168,20 +168,21 @@ def plot_diagrams(
168
168
def plot_a_bar (p , q , c = 'b' , linestyle = '-' ):
169
169
plt .plot ([p [0 ], q [0 ]], [p [1 ], q [1 ]], c = c , linestyle = linestyle , linewidth = 1 )
170
170
171
- def bottleneck_matching (I1 , I2 , matchidx , D , labels = ["dgm1" , "dgm2" ], ax = None ):
171
+ def bottleneck_matching (dgm1 , dgm2 , matching , labels = ["dgm1" , "dgm2" ], ax = None ):
172
172
""" Visualize bottleneck matching between two diagrams
173
173
174
174
Parameters
175
175
===========
176
176
177
- I1: array
178
- A diagram
179
- I2: array
180
- A diagram
181
- matchidx: tuples of matched indices
182
- if input `matching=True`, then return matching
183
- D: array
184
- cross-similarity matrix
177
+ dgm1: Mx(>=2)
178
+ array of birth/death pairs for PD 1
179
+ dgm2: Nx(>=2)
180
+ array of birth/death paris for PD 2
181
+ matching: ndarray(Mx+Nx, 3)
182
+ A list of correspondences in an optimal matching, as well as their distance, where:
183
+ * First column is index of point in first persistence diagram, or -1 if diagonal
184
+ * Second column is index of point in second persistence diagram, or -1 if diagonal
185
+ * Third column is the distance of each matching
185
186
labels: list of strings
186
187
names of diagrams for legend. Default = ["dgm1", "dgm2"],
187
188
ax: matplotlib Axis object
@@ -191,50 +192,61 @@ def bottleneck_matching(I1, I2, matchidx, D, labels=["dgm1", "dgm2"], ax=None):
191
192
Examples
192
193
==========
193
194
194
- bn_matching, (matchidx, D) = persim.bottleneck(A_h1, B_h1, matching=True)
195
- persim.bottleneck_matching(A_h1, B_h1, matchidx, D )
195
+ dist, matching = persim.bottleneck(A_h1, B_h1, matching=True)
196
+ persim.bottleneck_matching(A_h1, B_h1, matching )
196
197
197
198
"""
198
199
ax = ax or plt .gca ()
199
200
200
- plot_diagrams ([I1 , I2 ], labels = labels , ax = ax )
201
+ plot_diagrams ([dgm1 , dgm2 ], labels = labels , ax = ax )
201
202
cp = np .cos (np .pi / 4 )
202
203
sp = np .sin (np .pi / 4 )
203
204
R = np .array ([[cp , - sp ], [sp , cp ]])
204
- if I1 .size == 0 :
205
- I1 = np .array ([[0 , 0 ]])
206
- if I2 .size == 0 :
207
- I2 = np .array ([[0 , 0 ]])
208
- I1Rot = I1 .dot (R )
209
- I2Rot = I2 .dot (R )
210
- dists = [D [i , j ] for (i , j ) in matchidx ]
211
- (i , j ) = matchidx [np .argmax (dists )]
212
- if i >= I1 .shape [0 ] and j >= I2 .shape [0 ]:
213
- return
214
- if i >= I1 .shape [0 ]:
215
- diagElem = np .array ([I2Rot [j , 0 ], 0 ])
216
- diagElem = diagElem .dot (R .T )
217
- plt .plot ([I2 [j , 0 ], diagElem [0 ]], [I2 [j , 1 ], diagElem [1 ]], "g" )
218
- elif j >= I2 .shape [0 ]:
219
- diagElem = np .array ([I1Rot [i , 0 ], 0 ])
220
- diagElem = diagElem .dot (R .T )
221
- ax .plot ([I1 [i , 0 ], diagElem [0 ]], [I1 [i , 1 ], diagElem [1 ]], "g" )
222
- else :
223
- ax .plot ([I1 [i , 0 ], I2 [j , 0 ]], [I1 [i , 1 ], I2 [j , 1 ]], "g" )
224
-
225
-
226
- def wasserstein_matching (I1 , I2 , matchidx , palette = None , labels = ["dgm1" , "dgm2" ], colors = None , ax = None ):
205
+ if dgm1 .size == 0 :
206
+ dgm1 = np .array ([[0 , 0 ]])
207
+ if dgm2 .size == 0 :
208
+ dgm2 = np .array ([[0 , 0 ]])
209
+ dgm1Rot = dgm1 .dot (R )
210
+ dgm2Rot = dgm2 .dot (R )
211
+ max_idx = np .argmax (matching [:, 2 ])
212
+ for idx , [i , j , d ] in enumerate (matching ):
213
+ i = int (i )
214
+ j = int (j )
215
+ linestyle = '--'
216
+ linewidth = 1
217
+ c = 'C2'
218
+ if idx == max_idx :
219
+ linestyle = '-'
220
+ linewidth = 2
221
+ c = 'C3'
222
+ if i != - 1 or j != - 1 : # At least one point is a non-diagonal point
223
+ if i == - 1 :
224
+ diagElem = np .array ([dgm2Rot [j , 0 ], 0 ])
225
+ diagElem = diagElem .dot (R .T )
226
+ plt .plot ([dgm2 [j , 0 ], diagElem [0 ]], [dgm2 [j , 1 ], diagElem [1 ]], c , linewidth = linewidth , linestyle = linestyle )
227
+ elif j == - 1 :
228
+ diagElem = np .array ([dgm1Rot [i , 0 ], 0 ])
229
+ diagElem = diagElem .dot (R .T )
230
+ ax .plot ([dgm1 [i , 0 ], diagElem [0 ]], [dgm1 [i , 1 ], diagElem [1 ]], c , linewidth = linewidth , linestyle = linestyle )
231
+ else :
232
+ ax .plot ([dgm1 [i , 0 ], dgm2 [j , 0 ]], [dgm1 [i , 1 ], dgm2 [j , 1 ]], c , linewidth = linewidth , linestyle = linestyle )
233
+
234
+
235
+ def wasserstein_matching (dgm1 , dgm2 , matching , labels = ["dgm1" , "dgm2" ], ax = None ):
227
236
""" Visualize bottleneck matching between two diagrams
228
237
229
238
Parameters
230
239
===========
231
240
232
- I1 : array
241
+ dgm1 : array
233
242
A diagram
234
- I2 : array
243
+ dgm2 : array
235
244
A diagram
236
- matchidx: tuples of matched indices
237
- if input `matching=True`, then return matching
245
+ matching: ndarray(Mx+Nx, 3)
246
+ A list of correspondences in an optimal matching, as well as their distance, where:
247
+ * First column is index of point in first persistence diagram, or -1 if diagonal
248
+ * Second column is index of point in second persistence diagram, or -1 if diagonal
249
+ * Third column is the distance of each matching
238
250
labels: list of strings
239
251
names of diagrams for legend. Default = ["dgm1", "dgm2"],
240
252
ax: matplotlib Axis object
@@ -252,25 +264,25 @@ def wasserstein_matching(I1, I2, matchidx, palette=None, labels=["dgm1", "dgm2"]
252
264
cp = np .cos (np .pi / 4 )
253
265
sp = np .sin (np .pi / 4 )
254
266
R = np .array ([[cp , - sp ], [sp , cp ]])
255
- if I1 .size == 0 :
256
- I1 = np .array ([[0 , 0 ]])
257
- if I2 .size == 0 :
258
- I2 = np .array ([[0 , 0 ]])
259
- I1Rot = I1 .dot (R )
260
- I2Rot = I2 .dot (R )
261
- for index in matchidx :
262
- ( i , j ) = index
263
- if i >= I1 . shape [ 0 ] and j >= I2 . shape [ 0 ]:
264
- continue
265
- if i >= I1 . shape [ 0 ] :
266
- diagElem = np .array ([I2Rot [j , 0 ], 0 ])
267
- diagElem = diagElem .dot (R .T )
268
- plt .plot ([I2 [j , 0 ], diagElem [0 ]], [I2 [j , 1 ], diagElem [1 ]], "g" )
269
- elif j >= I2 . shape [ 0 ] :
270
- diagElem = np .array ([I1Rot [i , 0 ], 0 ])
271
- diagElem = diagElem .dot (R .T )
272
- ax .plot ([I1 [i , 0 ], diagElem [0 ]], [I1 [i , 1 ], diagElem [1 ]], "g" )
273
- else :
274
- ax .plot ([I1 [i , 0 ], I2 [j , 0 ]], [I1 [i , 1 ], I2 [j , 1 ]], "g" )
275
-
276
- plot_diagrams ([I1 , I2 ], labels = labels , ax = ax )
267
+ if dgm1 .size == 0 :
268
+ dgm1 = np .array ([[0 , 0 ]])
269
+ if dgm2 .size == 0 :
270
+ dgm2 = np .array ([[0 , 0 ]])
271
+ dgm1Rot = dgm1 .dot (R )
272
+ dgm2Rot = dgm2 .dot (R )
273
+ for [ i , j , d ] in matching :
274
+ i = int ( i )
275
+ j = int ( j )
276
+ if i != - 1 or j != - 1 : # At least one point is a non-diagonal point
277
+ if i == - 1 :
278
+ diagElem = np .array ([dgm2Rot [j , 0 ], 0 ])
279
+ diagElem = diagElem .dot (R .T )
280
+ plt .plot ([dgm2 [j , 0 ], diagElem [0 ]], [dgm2 [j , 1 ], diagElem [1 ]], "g" )
281
+ elif j == - 1 :
282
+ diagElem = np .array ([dgm1Rot [i , 0 ], 0 ])
283
+ diagElem = diagElem .dot (R .T )
284
+ ax .plot ([dgm1 [i , 0 ], diagElem [0 ]], [dgm1 [i , 1 ], diagElem [1 ]], "g" )
285
+ else :
286
+ ax .plot ([dgm1 [i , 0 ], dgm2 [j , 0 ]], [dgm1 [i , 1 ], dgm2 [j , 1 ]], "g" )
287
+
288
+ plot_diagrams ([dgm1 , dgm2 ], labels = labels , ax = ax )
0 commit comments