|
19 | 19 |
|
20 | 20 |
|
21 | 21 | __all__ = [
|
| 22 | + "estimate_tails", |
22 | 23 | "quantization_offset",
|
23 | 24 | "lower_tail",
|
24 | 25 | "upper_tail",
|
25 | 26 | ]
|
26 | 27 |
|
27 | 28 |
|
28 |
| -def estimate_tail(func, target, shape, dtype): |
29 |
| - """Estimates approximate tail quantiles.""" |
30 |
| - dtype = tf.as_dtype(dtype) |
31 |
| - shape = tf.convert_to_tensor(shape, tf.int32) |
32 |
| - target = tf.convert_to_tensor(target, dtype) |
33 |
| - opt = tf.keras.optimizers.Adam(learning_rate=.1) |
34 |
| - tails = tf.Variable( |
35 |
| - tf.zeros(shape, dtype=dtype), trainable=False, name="tails") |
36 |
| - loss = best_loss = tf.fill(shape, tf.constant(float("inf"), dtype=dtype)) |
37 |
| - while tf.reduce_any(loss == best_loss): |
38 |
| - with tf.GradientTape(watch_accessed_variables=False) as tape: |
39 |
| - tape.watch(tails) |
40 |
| - loss = abs(func(tails) - target) |
41 |
| - best_loss = tf.minimum(best_loss, loss) |
42 |
| - gradient = tape.gradient(loss, tails) |
43 |
| - opt.apply_gradients([(gradient, tails)]) |
44 |
| - return tails.value() |
| 29 | +# TODO(jonycgn): Consider wrapping in tf.function. |
| 30 | +def estimate_tails(func, target, shape, dtype): |
| 31 | + """Estimates approximate tail quantiles. |
| 32 | +
|
| 33 | + This runs a simple Adam iteration to determine tail quantiles. The |
| 34 | + objective is to find an `x` such that: |
| 35 | + ``` |
| 36 | + func(x) == target |
| 37 | + ``` |
| 38 | + For instance, if `func` is a CDF and the target is a quantile value, this |
| 39 | + would find the approximate location of that quantile. Note that `func` is |
| 40 | + assumed to be monotonic. When each tail estimate has passed the optimal value |
| 41 | + of `x`, the algorithm does 10 additional iterations and then stops. |
| 42 | +
|
| 43 | + This operation is vectorized. The tensor shape of `x` is given by `shape`, and |
| 44 | + `target` must have a shape that is broadcastable to the output of `func(x)`. |
| 45 | +
|
| 46 | + Arguments: |
| 47 | + func: A callable that computes cumulative distribution function, survival |
| 48 | + function, or similar. |
| 49 | + target: The desired target value. |
| 50 | + shape: The shape of the `tf.Tensor` representing `x`. |
| 51 | + dtype: The `tf.dtypes.Dtype` of the computation (and the return value). |
| 52 | +
|
| 53 | + Returns: |
| 54 | + A `tf.Tensor` representing the solution (`x`). |
| 55 | + """ |
| 56 | + with tf.name_scope("estimate_tails"): |
| 57 | + dtype = tf.as_dtype(dtype) |
| 58 | + shape = tf.convert_to_tensor(shape, tf.int32) |
| 59 | + target = tf.convert_to_tensor(target, dtype) |
| 60 | + |
| 61 | + def loop_cond(tails, m, v, count): |
| 62 | + del tails, m, v # unused |
| 63 | + return tf.reduce_min(count) < 10 |
| 64 | + |
| 65 | + def loop_body(tails, m, v, count): |
| 66 | + with tf.GradientTape(watch_accessed_variables=False) as tape: |
| 67 | + tape.watch(tails) |
| 68 | + loss = abs(func(tails) - target) |
| 69 | + grad = tape.gradient(loss, tails) |
| 70 | + m = .5 * m + .5 * grad # Adam mean estimate. |
| 71 | + v = .9 * v + .1 * tf.square(grad) # Adam variance estimate. |
| 72 | + tails -= .5 * m / (tf.sqrt(v) + 1e-7) |
| 73 | + # Start counting when the gradient flips sign (note that this assumes |
| 74 | + # `tails` is initialized to zero). |
| 75 | + count = tf.where( |
| 76 | + tf.math.logical_or(count > 0, tails * grad > 0), |
| 77 | + count + 1, count) |
| 78 | + return tails, m, v, count |
| 79 | + |
| 80 | + init_tails = tf.zeros(shape, dtype=dtype) |
| 81 | + init_m = tf.zeros(shape, dtype=dtype) |
| 82 | + init_v = tf.ones(shape, dtype=dtype) |
| 83 | + init_count = tf.zeros(shape, dtype=tf.int32) |
| 84 | + return tf.while_loop( |
| 85 | + loop_cond, loop_body, (init_tails, init_m, init_v, init_count), |
| 86 | + back_prop=False)[0] |
45 | 87 |
|
46 | 88 |
|
47 | 89 | def quantization_offset(distribution):
|
@@ -113,7 +155,7 @@ def lower_tail(distribution, tail_mass):
|
113 | 155 | tail = distribution.quantile(tail_mass / 2)
|
114 | 156 | except NotImplementedError:
|
115 | 157 | try:
|
116 |
| - tail = estimate_tail( |
| 158 | + tail = estimate_tails( |
117 | 159 | distribution.log_cdf, tf.math.log(tail_mass / 2),
|
118 | 160 | distribution.batch_shape_tensor(), distribution.dtype)
|
119 | 161 | except NotImplementedError:
|
@@ -149,7 +191,7 @@ def upper_tail(distribution, tail_mass):
|
149 | 191 | tail = distribution.quantile(1 - tail_mass / 2)
|
150 | 192 | except NotImplementedError:
|
151 | 193 | try:
|
152 |
| - tail = estimate_tail( |
| 194 | + tail = estimate_tails( |
153 | 195 | distribution.log_survival_function, tf.math.log(tail_mass / 2),
|
154 | 196 | distribution.batch_shape_tensor(), distribution.dtype)
|
155 | 197 | except NotImplementedError:
|
|
0 commit comments