Skip to content

Commit 4533188

Browse files
Googlertensorflower-gardener
authored andcommitted
Changing implementation to one that works in all modes.
PiperOrigin-RevId: 463143711
1 parent 943de2c commit 4533188

File tree

4 files changed

+39
-211
lines changed

4 files changed

+39
-211
lines changed

tensorflow_probability/python/stats/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ multi_substrate_py_library(
3333
name = "stats",
3434
srcs = ["__init__.py"],
3535
jax_omit_deps = [
36-
":kendalls_tau",
3736
":ranking",
3837
],
3938
deps = [

tensorflow_probability/python/stats/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,7 @@
2020
from tensorflow_probability.python.stats.calibration import brier_score
2121
from tensorflow_probability.python.stats.calibration import expected_calibration_error
2222
from tensorflow_probability.python.stats.calibration import expected_calibration_error_quantiles
23-
from tensorflow_probability.python.stats.kendalls_tau import iterative_mergesort
2423
from tensorflow_probability.python.stats.kendalls_tau import kendalls_tau
25-
from tensorflow_probability.python.stats.kendalls_tau import lexicographical_indirect_sort
2624
from tensorflow_probability.python.stats.leave_one_out import log_loomean_exp
2725
from tensorflow_probability.python.stats.leave_one_out import log_loosum_exp
2826
from tensorflow_probability.python.stats.leave_one_out import log_soomean_exp
@@ -65,9 +63,7 @@
6563
'expected_calibration_error_quantiles',
6664
'find_bins',
6765
'histogram',
68-
'iterative_mergesort',
6966
'kendalls_tau',
70-
'lexicographical_indirect_sort',
7167
'log_average_probs',
7268
'log_loomean_exp',
7369
'log_loosum_exp',

tensorflow_probability/python/stats/kendalls_tau.py

Lines changed: 39 additions & 187 deletions
Original file line numberDiff line numberDiff line change
@@ -21,149 +21,42 @@
2121
from tensorflow_probability.python.internal import prefer_static as ps
2222
from tensorflow_probability.python.internal import tensorshape_util
2323

24-
__all__ = ['iterative_mergesort', 'kendalls_tau']
24+
__all__ = ['kendalls_tau']
2525

2626

27-
def iterative_mergesort(y, permutation, name=None):
28-
"""Non-recusive mergesort that counts exchanges.
27+
def _tril_indices(n):
28+
"""Emulate np.tril_indices(n, k=-1).
29+
30+
This method ensures static shapes throughout (ie, XLA compilable).
31+
This method only works for n <= 30000.
2932
3033
Args:
31-
y: a `Tensor` of shape `[n]` containing values to be sorted.
32-
permutation: `Tensor` of shape `[n]` with original ordering.
33-
name: Optional Python `str` name for ops created by this method.
34-
Default value: `None` (i.e., 'iterative_mergesort').
34+
n: number elements to generate all pairs
3535
3636
Returns:
37-
exchanges: `int32` scalar that counts the number of exchanges required to
38-
produce a sorted permutation
39-
permutation: and a `tf.int32` Tensor that contains the ordering of y values
40-
that are sorted.
37+
A [2, n * (n - 1) / 2] vector of all combinations of range(n).
4138
"""
39+
n = tf.convert_to_tensor(n, dtype_hint=tf.int32)
40+
# Number of lower triangular entries in an nxn matrix
41+
m = (n - 1) * n / 2
42+
r = tf.cast(tf.range(m), dtype=tf.float64)
4243

43-
with tf.name_scope(name or 'iterative_mergesort'):
44-
y = tf.convert_to_tensor(y, name='y')
45-
permutation = tf.convert_to_tensor(
46-
permutation, name='permutation', dtype=tf.int32)
47-
shape = permutation.shape
48-
tensorshape_util.assert_is_compatible_with(y.shape, shape)
49-
n = ps.size(y)
50-
51-
def outer_body(k, exchanges, permutation):
52-
# The outer body progressively merges lists as k grows by powers of 2,
53-
# tracking the total swaps required in exchanges as the new permutation is
54-
# built in place.
55-
y_ordered = tf.gather(y, permutation)
56-
57-
def middle_body(left, exchanges, permutation):
58-
# the middle body advances through the sublists of size k, advancing
59-
# the left edge until the end of the input is reached.
60-
right = left + k
61-
end = tf.minimum(right + k, n)
62-
63-
# See explanation here
64-
# https://www.geeksforgeeks.org/counting-inversions/.
65-
66-
def inner_body(i, j, x, np, p):
67-
# The [left, right) and [right, end) lists are merged sorted, with
68-
# i and j tracking the advance through each range. x records the
69-
# number of order (bubble-sort equivalent) swaps that are happening
70-
# with each insertion, and np represents the size of the output
71-
# permutation that's been filled in using the p tensor.
72-
y_less = y_ordered[i] <= y_ordered[j]
73-
element = tf.where(y_less, [permutation[i]], [permutation[j]])
74-
new_p = tf.concat([p[0:np], element, p[np + 1:n]], axis=0)
75-
tensorshape_util.set_shape(new_p, p.shape)
76-
return (tf.where(y_less, i + 1, i), tf.where(y_less, j, j + 1),
77-
tf.where(y_less, x, x + right - i), np + 1, new_p)
78-
79-
i_j_x_np_p = (left, right, exchanges, 0, tf.zeros([n], dtype=tf.int32))
80-
(i, j, exchanges, np, p) = tf.while_loop(
81-
cond=lambda i, j, x, np, p: tf.math.logical_and(i < right, j < end),
82-
body=inner_body,
83-
loop_vars=i_j_x_np_p)
84-
permutation = tf.concat([
85-
permutation[0:left], p[0:np], permutation[i:right],
86-
permutation[j:end], permutation[end:n]
87-
],
88-
axis=0)
89-
tensorshape_util.set_shape(permutation, shape)
90-
return left + 2 * k, exchanges, permutation
91-
92-
_, exchanges, permutation = tf.while_loop(
93-
cond=lambda left, exchanges, permutation: left < n - k,
94-
body=middle_body,
95-
loop_vars=(0, exchanges, permutation))
96-
k *= 2
97-
return k, exchanges, permutation
98-
99-
_, exchanges, permutation = tf.while_loop(
100-
cond=lambda k, exchanges, permutation: k < n,
101-
body=outer_body,
102-
loop_vars=(1, 0, permutation))
103-
return exchanges, permutation
104-
105-
106-
def lexicographical_indirect_sort(primary, secondary, name=None):
107-
"""Sorts by primary, then by secondary returning the indices.
44+
# From Sloane: https://oeis.org/A002024 "k appears k times"
45+
# e.g., [1, 2, 2, 3, 3, 3, 4, 4, 4, 4, ...]
46+
e = tf.math.floor(tf.math.sqrt(2 * (r + 1)) + .5)
10847

109-
Args:
110-
primary: a `Tensor` of shape `[n]` containing the primary sort key. the
111-
primary sort key value.
112-
secondary: a `Tensor` of shape `[n]` containing the secondary sort key to be
113-
used when the primary keys are identical.
114-
name: Optional Python `str` name for ops created by this method.
115-
Default value: `None` (i.e., 'lexicographical_indirect_sort').
48+
# From Sloane: https://oeis.org/A002262 "Triangle read by rows"
49+
# e.g., [0, 0, 1, 0, 1, 2, 0, 1, 2, 3, ...]
50+
f = tf.math.floor(tf.math.sqrt(2 * r + .25) - .5)
51+
g = r - f * (f + 1) / 2
11652

117-
Returns:
118-
lexicographic: A permutation of range(n) that provides the sorted primary,
119-
then secondary values.
120-
"""
121-
with tf.name_scope(name or 'lexicographical_indirect_sort'):
122-
n = ps.size0(primary)
123-
permutation = tf.argsort(primary)
124-
# scan for ties, and for each range of ties do a argsort on
125-
# the secondary value. (TF has no lexicographical sorting, although
126-
# jax can sort complex number lexicographically. Hmm.)
127-
primary_ordered = tf.gather(primary, permutation)
128-
129-
def body(left, right, lexicographic):
130-
# We make a single pass through the list using right and left, where right
131-
# advances and left chases it looking for spans that are equal in their
132-
# primary key to then institute a sort on the secondary key.
133-
not_equal = tf.not_equal(primary_ordered[left], primary_ordered[right])
134-
135-
def secondary_sort():
136-
x = tf.concat([
137-
lexicographic[0:left],
138-
tf.gather(permutation[left:right],
139-
tf.argsort(tf.gather(secondary,
140-
permutation[left:right]))),
141-
lexicographic[right:n],
142-
],
143-
axis=0)
144-
tensorshape_util.set_shape(x, [n])
145-
return x
146-
147-
return (tf.where(not_equal, right, left), right + 1,
148-
tf.cond(not_equal, secondary_sort, lambda: lexicographic))
149-
150-
left, _, lexicographic = tf.while_loop(
151-
cond=lambda left, right, lexicographic: right < n,
152-
body=body,
153-
loop_vars=(0, 0, tf.zeros_like(permutation, dtype=tf.int32)))
154-
return tf.concat([
155-
lexicographic[0:left],
156-
tf.gather(permutation[left:n],
157-
tf.argsort(tf.gather(secondary, permutation[left:n])))
158-
],
159-
axis=0)
53+
return tf.cast(tf.stack([e, g]), dtype=tf.int32)
16054

16155

16256
def kendalls_tau(y_true, y_pred, name=None):
16357
"""Computes Kendall's Tau for two ordered lists.
16458
165-
Kendall's Tau measures the correlation between ordinal rankings. This
166-
implementation is similar to the one used in scipy.stats.kendalltau.
59+
Kendall's Tau measures the correlation between ordinal rankings.
16760
The provided values may be of any type that is sortable, with the
16861
argsort indices indicating the true or proposed ordinal sequence.
16962
@@ -189,62 +82,21 @@ def kendalls_tau(y_true, y_pred, name=None):
18982
ps.size(y_true), 1, 'Ordering requires at least 2 elements.')
19083
]
19184
with tf.control_dependencies(assertions):
192-
lexa = lexicographical_indirect_sort(y_true, y_pred)
193-
194-
# See A Computer Method for Calculating Kendall's Tau with Ungrouped Data
195-
# by William Night, Journal of the American Statistical Association,
196-
# Jun., 1966, Vol. 61, No. 314, Part 1 (Jun., 1966), pp. 436-439
197-
# for notation https://www.jstor.org/stable/2282833
198-
199-
def jointly_tied_pairs_body(first, t, i):
200-
not_equal = tf.math.logical_or(
201-
tf.not_equal(y_true[lexa[first]], y_true[lexa[i]]),
202-
tf.not_equal(y_pred[lexa[first]], y_pred[lexa[i]]))
203-
return (tf.where(not_equal, i, first),
204-
tf.where(not_equal, t + ((i - first) * (i - first - 1)) // 2,
205-
t), i + 1)
206-
207-
n = ps.size0(y_true)
208-
first, t, _ = tf.while_loop(
209-
cond=lambda first, t, i: i < n,
210-
body=jointly_tied_pairs_body,
211-
loop_vars=(0, 0, 1))
212-
t += ((n - first) * (n - first - 1)) // 2
213-
214-
def ties_y_true_body(first, v, i):
215-
not_equal = tf.not_equal(y_true[lexa[first]], y_true[lexa[i]])
216-
return (tf.where(not_equal, i, first),
217-
tf.where(not_equal, v + ((i - first) * (i - first - 1)) // 2,
218-
v), i + 1)
219-
220-
first, v, _ = tf.while_loop(
221-
cond=lambda first, v, i: i < n,
222-
body=ties_y_true_body,
223-
loop_vars=(0, 0, 1))
224-
v += ((n - first) * (n - first - 1)) // 2
225-
226-
# count exchanges
227-
exchanges, newperm = iterative_mergesort(y_pred, lexa)
228-
229-
def ties_in_y_pred_body(first, u, i):
230-
not_equal = tf.not_equal(y_pred[newperm[first]], y_pred[newperm[i]])
231-
return (tf.where(not_equal, i, first),
232-
tf.where(not_equal, u + ((i - first) * (i - first - 1)) // 2,
233-
u), i + 1)
234-
235-
first, u, _ = tf.while_loop(
236-
cond=lambda first, u, i: i < n,
237-
body=ties_in_y_pred_body,
238-
loop_vars=(0, 0, 1))
239-
u += ((n - first) * (n - first - 1)) // 2
240-
n0 = (n * (n - 1)) // 2
241-
assertions = [
242-
assert_util.assert_less(v, tf.cast(n0, tf.int32),
243-
'All ranks are ties for y_true.'),
244-
assert_util.assert_less(u, tf.cast(n0, tf.int32),
245-
'All ranks are ties for y_pred.')
246-
]
247-
with tf.control_dependencies(assertions):
248-
return (tf.cast(n0 - (u + v - t), tf.float32) -
249-
2.0 * tf.cast(exchanges, tf.float32)) / tf.math.sqrt(
250-
tf.cast(n0 - v, tf.float32) * tf.cast(n0 - u, tf.float32))
85+
n = ps.size0(y_true)
86+
indices = _tril_indices(n)
87+
dxij = tf.sign(
88+
tf.gather(y_true, indices[0]) - tf.gather(y_true, indices[1]))
89+
dyij = tf.sign(
90+
tf.gather(y_pred, indices[0]) - tf.gather(y_pred, indices[1]))
91+
# s is sum of concordant pairs minus discordant pairs.
92+
s = tf.cast(tf.math.reduce_sum(dxij * dyij), tf.float32)
93+
# t is the number of y_true pairs that are not ties.
94+
t = tf.math.count_nonzero(dxij, dtype=tf.float32)
95+
# u is the number of y_pred pairs that are not ties.
96+
u = tf.math.count_nonzero(dyij, dtype=tf.float32)
97+
assertions = [
98+
assert_util.assert_positive(t, 'All ranks are ties for y_true.'),
99+
assert_util.assert_positive(u, 'All ranks are ties for y_pred.')
100+
]
101+
with tf.control_dependencies(assertions):
102+
return s / tf.math.sqrt(t * u)

