1
+ """
2
+
3
+ Implementation of the Wasserstein distance using
4
+ the Hungarian algorithm
5
+
6
+ Author: Chris Tralie
7
+
8
+ """
1
9
import numpy as np
2
10
from sklearn import metrics
3
11
from scipy import optimize
12
+ import warnings
4
13
5
14
__all__ = ["wasserstein" ]
6
15
@@ -22,7 +31,7 @@ def wasserstein(dgm1, dgm2, matching=False):
22
31
dgm2: Nx(>=2)
23
32
array of birth/death paris for PD 2
24
33
matching: bool, default False
25
- if True, return matching infromation and cross-similarity matrix
34
+ if True, return matching information and cross-similarity matrix
26
35
27
36
Returns
28
37
---------
@@ -34,34 +43,52 @@ def wasserstein(dgm1, dgm2, matching=False):
34
43
35
44
"""
36
45
37
- # Step 1: Compute CSM between S and dgm2, including points on diagonal
38
- N = dgm1 .shape [0 ]
39
- M = dgm2 .shape [0 ]
40
- # Handle the cases where there are no points in the diagrams
41
- if N == 0 :
42
- dgm1 = np .array ([[0 , 0 ]])
43
- N = 1
46
+ S = np .array (dgm1 )
47
+ M = min (S .shape [0 ], S .size )
48
+ if S .size > 0 :
49
+ S = S [np .isfinite (S [:, 1 ]), :]
50
+ if S .shape [0 ] < M :
51
+ warnings .warn (
52
+ "dgm1 has points with non-finite death times;" +
53
+ "ignoring those points"
54
+ )
55
+ M = S .shape [0 ]
56
+ T = np .array (dgm2 )
57
+ N = min (T .shape [0 ], T .size )
58
+ if T .size > 0 :
59
+ T = T [np .isfinite (T [:, 1 ]), :]
60
+ if T .shape [0 ] < N :
61
+ warnings .warn (
62
+ "dgm2 has points with non-finite death times;" +
63
+ "ignoring those points"
64
+ )
65
+ N = T .shape [0 ]
66
+
44
67
if M == 0 :
45
- dgm2 = np .array ([[0 , 0 ]])
68
+ S = np .array ([[0 , 0 ]])
46
69
M = 1
47
- DUL = metrics .pairwise .pairwise_distances (dgm1 , dgm2 )
70
+ if N == 0 :
71
+ T = np .array ([[0 , 0 ]])
72
+ N = 1
73
+ # Step 1: Compute CSM between S and dgm2, including points on diagonal
74
+ DUL = metrics .pairwise .pairwise_distances (S , T )
48
75
49
76
# Put diagonal elements into the matrix
50
77
# Rotate the diagrams to make it easy to find the straight line
51
78
# distance to the diagonal
52
79
cp = np .cos (np .pi / 4 )
53
80
sp = np .sin (np .pi / 4 )
54
81
R = np .array ([[cp , - sp ], [sp , cp ]])
55
- dgm1 = dgm1 [:, 0 :2 ].dot (R )
56
- dgm2 = dgm2 [:, 0 :2 ].dot (R )
57
- D = np .zeros ((N + M , N + M ))
58
- D [0 :N , 0 :M ] = DUL
59
- UR = np .max (D )* np .ones ((N , N ))
60
- np .fill_diagonal (UR , dgm1 [:, 1 ])
61
- D [0 :N , M : M + N ] = UR
62
- UL = np .max (D )* np .ones ((M , M ))
63
- np .fill_diagonal (UL , dgm2 [:, 1 ])
64
- D [N : M + N , 0 :M ] = UL
82
+ S = S [:, 0 :2 ].dot (R )
83
+ T = T [:, 0 :2 ].dot (R )
84
+ D = np .zeros ((M + N , M + N ))
85
+ D [0 :M , 0 :N ] = DUL
86
+ UR = np .max (D )* np .ones ((M , M ))
87
+ np .fill_diagonal (UR , S [:, 1 ])
88
+ D [0 :M , N : N + M ] = UR
89
+ UL = np .max (D )* np .ones ((N , N ))
90
+ np .fill_diagonal (UL , T [:, 1 ])
91
+ D [M : N + M , 0 :N ] = UL
65
92
66
93
# Step 2: Run the hungarian algorithm
67
94
matchi , matchj = optimize .linear_sum_assignment (D )
0 commit comments