Skip to content

Commit 943de2c

Browse files
langmoretensorflower-gardener
authored andcommitted
Rectilinear grid support for TFP interpolation with batch_interp_rectilinear_nd_grid. A "rectilinear" grid is one where grid cells are rectangles. The rectangles do not have to be equal size. E.g. one dimension can be logarithmically spaced.
PiperOrigin-RevId: 462824614
1 parent dd3e117 commit 943de2c

File tree

3 files changed

+398
-47
lines changed

3 files changed

+398
-47
lines changed

tensorflow_probability/python/math/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from tensorflow_probability.python.math.gradient import value_and_gradient
4343
from tensorflow_probability.python.math.gram_schmidt import gram_schmidt
4444
from tensorflow_probability.python.math.integration import trapz
45+
from tensorflow_probability.python.math.interpolation import batch_interp_rectilinear_nd_grid
4546
from tensorflow_probability.python.math.interpolation import batch_interp_regular_1d_grid
4647
from tensorflow_probability.python.math.interpolation import batch_interp_regular_nd_grid
4748
from tensorflow_probability.python.math.interpolation import interp_regular_1d_grid
@@ -95,6 +96,7 @@
9596
_allowed_symbols = [
9697
'atan_difference',
9798
'betainc',
99+
'batch_interp_rectilinear_nd_grid',
98100
'batch_interp_regular_1d_grid',
99101
'batch_interp_regular_nd_grid',
100102
'bessel_iv_ratio',

tensorflow_probability/python/math/interpolation.py

Lines changed: 270 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
'interp_regular_1d_grid',
3030
'batch_interp_regular_1d_grid',
3131
'batch_interp_regular_nd_grid',
32+
'batch_interp_rectilinear_nd_grid',
3233
]
3334

3435

@@ -500,6 +501,9 @@ def batch_interp_regular_nd_grid(x,
500501
The interpolant is built from reference values indexed by `nd` dimensions
501502
of `y_ref`, starting at `axis`.
502503
504+
The x grid span is defined by `x_ref_min`, `x_ref_max`. The number of grid
505+
points is inferred from the shape of `y_ref`.
506+
503507
For example, take the case of a `2-D` scalar valued function and no leading
504508
batch dimensions. In this case, `y_ref.shape = [C1, C2]` and `y_ref[i, j]`
505509
is the reference value corresponding to grid point
@@ -539,10 +543,12 @@ def batch_interp_regular_nd_grid(x,
539543
y_interp: Interpolation between members of `y_ref`, at points `x`.
540544
`Tensor` of same `dtype` as `x`, and shape `[..., D, B1, ..., BM].`
541545
546+
Exceptions will be raised if shapes are statically determined to be wrong.
547+
542548
Raises:
543-
ValueError: If `rank(x) < 2` is determined statically.
544-
ValueError: If `axis` is not a scalar is determined statically.
545-
ValueError: If `axis + nd > rank(y_ref)` is determined statically.
549+
ValueError: If `rank(x) < 2`.
550+
ValueError: If `axis` is not a scalar.
551+
ValueError: If `axis + nd > rank(y_ref)`.
546552
547553
#### Examples
548554
@@ -575,7 +581,7 @@ def func(x0, x1):
575581
576582
y_ref = func(x0s, x1s)
577583
578-
x = np.pi * tf.random.uniform(shape=(10, 2))
584+
x = 2 * np.pi * tf.random.uniform(shape=(10, 2))
579585
580586
tfp.math.batch_interp_regular_nd_grid(x, x_ref_min, x_ref_max, y_ref, axis=-2)
581587
==> tf.sin(x[:, 0]) * tf.cos(x[:, 1])
@@ -587,15 +593,7 @@ def func(x0, x1):
587593
dtype_hint=tf.float32)
588594

589595
# Arg checking.
590-
if isinstance(fill_value, str):
591-
if fill_value != 'constant_extension':
592-
raise ValueError(
593-
'A fill value ({}) was not an allowed string ({})'.format(
594-
fill_value, 'constant_extension'))
595-
else:
596-
fill_value = tf.convert_to_tensor(
597-
fill_value, name='fill_value', dtype=dtype)
598-
_assert_ndims_statically(fill_value, expect_ndims=0)
596+
fill_value = _intake_fill_value_for_nd_interp(fill_value, dtype)
599597

600598
# x.shape = [..., nd].
601599
x = tf.convert_to_tensor(x, name='x', dtype=dtype)
@@ -623,19 +621,7 @@ def func(x0, x1):
623621
x_ref_max.shape[-1:], x_ref_min.shape[-1:])
624622

625623
# Convert axis and check it statically.
626-
axis = ps.convert_to_shape_tensor(axis, dtype=tf.int32, name='axis')
627-
axis = ps.non_negative_axis(axis, ps.rank(y_ref))
628-
tensorshape_util.assert_has_rank(axis.shape, 0)
629-
axis_ = tf.get_static_value(axis)
630-
y_ref_rank_ = tf.get_static_value(tf.rank(y_ref))
631-
if axis_ is not None and y_ref_rank_ is not None:
632-
if axis_ + nd > y_ref_rank_:
633-
raise ValueError(
634-
'Since dims `[axis, axis + nd)` index the interpolation table, we '
635-
'must have `axis + nd <= rank(y_ref)`. Found: '
636-
'`axis`: {}, rank(y_ref): {}, and inferred `nd` from trailing '
637-
'dimensions of `x_ref_min` to be {}.'.format(
638-
axis_, y_ref_rank_, nd))
624+
axis = _intake_axis_for_nd_interp(axis, y_ref, nd)
639625

640626
x_batch_shape = ps.shape_slice(x, np.s_[:-2])
641627
x_ref_min_batch_shape = ps.shape_slice(x_ref_min, np.s_[:-1])
@@ -665,7 +651,7 @@ def _batch_shape_of_zeros_with_rightmost_singletons(n_singletons):
665651
_batch_shape_of_zeros_with_rightmost_singletons(
666652
n_singletons=ps.rank(y_ref) - axis))
667653

668-
# In this function,
654+
# At this point,
669655
# x.shape = [A1, ..., An, D, nd], where n = batch_ndims
670656
# and
671657
# y_ref.shape = [A1, ..., An, C1, C2,..., Cnd, B1,...,BM]
@@ -678,6 +664,9 @@ def _batch_shape_of_zeros_with_rightmost_singletons(n_singletons):
678664
# ny[k] is number of y reference points in interp dim k.
679665
# It is used to indicate the dimension sizes.
680666
ny = tf.cast(
667+
# After broadcasting y_ref with x, slice(batch_ndims, batch_ndims + nd)
668+
# is the proper way to extract ny. Before broadcasting, use
669+
# slice(axis, axis + nd)
681670
ps.shape_slice(y_ref, np.s_[batch_ndims:batch_ndims + nd]), dtype)
682671

683672
# Map [x_ref_min, x_ref_max] to [0, ny - 1].
@@ -698,6 +687,228 @@ def _batch_shape_of_zeros_with_rightmost_singletons(n_singletons):
698687
batch_ndims=batch_ndims)
699688

700689

690+
def batch_interp_rectilinear_nd_grid(x,
691+
x_grid_points,
692+
y_ref,
693+
axis,
694+
fill_value='constant_extension',
695+
name=None):
696+
"""Multi-linear interpolation on a rectilinear grid.
697+
698+
Given [a batch of] reference values, this function computes a multi-linear
699+
interpolant and evaluates it on [a batch of] new `x` values. This is a
700+
multi-dimensional generalization of [Bilinear Interpolation](
701+
https://en.wikipedia.org/wiki/Bilinear_interpolation).
702+
703+
The interpolant is built from reference values indexed by `nd` dimensions
704+
of `y_ref`, starting at `axis`.
705+
706+
The x grid is defined by `1-D` points along each dimension. These points must
707+
be sorted, but may have unequal spacing.
708+
709+
For example, take the case of a `2-D` scalar valued function and no leading
710+
batch dimensions. In this case, `y_ref.shape = [C1, C2]` and `y_ref[i, j]`
711+
is the reference value corresponding to grid point
712+
713+
```[x_grid_points[0][i], x_grid_points[1][j]]```
714+
715+
In the general case, dimensions to the left of `axis` in `y_ref` are broadcast
716+
with leading dimensions in `x`, and `x_grid_points[k]`, `k = 0, ..., nd - 1`.
717+
718+
Args:
719+
x: Numeric `Tensor` The x-coordinates of the interpolated output values for
720+
each batch. Shape `[..., D, nd]`, designating [a batch of] `D`
721+
coordinates in `nd` space. `D` must be `>= 1` and is not a batch dim.
722+
x_grid_points: Tuple of dimension points. `x_grid_points[k]` are a shape
723+
`[..., Ck]` `Tensor` of the same dtype as `x` that must be sorted along
724+
the innermost (-1) axis. These represent [a batch of] points defining the
725+
`kth` dimension values.
726+
y_ref: `Tensor` of same `dtype` as `x`. The reference output values. Shape
727+
`[..., C1, ..., Cnd, B1,...,BM]`, designating [a batch of] reference
728+
values indexed by `nd` dimensions, of a shape `[B1,...,BM]` valued
729+
function (for `M >= 0`).
730+
axis: Scalar integer `Tensor`. Dimensions `[axis, axis + nd)` of `y_ref`
731+
index the interpolation table. E.g. `3-D` interpolation of a scalar
732+
valued function requires `axis=-3` and a `3-D` matrix valued function
733+
requires `axis=-5`.
734+
fill_value: Determines what values output should take for `x` values that
735+
are below/above the min/max values in `x_grid_points`.
736+
'constant_extension' ==> Extend as constant function.
737+
Default value: `'constant_extension'`
738+
name: A name to prepend to created ops.
739+
Default value: `'batch_interp_rectilinear_nd_grid'`.
740+
741+
Returns:
742+
y_interp: Interpolation between members of `y_ref`, at points `x`.
743+
`Tensor` of same `dtype` as `x`, and shape `[..., D, B1, ..., BM].`
744+
745+
Exceptions will be raised if shapes are statically determined to be wrong.
746+
747+
Raises:
748+
ValueError: If `rank(x) < 2`
749+
ValueError: If `axis` is not a scalar.
750+
ValueError: If `axis + nd > rank(y_ref)`.
751+
ValueError: If `x_grid_points[k].shape[-1] != y_ref.shape[axis + k]`.
752+
753+
#### Examples
754+
755+
Interpolate a function of one variable.
756+
757+
```python
758+
x_grid = tf.linspace(0., 1., 20)**2 # Nonlinearly spaced
759+
y_ref = tf.exp(x_grid)
760+
761+
tfp.math.batch_interp_rectilinear_nd_grid(
762+
# x.shape = [3, 1], with the trailing `1` for `1-D`.
763+
x=[[6.0], [0.5], [3.3]], x_grid_points=(x_grid,), y_ref=y_ref, axis=0)
764+
==> approx [exp(6.0), exp(0.5), exp(3.3)]
765+
```
766+
767+
Interpolate a scalar function of two variables.
768+
769+
```python
770+
x0_grid = tf.linspace(0., 2 * np.pi, num=100),
771+
x1_grid = tf.linspace(0., 2 * np.pi, num=100),
772+
773+
# Build y_ref.
774+
x0s, x1s = tf.meshgrid(x0_grid, x1_grid, indexing='ij')
775+
776+
def func(x0, x1):
777+
return tf.sin(x0) * tf.cos(x1)
778+
779+
y_ref = func(x0s, x1s)
780+
781+
x = np.pi * tf.random.uniform(shape=(10, 2))
782+
783+
tfp.math.batch_interp_regular_nd_grid(x, x_grid_points=(x0_grid, x1_grid),
784+
y_ref, axis=-2)
785+
==> tf.sin(x[:, 0]) * tf.cos(x[:, 1])
786+
```
787+
788+
"""
789+
with tf.name_scope(name or 'batch_interp_rectilinear_nd_grid'):
790+
if not isinstance(x_grid_points, tuple):
791+
raise ValueError(
792+
f'`x_grid_points` must be a tuple. Found {type(x_grid_points)}')
793+
794+
dtype = dtype_util.common_dtype([x, y_ref] + list(x_grid_points),
795+
dtype_hint=tf.float32)
796+
797+
# Arg checking.
798+
fill_value = _intake_fill_value_for_nd_interp(fill_value, dtype)
799+
800+
# x.shape = [..., nd].
801+
x = tf.convert_to_tensor(x, name='x', dtype=dtype)
802+
_assert_ndims_statically(x, expect_ndims_at_least=2)
803+
804+
# y_ref.shape = [..., C1,...,Cnd, B1,...,BM]
805+
y_ref = tf.convert_to_tensor(y_ref, name='y_ref', dtype=dtype)
806+
807+
# x_ref_min.shape = [nd]
808+
x_grid_points = tuple(
809+
tf.convert_to_tensor(p, dtype=dtype) for p in x_grid_points)
810+
for p in x_grid_points:
811+
_assert_ndims_statically(p, expect_ndims_at_least=1, expect_static=True)
812+
813+
# nd is the number of dimensions indexing the interpolation table, it's the
814+
# 'nd' in the function name.
815+
nd = len(x_grid_points)
816+
817+
# Convert axis and check it statically.
818+
axis = _intake_axis_for_nd_interp(axis, y_ref, nd)
819+
820+
# Check that the number of grid points implied by x_grid_points and y_ref
821+
# match.
822+
for k, p_k in enumerate(x_grid_points):
823+
nx_k = p_k.shape[-1]
824+
ny_k = y_ref.shape[axis + k]
825+
if ny_k is not None and ny_k is not None and nx_k != ny_k:
826+
raise ValueError(
827+
f'x_grid_points[{k}] contained {nx_k} points, which differed from '
828+
f'{ny_k}, the number of points in the {k}th table dimension of '
829+
f'y_ref.')
830+
831+
x_batch_shape = ps.shape_slice(x, np.s_[:-2])
832+
x_grid_points_batch_shapes = list(
833+
ps.shape_slice(p, np.s_[:-1]) for p in x_grid_points)
834+
y_ref_batch_shape = ps.shape_slice(y_ref, np.s_[:axis])
835+
836+
# Do a brute-force broadcast of batch dims (add zeros).
837+
batch_shape = y_ref_batch_shape
838+
for tensor in [x_batch_shape] + x_grid_points_batch_shapes:
839+
batch_shape = ps.broadcast_shape(batch_shape, tensor)
840+
841+
def _batch_shape_of_zeros_with_rightmost_singletons(n_singletons):
842+
"""Return Tensor of zeros with some singletons on the rightmost dims."""
843+
ones = ps.ones(shape=[n_singletons], dtype=tf.int32)
844+
return ps.concat([batch_shape, ones], axis=0)
845+
846+
x = _broadcast_with(
847+
x, _batch_shape_of_zeros_with_rightmost_singletons(n_singletons=2))
848+
x_grid_points = tuple(
849+
_broadcast_with(
850+
p, _batch_shape_of_zeros_with_rightmost_singletons(n_singletons=1))
851+
for p in x_grid_points)
852+
y_ref = _broadcast_with(
853+
y_ref,
854+
_batch_shape_of_zeros_with_rightmost_singletons(
855+
n_singletons=ps.rank(y_ref) - axis))
856+
857+
# At this point,
858+
# x.shape = [A1, ..., An, D, nd], where n = batch_ndims
859+
# and
860+
# y_ref.shape = [A1, ..., An, C1, C2,..., Cnd, B1,...,BM]
861+
# y_ref[A1, ..., An, i1,...,ind] is a shape [B1,...,BM] Tensor with value
862+
# at index [i1,...,ind] in the interpolation table.
863+
# and `p_k = x_grid_points[k]` has shape [A1, ..., An, Ck].
864+
865+
batch_ndims = ps.rank(x) - 2
866+
867+
# ny[k] is number of y reference points in interp dim k.
868+
# It is used to indicate the dimension sizes...
869+
# It could also be called nx, if we actually materialized a grid of x
870+
# points. We don't though, as x points are given only as axis values.
871+
ny = tf.cast(
872+
ps.shape_slice(y_ref, np.s_[batch_ndims:batch_ndims + nd]), tf.int32)
873+
874+
# Map the `kth` point `x_grid_points[k]` to [0, ny[k] - 1].
875+
# This is the (fractional) index of x, "unclipped" meaning it may take
876+
# values outside [0, ..., ny[k]].
877+
# x_idx_unclipped[A1, ..., An, d, k] is the fractional index into dim k of
878+
# interpolation table for the dth x value.
879+
x_idx_unclipped = []
880+
for k, p_k in enumerate(x_grid_points):
881+
# x_k and x_k_clipped shape [A1, ..., An, D].
882+
# Clip x_k below...no need to clip above since, in the place it is used
883+
# below, we have a tf.minimum(ny[k] - 1,...)
884+
x_k = x[..., k]
885+
x_k_clipped = tf.maximum(x_k, tf.reduce_min(p_k, axis=-1, keepdims=True))
886+
887+
# This construction of indices ensures that idx_below_k < idx_above_k.
888+
# In particular, the use of x_k_clipped ensures this, even if x_k is OOB.
889+
idx_above_k = tf.minimum(
890+
ny[k] - 1, tf.searchsorted(p_k, x_k_clipped, side='right'))
891+
idx_below_k = tf.maximum(idx_above_k - 1, 0)
892+
x_above_k = tf.gather(p_k, idx_above_k, batch_dims=batch_ndims)
893+
x_below_k = tf.gather(p_k, idx_below_k, batch_dims=batch_ndims)
894+
895+
# The use of x_k (not clipped) here allows x_idx_unclipped to be < 0 or >
896+
# ny[k] - 1.
897+
x_idx_unclipped.append(
898+
tf.cast(idx_below_k, dtype) + (x_k - x_below_k) /
899+
(x_above_k - x_below_k))
900+
901+
x_idx_unclipped = tf.stack(x_idx_unclipped, axis=-1)
902+
903+
return _batch_interp_with_gather_nd(
904+
x=x,
905+
x_idx_unclipped=x_idx_unclipped,
906+
y_ref=y_ref,
907+
nd=nd,
908+
fill_value=fill_value,
909+
batch_ndims=batch_ndims)
910+
911+
701912
def _batch_interp_with_gather_nd(x, x_idx_unclipped, y_ref, nd, fill_value,
702913
batch_ndims):
703914
"""Batch interpolation starting with indices."""
@@ -852,6 +1063,38 @@ def _assert_ndims_statically(x,
8521063
expect_ndims_at_least, ndims))
8531064

8541065

1066+
def _intake_fill_value_for_nd_interp(fill_value, dtype):
1067+
"""Check `fill_value` and return after converting numeric to tensor."""
1068+
if isinstance(fill_value, str):
1069+
if fill_value != 'constant_extension':
1070+
raise ValueError(
1071+
'A fill value ({}) was not an allowed string ({})'.format(
1072+
fill_value, 'constant_extension'))
1073+
else:
1074+
fill_value = tf.convert_to_tensor(
1075+
fill_value, name='fill_value', dtype=dtype)
1076+
_assert_ndims_statically(fill_value, expect_ndims=0)
1077+
return fill_value
1078+
1079+
1080+
def _intake_axis_for_nd_interp(axis, y_ref, nd):
1081+
"""Convert `axis` to its non-negative value and return after validation."""
1082+
axis = ps.convert_to_shape_tensor(axis, dtype=tf.int32, name='axis')
1083+
axis = ps.non_negative_axis(axis, ps.rank(y_ref))
1084+
tensorshape_util.assert_has_rank(axis.shape, 0)
1085+
axis_ = tf.get_static_value(axis)
1086+
y_ref_rank_ = tf.get_static_value(tf.rank(y_ref))
1087+
if axis_ is not None and y_ref_rank_ is not None:
1088+
if axis_ + nd > y_ref_rank_:
1089+
raise ValueError(
1090+
'Since dims `[axis, axis + nd)` index the interpolation table, we '
1091+
'must have `axis + nd <= rank(y_ref)`. Found: '
1092+
'`axis`: {}, rank(y_ref): {}, and inferred `nd` from trailing '
1093+
'dimensions of `x_ref_min` to be {}.'.format(
1094+
axis_, y_ref_rank_, nd))
1095+
return axis
1096+
1097+
8551098
def _make_expand_x_fn_for_non_batch_interpolation(y_ref, axis):
8561099
"""Make func to expand left/right (of axis) dims of tensors shaped like x."""
8571100
# This expansion is to help x broadcast with `y`, the output.

0 commit comments

Comments
 (0)