tensorflow_probability/python/stats/kendalls_tau_test.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,6 @@
2626
@test_util.test_all_tf_execution_regimes
2727
class KendallsTauTest(test_util.TestCase):
2828

29-
def test_iterative_mergesort(self):
30-
values = [7, 3, 9, 0, -6, 12, 54, 3, -6, 88, 1412]
31-
array = tf.constant(values, tf.int32)
32-
iperm = tf.range(len(values), dtype=tf.int32)
33-
exchanges, perm = tfp.stats.iterative_mergesort(array, iperm)
34-
expected = sorted(values)
35-
self.assertAllEqual(expected, tf.gather(array, perm))
36-
ordered, _ = tfp.stats.iterative_mergesort(array, perm)
37-
self.assertAllEqual(ordered, 0)
38-
self.assertAllEqual(exchanges, 19)
39-
4029
def test_kendall_tau(self):
4130
x1 = [12, 2, 1, 12, 2]
4231
x2 = [1, 4, 7, 1, 0]
@@ -46,14 +35,6 @@ def test_kendall_tau(self):
4635
tf.constant(x1, tf.float32), tf.constant(x2, tf.float32)))
4736
self.assertAllClose(expected, res, atol=1e-5)
4837

49-
def test_lexicographical_sort(self):
50-
primary = [12, 2, 1, 12, 2]
51-
secondary = [1, 4, 7, 1, 0]
52-
expected = [2, 4, 1, 0, 3] # Assumes stable sort.
53-
res = self.evaluate(
54-
tfp.stats.lexicographical_indirect_sort(primary, secondary))
55-
self.assertAllEqual(expected, res)
56-
5738
def test_kendall_tau_float(self):
5839
x1 = [0.12, 0.02, 0.01, 0.12, 0.02]
5940
x2 = [0.1, 0.4, 0.7, 0.1, 0.0]

0 commit comments

Comments
 (0)