Skip to content

Commit bf64346

Browse files
committed
Wasserstein is back on track now
1 parent cc5c351 commit bf64346

File tree

4 files changed

+32
-5
lines changed

4 files changed

+32
-5
lines changed

RELEASE.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
0.3.1
2+
- Fixed bug with repeated intervals in bottleneck
3+
- Tidied up API for indicating matchings for bottleneck and wasserstein, and updated notebook
4+
15
0.3.0
26
- Add implementations of Persistence Landscapes, including plotting methods, a transformer, and additional notebooks.
37

persim/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.3.0"
1+
__version__ = "0.3.1"

persim/wasserstein.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def wasserstein(dgm1, dgm2, matching=False):
8181
R = np.array([[cp, -sp], [sp, cp]])
8282
S = S[:, 0:2].dot(R)
8383
T = T[:, 0:2].dot(R)
84-
D = np.inf*np.ones((M+N, M+N))
84+
D = np.zeros((M+N, M+N))
8585
np.fill_diagonal(D, 0)
8686
D[0:M, 0:N] = DUL
8787
UR = np.inf*np.ones((M, M))
@@ -90,7 +90,6 @@ def wasserstein(dgm1, dgm2, matching=False):
9090
UL = np.inf*np.ones((N, N))
9191
np.fill_diagonal(UL, T[:, 1])
9292
D[M:N+M, 0:N] = UL
93-
print(D)
9493

9594
# Step 2: Run the hungarian algorithm
9695
matchi, matchj = optimize.linear_sum_assignment(D)
@@ -104,6 +103,8 @@ def wasserstein(dgm1, dgm2, matching=False):
104103
# Indicate diagonally matched points
105104
ret[ret[:, 0] >= M, 0] = -1
106105
ret[ret[:, 1] >= N, 1] = -1
106+
# Exclude diagonal to diagonal
107+
ret = ret[ret[:, 0] + ret[:, 1] != -2, :]
107108
return matchdist, ret
108109

109110
return matchdist

test/test_distances.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,10 +183,32 @@ def test_inf_deathtime(self):
183183
assert (np.allclose(dist1, np.sqrt(2)/2)) and (np.allclose(dist2, np.sqrt(2)/2))
184184

185185
def test_repeated(self):
186-
dgm1 = np.array([[0, 10], [0,10], [0,11]])
186+
dgm1 = np.array([[0, 10], [0,10]])
187187
dgm2 = np.array([[0, 10]])
188188
dist = wasserstein(dgm1, dgm2)
189-
assert dist == 10.5
189+
assert dist == 5*np.sqrt(2)
190+
191+
def test_matching(self):
192+
dgm1 = np.array([
193+
[0.5, 1],
194+
[0.6, 1.1]
195+
])
196+
dgm2 = np.array([
197+
[0.5, 1.1],
198+
[0.6, 1.1],
199+
[0.8, 1.1],
200+
[1.0, 1.1],
201+
])
202+
203+
d, m = wasserstein(
204+
dgm1, dgm2,
205+
matching=True
206+
)
207+
u1 = np.unique(m[:, 0])
208+
u1 = u1[u1 >= 0]
209+
u2 = np.unique(m[:, 1])
210+
u2 = u2[u2 >= 0]
211+
assert u1.size == dgm1.shape[0] and u2.size == dgm2.shape[0]
190212

191213

192214
class TestSliced:

0 commit comments

Comments
 (0)