Skip to content

Commit a040a58

Browse files
authored
Merge pull request #52 from mmcdermott/master
Made `ax` arg for plot functions actually be used.
2 parents fe5af88 + 9f47c21 commit a040a58

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

persim/visuals.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ def bottleneck_matching(I1, I2, matchidx, D, labels=["dgm1", "dgm2"], ax=None):
195195
persim.bottleneck_matching(A_h1, B_h1, matchidx, D)
196196
197197
"""
198+
ax = ax or plt.gca()
198199

199200
plot_diagrams([I1, I2], labels=labels, ax=ax)
200201
cp = np.cos(np.pi / 4)
@@ -217,9 +218,9 @@ def bottleneck_matching(I1, I2, matchidx, D, labels=["dgm1", "dgm2"], ax=None):
217218
elif j >= I2.shape[0]:
218219
diagElem = np.array([I1Rot[i, 0], 0])
219220
diagElem = diagElem.dot(R.T)
220-
plt.plot([I1[i, 0], diagElem[0]], [I1[i, 1], diagElem[1]], "g")
221+
ax.plot([I1[i, 0], diagElem[0]], [I1[i, 1], diagElem[1]], "g")
221222
else:
222-
plt.plot([I1[i, 0], I2[j, 0]], [I1[i, 1], I2[j, 1]], "g")
223+
ax.plot([I1[i, 0], I2[j, 0]], [I1[i, 1], I2[j, 1]], "g")
223224

224225

225226
def wasserstein_matching(I1, I2, matchidx, palette=None, labels=["dgm1", "dgm2"], colors=None, ax=None):
@@ -246,6 +247,7 @@ def wasserstein_matching(I1, I2, matchidx, palette=None, labels=["dgm1", "dgm2"]
246247
persim.wasserstein_matching(A_h1, B_h1, matchidx, D)
247248
248249
"""
250+
ax = ax or plt.gca()
249251

250252
cp = np.cos(np.pi / 4)
251253
sp = np.sin(np.pi / 4)
@@ -267,8 +269,8 @@ def wasserstein_matching(I1, I2, matchidx, palette=None, labels=["dgm1", "dgm2"]
267269
elif j >= I2.shape[0]:
268270
diagElem = np.array([I1Rot[i, 0], 0])
269271
diagElem = diagElem.dot(R.T)
270-
plt.plot([I1[i, 0], diagElem[0]], [I1[i, 1], diagElem[1]], "g")
272+
ax.plot([I1[i, 0], diagElem[0]], [I1[i, 1], diagElem[1]], "g")
271273
else:
272-
plt.plot([I1[i, 0], I2[j, 0]], [I1[i, 1], I2[j, 1]], "g")
274+
ax.plot([I1[i, 0], I2[j, 0]], [I1[i, 1], I2[j, 1]], "g")
273275

274276
plot_diagrams([I1, I2], labels=labels, ax=ax)

0 commit comments

Comments
 (0)