@@ -195,6 +195,7 @@ def bottleneck_matching(I1, I2, matchidx, D, labels=["dgm1", "dgm2"], ax=None):
195
195
persim.bottleneck_matching(A_h1, B_h1, matchidx, D)
196
196
197
197
"""
198
+ ax = ax or plt .gca ()
198
199
199
200
plot_diagrams ([I1 , I2 ], labels = labels , ax = ax )
200
201
cp = np .cos (np .pi / 4 )
@@ -217,9 +218,9 @@ def bottleneck_matching(I1, I2, matchidx, D, labels=["dgm1", "dgm2"], ax=None):
217
218
elif j >= I2 .shape [0 ]:
218
219
diagElem = np .array ([I1Rot [i , 0 ], 0 ])
219
220
diagElem = diagElem .dot (R .T )
220
- plt .plot ([I1 [i , 0 ], diagElem [0 ]], [I1 [i , 1 ], diagElem [1 ]], "g" )
221
+ ax .plot ([I1 [i , 0 ], diagElem [0 ]], [I1 [i , 1 ], diagElem [1 ]], "g" )
221
222
else :
222
- plt .plot ([I1 [i , 0 ], I2 [j , 0 ]], [I1 [i , 1 ], I2 [j , 1 ]], "g" )
223
+ ax .plot ([I1 [i , 0 ], I2 [j , 0 ]], [I1 [i , 1 ], I2 [j , 1 ]], "g" )
223
224
224
225
225
226
def wasserstein_matching (I1 , I2 , matchidx , palette = None , labels = ["dgm1" , "dgm2" ], colors = None , ax = None ):
@@ -246,6 +247,7 @@ def wasserstein_matching(I1, I2, matchidx, palette=None, labels=["dgm1", "dgm2"]
246
247
persim.wasserstein_matching(A_h1, B_h1, matchidx, D)
247
248
248
249
"""
250
+ ax = ax or plt .gca ()
249
251
250
252
cp = np .cos (np .pi / 4 )
251
253
sp = np .sin (np .pi / 4 )
@@ -267,8 +269,8 @@ def wasserstein_matching(I1, I2, matchidx, palette=None, labels=["dgm1", "dgm2"]
267
269
elif j >= I2 .shape [0 ]:
268
270
diagElem = np .array ([I1Rot [i , 0 ], 0 ])
269
271
diagElem = diagElem .dot (R .T )
270
- plt .plot ([I1 [i , 0 ], diagElem [0 ]], [I1 [i , 1 ], diagElem [1 ]], "g" )
272
+ ax .plot ([I1 [i , 0 ], diagElem [0 ]], [I1 [i , 1 ], diagElem [1 ]], "g" )
271
273
else :
272
- plt .plot ([I1 [i , 0 ], I2 [j , 0 ]], [I1 [i , 1 ], I2 [j , 1 ]], "g" )
274
+ ax .plot ([I1 [i , 0 ], I2 [j , 0 ]], [I1 [i , 1 ], I2 [j , 1 ]], "g" )
273
275
274
276
plot_diagrams ([I1 , I2 ], labels = labels , ax = ax )
0 commit comments