21
21
from tensorflow_probability .python .internal import prefer_static as ps
22
22
from tensorflow_probability .python .internal import tensorshape_util
23
23
24
- __all__ = ['iterative_mergesort' , ' kendalls_tau' ]
24
+ __all__ = ['kendalls_tau' ]
25
25
26
26
27
- def iterative_mergesort (y , permutation , name = None ):
28
- """Non-recusive mergesort that counts exchanges.
27
+ def _tril_indices (n ):
28
+ """Emulate np.tril_indices(n, k=-1).
29
+
30
+ This method ensures static shapes throughout (ie, XLA compilable).
31
+ This method only works for n <= 30000.
29
32
30
33
Args:
31
- y: a `Tensor` of shape `[n]` containing values to be sorted.
32
- permutation: `Tensor` of shape `[n]` with original ordering.
33
- name: Optional Python `str` name for ops created by this method.
34
- Default value: `None` (i.e., 'iterative_mergesort').
34
+ n: number elements to generate all pairs
35
35
36
36
Returns:
37
- exchanges: `int32` scalar that counts the number of exchanges required to
38
- produce a sorted permutation
39
- permutation: and a `tf.int32` Tensor that contains the ordering of y values
40
- that are sorted.
37
+ A [2, n * (n - 1) / 2] vector of all combinations of range(n).
41
38
"""
39
+ n = tf .convert_to_tensor (n , dtype_hint = tf .int32 )
40
+ # Number of lower triangular entries in an nxn matrix
41
+ m = (n - 1 ) * n / 2
42
+ r = tf .cast (tf .range (m ), dtype = tf .float64 )
42
43
43
- with tf .name_scope (name or 'iterative_mergesort' ):
44
- y = tf .convert_to_tensor (y , name = 'y' )
45
- permutation = tf .convert_to_tensor (
46
- permutation , name = 'permutation' , dtype = tf .int32 )
47
- shape = permutation .shape
48
- tensorshape_util .assert_is_compatible_with (y .shape , shape )
49
- n = ps .size (y )
50
-
51
- def outer_body (k , exchanges , permutation ):
52
- # The outer body progressively merges lists as k grows by powers of 2,
53
- # tracking the total swaps required in exchanges as the new permutation is
54
- # built in place.
55
- y_ordered = tf .gather (y , permutation )
56
-
57
- def middle_body (left , exchanges , permutation ):
58
- # the middle body advances through the sublists of size k, advancing
59
- # the left edge until the end of the input is reached.
60
- right = left + k
61
- end = tf .minimum (right + k , n )
62
-
63
- # See explanation here
64
- # https://www.geeksforgeeks.org/counting-inversions/.
65
-
66
- def inner_body (i , j , x , np , p ):
67
- # The [left, right) and [right, end) lists are merged sorted, with
68
- # i and j tracking the advance through each range. x records the
69
- # number of order (bubble-sort equivalent) swaps that are happening
70
- # with each insertion, and np represents the size of the output
71
- # permutation that's been filled in using the p tensor.
72
- y_less = y_ordered [i ] <= y_ordered [j ]
73
- element = tf .where (y_less , [permutation [i ]], [permutation [j ]])
74
- new_p = tf .concat ([p [0 :np ], element , p [np + 1 :n ]], axis = 0 )
75
- tensorshape_util .set_shape (new_p , p .shape )
76
- return (tf .where (y_less , i + 1 , i ), tf .where (y_less , j , j + 1 ),
77
- tf .where (y_less , x , x + right - i ), np + 1 , new_p )
78
-
79
- i_j_x_np_p = (left , right , exchanges , 0 , tf .zeros ([n ], dtype = tf .int32 ))
80
- (i , j , exchanges , np , p ) = tf .while_loop (
81
- cond = lambda i , j , x , np , p : tf .math .logical_and (i < right , j < end ),
82
- body = inner_body ,
83
- loop_vars = i_j_x_np_p )
84
- permutation = tf .concat ([
85
- permutation [0 :left ], p [0 :np ], permutation [i :right ],
86
- permutation [j :end ], permutation [end :n ]
87
- ],
88
- axis = 0 )
89
- tensorshape_util .set_shape (permutation , shape )
90
- return left + 2 * k , exchanges , permutation
91
-
92
- _ , exchanges , permutation = tf .while_loop (
93
- cond = lambda left , exchanges , permutation : left < n - k ,
94
- body = middle_body ,
95
- loop_vars = (0 , exchanges , permutation ))
96
- k *= 2
97
- return k , exchanges , permutation
98
-
99
- _ , exchanges , permutation = tf .while_loop (
100
- cond = lambda k , exchanges , permutation : k < n ,
101
- body = outer_body ,
102
- loop_vars = (1 , 0 , permutation ))
103
- return exchanges , permutation
104
-
105
-
106
- def lexicographical_indirect_sort (primary , secondary , name = None ):
107
- """Sorts by primary, then by secondary returning the indices.
44
+ # From Sloane: https://oeis.org/A002024 "k appears k times"
45
+ # e.g., [1, 2, 2, 3, 3, 3, 4, 4, 4, 4, ...]
46
+ e = tf .math .floor (tf .math .sqrt (2 * (r + 1 )) + .5 )
108
47
109
- Args:
110
- primary: a `Tensor` of shape `[n]` containing the primary sort key. the
111
- primary sort key value.
112
- secondary: a `Tensor` of shape `[n]` containing the secondary sort key to be
113
- used when the primary keys are identical.
114
- name: Optional Python `str` name for ops created by this method.
115
- Default value: `None` (i.e., 'lexicographical_indirect_sort').
48
+ # From Sloane: https://oeis.org/A002262 "Triangle read by rows"
49
+ # e.g., [0, 0, 1, 0, 1, 2, 0, 1, 2, 3, ...]
50
+ f = tf .math .floor (tf .math .sqrt (2 * r + .25 ) - .5 )
51
+ g = r - f * (f + 1 ) / 2
116
52
117
- Returns:
118
- lexicographic: A permutation of range(n) that provides the sorted primary,
119
- then secondary values.
120
- """
121
- with tf .name_scope (name or 'lexicographical_indirect_sort' ):
122
- n = ps .size0 (primary )
123
- permutation = tf .argsort (primary )
124
- # scan for ties, and for each range of ties do a argsort on
125
- # the secondary value. (TF has no lexicographical sorting, although
126
- # jax can sort complex number lexicographically. Hmm.)
127
- primary_ordered = tf .gather (primary , permutation )
128
-
129
- def body (left , right , lexicographic ):
130
- # We make a single pass through the list using right and left, where right
131
- # advances and left chases it looking for spans that are equal in their
132
- # primary key to then institute a sort on the secondary key.
133
- not_equal = tf .not_equal (primary_ordered [left ], primary_ordered [right ])
134
-
135
- def secondary_sort ():
136
- x = tf .concat ([
137
- lexicographic [0 :left ],
138
- tf .gather (permutation [left :right ],
139
- tf .argsort (tf .gather (secondary ,
140
- permutation [left :right ]))),
141
- lexicographic [right :n ],
142
- ],
143
- axis = 0 )
144
- tensorshape_util .set_shape (x , [n ])
145
- return x
146
-
147
- return (tf .where (not_equal , right , left ), right + 1 ,
148
- tf .cond (not_equal , secondary_sort , lambda : lexicographic ))
149
-
150
- left , _ , lexicographic = tf .while_loop (
151
- cond = lambda left , right , lexicographic : right < n ,
152
- body = body ,
153
- loop_vars = (0 , 0 , tf .zeros_like (permutation , dtype = tf .int32 )))
154
- return tf .concat ([
155
- lexicographic [0 :left ],
156
- tf .gather (permutation [left :n ],
157
- tf .argsort (tf .gather (secondary , permutation [left :n ])))
158
- ],
159
- axis = 0 )
53
+ return tf .cast (tf .stack ([e , g ]), dtype = tf .int32 )
160
54
161
55
162
56
def kendalls_tau (y_true , y_pred , name = None ):
163
57
"""Computes Kendall's Tau for two ordered lists.
164
58
165
- Kendall's Tau measures the correlation between ordinal rankings. This
166
- implementation is similar to the one used in scipy.stats.kendalltau.
59
+ Kendall's Tau measures the correlation between ordinal rankings.
167
60
The provided values may be of any type that is sortable, with the
168
61
argsort indices indicating the true or proposed ordinal sequence.
169
62
@@ -189,62 +82,21 @@ def kendalls_tau(y_true, y_pred, name=None):
189
82
ps .size (y_true ), 1 , 'Ordering requires at least 2 elements.' )
190
83
]
191
84
with tf .control_dependencies (assertions ):
192
- lexa = lexicographical_indirect_sort (y_true , y_pred )
193
-
194
- # See A Computer Method for Calculating Kendall's Tau with Ungrouped Data
195
- # by William Night, Journal of the American Statistical Association,
196
- # Jun., 1966, Vol. 61, No. 314, Part 1 (Jun., 1966), pp. 436-439
197
- # for notation https://www.jstor.org/stable/2282833
198
-
199
- def jointly_tied_pairs_body (first , t , i ):
200
- not_equal = tf .math .logical_or (
201
- tf .not_equal (y_true [lexa [first ]], y_true [lexa [i ]]),
202
- tf .not_equal (y_pred [lexa [first ]], y_pred [lexa [i ]]))
203
- return (tf .where (not_equal , i , first ),
204
- tf .where (not_equal , t + ((i - first ) * (i - first - 1 )) // 2 ,
205
- t ), i + 1 )
206
-
207
- n = ps .size0 (y_true )
208
- first , t , _ = tf .while_loop (
209
- cond = lambda first , t , i : i < n ,
210
- body = jointly_tied_pairs_body ,
211
- loop_vars = (0 , 0 , 1 ))
212
- t += ((n - first ) * (n - first - 1 )) // 2
213
-
214
- def ties_y_true_body (first , v , i ):
215
- not_equal = tf .not_equal (y_true [lexa [first ]], y_true [lexa [i ]])
216
- return (tf .where (not_equal , i , first ),
217
- tf .where (not_equal , v + ((i - first ) * (i - first - 1 )) // 2 ,
218
- v ), i + 1 )
219
-
220
- first , v , _ = tf .while_loop (
221
- cond = lambda first , v , i : i < n ,
222
- body = ties_y_true_body ,
223
- loop_vars = (0 , 0 , 1 ))
224
- v += ((n - first ) * (n - first - 1 )) // 2
225
-
226
- # count exchanges
227
- exchanges , newperm = iterative_mergesort (y_pred , lexa )
228
-
229
- def ties_in_y_pred_body (first , u , i ):
230
- not_equal = tf .not_equal (y_pred [newperm [first ]], y_pred [newperm [i ]])
231
- return (tf .where (not_equal , i , first ),
232
- tf .where (not_equal , u + ((i - first ) * (i - first - 1 )) // 2 ,
233
- u ), i + 1 )
234
-
235
- first , u , _ = tf .while_loop (
236
- cond = lambda first , u , i : i < n ,
237
- body = ties_in_y_pred_body ,
238
- loop_vars = (0 , 0 , 1 ))
239
- u += ((n - first ) * (n - first - 1 )) // 2
240
- n0 = (n * (n - 1 )) // 2
241
- assertions = [
242
- assert_util .assert_less (v , tf .cast (n0 , tf .int32 ),
243
- 'All ranks are ties for y_true.' ),
244
- assert_util .assert_less (u , tf .cast (n0 , tf .int32 ),
245
- 'All ranks are ties for y_pred.' )
246
- ]
247
- with tf .control_dependencies (assertions ):
248
- return (tf .cast (n0 - (u + v - t ), tf .float32 ) -
249
- 2.0 * tf .cast (exchanges , tf .float32 )) / tf .math .sqrt (
250
- tf .cast (n0 - v , tf .float32 ) * tf .cast (n0 - u , tf .float32 ))
85
+ n = ps .size0 (y_true )
86
+ indices = _tril_indices (n )
87
+ dxij = tf .sign (
88
+ tf .gather (y_true , indices [0 ]) - tf .gather (y_true , indices [1 ]))
89
+ dyij = tf .sign (
90
+ tf .gather (y_pred , indices [0 ]) - tf .gather (y_pred , indices [1 ]))
91
+ # s is sum of concordant pairs minus discordant pairs.
92
+ s = tf .cast (tf .math .reduce_sum (dxij * dyij ), tf .float32 )
93
+ # t is the number of y_true pairs that are not ties.
94
+ t = tf .math .count_nonzero (dxij , dtype = tf .float32 )
95
+ # u is the number of y_pred pairs that are not ties.
96
+ u = tf .math .count_nonzero (dyij , dtype = tf .float32 )
97
+ assertions = [
98
+ assert_util .assert_positive (t , 'All ranks are ties for y_true.' ),
99
+ assert_util .assert_positive (u , 'All ranks are ties for y_pred.' )
100
+ ]
101
+ with tf .control_dependencies (assertions ):
102
+ return s / tf .math .sqrt (t * u )
0 commit comments