Skip to content

hypergeometric 2F1 undefined for legitimate inputs #2001

@maciejskorski

Description

@maciejskorski

In my research on Gaussian Processes I needed hypergeometric function2F1 for small arguments (less than 1 in absolute value).

I have found that the current implementation of hyp2f1_small_argument is incomplete:

To reproduce, run this in Colab and compare against Scipy or Mathematica

import tensorflow_probability as tfp
hyp2f1 = tfp.math.hypergeometric.hyp2f1_small_argument

H = 1
a = tf.constant(1.0)
b = tf.constant(0.5 - H)
c = tf.constant(H + 1.5)
x = tf.constant(0.9901961)
hyp2f1(a,b,c,x) # nan, should be ~ 0.753603 

Tested under tfp==0.25.0 and tf==2.18.0 on Google Colab.

Here is the problematic code:

@tf.custom_gradient
def hyp2f1_small_argument(a, b, c, z, name=None):
"""Compute the Hypergeometric function 2f1(a, b, c, z) when |z| <= 1.
Given `a, b, c` and `z`, compute Gauss' Hypergeometric Function, specified
by the series:
`1 + (a * b/c) * z + (a * (a + 1) * b * (b + 1) / ((c * (c + 1)) * z**2 / 2 +
... (a)_n * (b)_n / (c)_n * z ** n / n! + ....`
NOTE: Gradients with only respect to `z` are available.
NOTE: It is recommended that the arguments are `float64` due to the heavy
loss of precision in float32.
Args:
a: Floating-point `Tensor`, broadcastable with `b, c, z`. Parameter for the
numerator of the series fraction.
b: Floating-point `Tensor`, broadcastable with `a, c, z`. Parameter for the
numerator of the series fraction.
c: Floating-point `Tensor`, broadcastable with `a, b, z`. Parameter for the
denominator of the series fraction.
z: Floating-point `Tensor`, broadcastable `a, b, c`. Value to compute
`2F1(a, b, c, z)` at. Only values of `|z| < 1` are allowed.
name: A name for the operation (optional).
Default value: `None` (i.e., 'continued_fraction').
Returns:
hypergeo: `2F1(a, b, c, z)`
#### References
[1] F. Johansson. Computing hypergeometric functions rigorously.
ACM Transactions on Mathematical Software, August 2019.
https://arxiv.org/abs/1606.06977
[2] J. Pearson, S. Olver, M. Porter. Numerical methods for the computation of
the confluent and Gauss hypergeometric functions.
Numerical Algorithms, August 2016.
[3] M. Abramowitz, I. Stegun. Handbook of Mathematical Functions with
Formulas, Graphs and Mathematical Tables.
"""
with tf.name_scope(name or 'hyp2f1_small_argument'):
dtype = dtype_util.common_dtype([a, b, c, z], tf.float32)
numpy_dtype = dtype_util.as_numpy_dtype(dtype)
a = tf.convert_to_tensor(a, dtype=dtype)
b = tf.convert_to_tensor(b, dtype=dtype)
c = tf.convert_to_tensor(c, dtype=dtype)
z = tf.convert_to_tensor(z, dtype=dtype)
# Mask out exceptional cases to ensure that the series transformations
# terminate fast.
safe_a, safe_b, safe_c, safe_z = _mask_exceptional_arguments(
a, b, c, z, numpy_dtype)
# TODO(b/128632717): Extend this by including transformations for:
# * Large parameter ranges. Specifically use Hypergeometric recurrences
# to decrease the parameter values. This should be done via backward
# recurrences rather than forward recurrences since those are numerically
# stable.
# * Include |z| > 1. This can be done via Hypergeometric identities that
# transform to |z| < 1.
# * Handling exceptional cases where parameters are negative integers.
# Assume that |b| > |a|. Swapping the two makes no effect on the
# calculation.
a_small = tf.where(
tf.math.abs(safe_a) > tf.math.abs(safe_b), safe_b, safe_a)
safe_b = tf.where(tf.math.abs(safe_a) > tf.math.abs(safe_b), safe_a, safe_b)
safe_a = a_small
d = safe_c - safe_a - safe_b
# Use the identity
# 2F1(a , b, c, z) = (1 - z) ** d * 2F1(c - a, c - b, c, z).
# when the numerator coefficients become smaller.
should_use_linear_transform = (
(tf.math.abs(c - a) < tf.math.abs(a)) &
(tf.math.abs(c - b) < tf.math.abs(b)))
safe_a = tf.where(should_use_linear_transform, c - a, a)
safe_b = tf.where(should_use_linear_transform, c - b, b)
# When -0.5 < z < 0.9, use approximations to Taylor Series.
safe_z_small = tf.where(
(safe_z >= 0.9) | (safe_z <= -0.5), numpy_dtype(0.), safe_z)
taylor_series = _hyp2f1_internal(safe_a, safe_b, safe_c, safe_z_small)
# When z >= 0.9 or -0.5 > z, we use hypergeometric identities to ensure
# that |z| is small.
safe_positive_z_large = tf.where(safe_z >= 0.9, safe_z, numpy_dtype(1.))
hyp2f1_z_near_one = _hyp2f1_z_near_one(
safe_a, safe_b, safe_c, safe_positive_z_large)
safe_negative_z_large = tf.where(safe_z <= -0.5, safe_z, numpy_dtype(-1.))
hyp2f1_z_near_negative_one = _hyp2f1_z_near_negative_one(
safe_a, safe_b, safe_c, safe_negative_z_large)
result = tf.where(
safe_z >= 0.9, hyp2f1_z_near_one,
tf.where(safe_z <= -0.5, hyp2f1_z_near_negative_one, taylor_series))
# Now if we applied the linear transformation identity, we need to
# add a term (1 - z) ** (c - a - b)
result = tf.where(
should_use_linear_transform,
tf.math.exp(d * tf.math.log1p(-safe_z)) * result,
result)
# Finally handle the exceptional cases.
# First when z == 1., this expression diverges if c <= a + b, and otherwise
# converges.
hyp2f1_at_one = tf.math.exp(
tf.math.lgamma(c) + tf.math.lgamma(c - a - b) -
tf.math.lgamma(c - a) - tf.math.lgamma(c - b))
sign_hyp2f1_at_one = (
_gamma_negative(c) ^ _gamma_negative(c - a - b) ^
_gamma_negative(c - a) ^ _gamma_negative(c - b))
sign_hyp2f1_at_one = -2. * tf.cast(sign_hyp2f1_at_one, dtype) + 1.
hyp2f1_at_one = hyp2f1_at_one * sign_hyp2f1_at_one
result = tf.where(
tf.math.equal(z, 1.),
tf.where(c > a + b,
hyp2f1_at_one, numpy_dtype(np.nan)),
result)
# When a == c or b == c this reduces to (1 - z)**-b (-a respectively).
result = tf.where(
tf.math.equal(a, c),
tf.math.exp(-b * tf.math.log1p(-z)),
tf.where(
tf.math.equal(b, c),
tf.math.exp(-a * tf.math.log1p(-z)), result))
# When c is a negative integer we can get a divergent series.
result = tf.where(
(_is_negative_integer(c) &
((a < c) | ~_is_negative_integer(a)) &
((b < c) | ~_is_negative_integer(b))),
numpy_dtype(np.inf),
result)
def grad(dy):
grad_z = a * b * dy * hyp2f1_small_argument(
a + 1., b + 1., c + 1., z) / c
# We don't have an easily computable gradient with respect to parameters,
# so ignore that for now.
broadcast_shape = functools.reduce(
ps.broadcast_shape,
[ps.shape(x) for x in [a, b, c]])
_, grad_z = tfp_math.fix_gradient_for_broadcasting(
[tf.ones(broadcast_shape, dtype=z.dtype), z],
[tf.ones_like(grad_z), grad_z])
return None, None, None, grad_z
return result, grad

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions