29
29
'interp_regular_1d_grid' ,
30
30
'batch_interp_regular_1d_grid' ,
31
31
'batch_interp_regular_nd_grid' ,
32
+ 'batch_interp_rectilinear_nd_grid' ,
32
33
]
33
34
34
35
@@ -500,6 +501,9 @@ def batch_interp_regular_nd_grid(x,
500
501
The interpolant is built from reference values indexed by `nd` dimensions
501
502
of `y_ref`, starting at `axis`.
502
503
504
+ The x grid span is defined by `x_ref_min`, `x_ref_max`. The number of grid
505
+ points is inferred from the shape of `y_ref`.
506
+
503
507
For example, take the case of a `2-D` scalar valued function and no leading
504
508
batch dimensions. In this case, `y_ref.shape = [C1, C2]` and `y_ref[i, j]`
505
509
is the reference value corresponding to grid point
@@ -539,10 +543,12 @@ def batch_interp_regular_nd_grid(x,
539
543
y_interp: Interpolation between members of `y_ref`, at points `x`.
540
544
`Tensor` of same `dtype` as `x`, and shape `[..., D, B1, ..., BM].`
541
545
546
+ Exceptions will be raised if shapes are statically determined to be wrong.
547
+
542
548
Raises:
543
- ValueError: If `rank(x) < 2` is determined statically .
544
- ValueError: If `axis` is not a scalar is determined statically .
545
- ValueError: If `axis + nd > rank(y_ref)` is determined statically .
549
+ ValueError: If `rank(x) < 2`.
550
+ ValueError: If `axis` is not a scalar.
551
+ ValueError: If `axis + nd > rank(y_ref)`.
546
552
547
553
#### Examples
548
554
@@ -575,7 +581,7 @@ def func(x0, x1):
575
581
576
582
y_ref = func(x0s, x1s)
577
583
578
- x = np.pi * tf.random.uniform(shape=(10, 2))
584
+ x = 2 * np.pi * tf.random.uniform(shape=(10, 2))
579
585
580
586
tfp.math.batch_interp_regular_nd_grid(x, x_ref_min, x_ref_max, y_ref, axis=-2)
581
587
==> tf.sin(x[:, 0]) * tf.cos(x[:, 1])
@@ -587,15 +593,7 @@ def func(x0, x1):
587
593
dtype_hint = tf .float32 )
588
594
589
595
# Arg checking.
590
- if isinstance (fill_value , str ):
591
- if fill_value != 'constant_extension' :
592
- raise ValueError (
593
- 'A fill value ({}) was not an allowed string ({})' .format (
594
- fill_value , 'constant_extension' ))
595
- else :
596
- fill_value = tf .convert_to_tensor (
597
- fill_value , name = 'fill_value' , dtype = dtype )
598
- _assert_ndims_statically (fill_value , expect_ndims = 0 )
596
+ fill_value = _intake_fill_value_for_nd_interp (fill_value , dtype )
599
597
600
598
# x.shape = [..., nd].
601
599
x = tf .convert_to_tensor (x , name = 'x' , dtype = dtype )
@@ -623,19 +621,7 @@ def func(x0, x1):
623
621
x_ref_max .shape [- 1 :], x_ref_min .shape [- 1 :])
624
622
625
623
# Convert axis and check it statically.
626
- axis = ps .convert_to_shape_tensor (axis , dtype = tf .int32 , name = 'axis' )
627
- axis = ps .non_negative_axis (axis , ps .rank (y_ref ))
628
- tensorshape_util .assert_has_rank (axis .shape , 0 )
629
- axis_ = tf .get_static_value (axis )
630
- y_ref_rank_ = tf .get_static_value (tf .rank (y_ref ))
631
- if axis_ is not None and y_ref_rank_ is not None :
632
- if axis_ + nd > y_ref_rank_ :
633
- raise ValueError (
634
- 'Since dims `[axis, axis + nd)` index the interpolation table, we '
635
- 'must have `axis + nd <= rank(y_ref)`. Found: '
636
- '`axis`: {}, rank(y_ref): {}, and inferred `nd` from trailing '
637
- 'dimensions of `x_ref_min` to be {}.' .format (
638
- axis_ , y_ref_rank_ , nd ))
624
+ axis = _intake_axis_for_nd_interp (axis , y_ref , nd )
639
625
640
626
x_batch_shape = ps .shape_slice (x , np .s_ [:- 2 ])
641
627
x_ref_min_batch_shape = ps .shape_slice (x_ref_min , np .s_ [:- 1 ])
@@ -665,7 +651,7 @@ def _batch_shape_of_zeros_with_rightmost_singletons(n_singletons):
665
651
_batch_shape_of_zeros_with_rightmost_singletons (
666
652
n_singletons = ps .rank (y_ref ) - axis ))
667
653
668
- # In this function ,
654
+ # At this point ,
669
655
# x.shape = [A1, ..., An, D, nd], where n = batch_ndims
670
656
# and
671
657
# y_ref.shape = [A1, ..., An, C1, C2,..., Cnd, B1,...,BM]
@@ -678,6 +664,9 @@ def _batch_shape_of_zeros_with_rightmost_singletons(n_singletons):
678
664
# ny[k] is number of y reference points in interp dim k.
679
665
# It is used to indicate the dimension sizes.
680
666
ny = tf .cast (
667
+ # After broadcasting y_ref with x, slice(batch_ndims, batch_ndims + nd)
668
+ # is the proper way to extract ny. Before broadcasting, use
669
+ # slice(axis, axis + nd)
681
670
ps .shape_slice (y_ref , np .s_ [batch_ndims :batch_ndims + nd ]), dtype )
682
671
683
672
# Map [x_ref_min, x_ref_max] to [0, ny - 1].
@@ -698,6 +687,228 @@ def _batch_shape_of_zeros_with_rightmost_singletons(n_singletons):
698
687
batch_ndims = batch_ndims )
699
688
700
689
690
+ def batch_interp_rectilinear_nd_grid (x ,
691
+ x_grid_points ,
692
+ y_ref ,
693
+ axis ,
694
+ fill_value = 'constant_extension' ,
695
+ name = None ):
696
+ """Multi-linear interpolation on a rectilinear grid.
697
+
698
+ Given [a batch of] reference values, this function computes a multi-linear
699
+ interpolant and evaluates it on [a batch of] new `x` values. This is a
700
+ multi-dimensional generalization of [Bilinear Interpolation](
701
+ https://en.wikipedia.org/wiki/Bilinear_interpolation).
702
+
703
+ The interpolant is built from reference values indexed by `nd` dimensions
704
+ of `y_ref`, starting at `axis`.
705
+
706
+ The x grid is defined by `1-D` points along each dimension. These points must
707
+ be sorted, but may have unequal spacing.
708
+
709
+ For example, take the case of a `2-D` scalar valued function and no leading
710
+ batch dimensions. In this case, `y_ref.shape = [C1, C2]` and `y_ref[i, j]`
711
+ is the reference value corresponding to grid point
712
+
713
+ ```[x_grid_points[0][i], x_grid_points[1][j]]```
714
+
715
+ In the general case, dimensions to the left of `axis` in `y_ref` are broadcast
716
+ with leading dimensions in `x`, and `x_grid_points[k]`, `k = 0, ..., nd - 1`.
717
+
718
+ Args:
719
+ x: Numeric `Tensor` The x-coordinates of the interpolated output values for
720
+ each batch. Shape `[..., D, nd]`, designating [a batch of] `D`
721
+ coordinates in `nd` space. `D` must be `>= 1` and is not a batch dim.
722
+ x_grid_points: Tuple of dimension points. `x_grid_points[k]` are a shape
723
+ `[..., Ck]` `Tensor` of the same dtype as `x` that must be sorted along
724
+ the innermost (-1) axis. These represent [a batch of] points defining the
725
+ `kth` dimension values.
726
+ y_ref: `Tensor` of same `dtype` as `x`. The reference output values. Shape
727
+ `[..., C1, ..., Cnd, B1,...,BM]`, designating [a batch of] reference
728
+ values indexed by `nd` dimensions, of a shape `[B1,...,BM]` valued
729
+ function (for `M >= 0`).
730
+ axis: Scalar integer `Tensor`. Dimensions `[axis, axis + nd)` of `y_ref`
731
+ index the interpolation table. E.g. `3-D` interpolation of a scalar
732
+ valued function requires `axis=-3` and a `3-D` matrix valued function
733
+ requires `axis=-5`.
734
+ fill_value: Determines what values output should take for `x` values that
735
+ are below/above the min/max values in `x_grid_points`.
736
+ 'constant_extension' ==> Extend as constant function.
737
+ Default value: `'constant_extension'`
738
+ name: A name to prepend to created ops.
739
+ Default value: `'batch_interp_rectilinear_nd_grid'`.
740
+
741
+ Returns:
742
+ y_interp: Interpolation between members of `y_ref`, at points `x`.
743
+ `Tensor` of same `dtype` as `x`, and shape `[..., D, B1, ..., BM].`
744
+
745
+ Exceptions will be raised if shapes are statically determined to be wrong.
746
+
747
+ Raises:
748
+ ValueError: If `rank(x) < 2`
749
+ ValueError: If `axis` is not a scalar.
750
+ ValueError: If `axis + nd > rank(y_ref)`.
751
+ ValueError: If `x_grid_points[k].shape[-1] != y_ref.shape[axis + k]`.
752
+
753
+ #### Examples
754
+
755
+ Interpolate a function of one variable.
756
+
757
+ ```python
758
+ x_grid = tf.linspace(0., 1., 20)**2 # Nonlinearly spaced
759
+ y_ref = tf.exp(x_grid)
760
+
761
+ tfp.math.batch_interp_rectilinear_nd_grid(
762
+ # x.shape = [3, 1], with the trailing `1` for `1-D`.
763
+ x=[[6.0], [0.5], [3.3]], x_grid_points=(x_grid,), y_ref=y_ref, axis=0)
764
+ ==> approx [exp(6.0), exp(0.5), exp(3.3)]
765
+ ```
766
+
767
+ Interpolate a scalar function of two variables.
768
+
769
+ ```python
770
+ x0_grid = tf.linspace(0., 2 * np.pi, num=100),
771
+ x1_grid = tf.linspace(0., 2 * np.pi, num=100),
772
+
773
+ # Build y_ref.
774
+ x0s, x1s = tf.meshgrid(x0_grid, x1_grid, indexing='ij')
775
+
776
+ def func(x0, x1):
777
+ return tf.sin(x0) * tf.cos(x1)
778
+
779
+ y_ref = func(x0s, x1s)
780
+
781
+ x = np.pi * tf.random.uniform(shape=(10, 2))
782
+
783
+ tfp.math.batch_interp_regular_nd_grid(x, x_grid_points=(x0_grid, x1_grid),
784
+ y_ref, axis=-2)
785
+ ==> tf.sin(x[:, 0]) * tf.cos(x[:, 1])
786
+ ```
787
+
788
+ """
789
+ with tf .name_scope (name or 'batch_interp_rectilinear_nd_grid' ):
790
+ if not isinstance (x_grid_points , tuple ):
791
+ raise ValueError (
792
+ f'`x_grid_points` must be a tuple. Found { type (x_grid_points )} ' )
793
+
794
+ dtype = dtype_util .common_dtype ([x , y_ref ] + list (x_grid_points ),
795
+ dtype_hint = tf .float32 )
796
+
797
+ # Arg checking.
798
+ fill_value = _intake_fill_value_for_nd_interp (fill_value , dtype )
799
+
800
+ # x.shape = [..., nd].
801
+ x = tf .convert_to_tensor (x , name = 'x' , dtype = dtype )
802
+ _assert_ndims_statically (x , expect_ndims_at_least = 2 )
803
+
804
+ # y_ref.shape = [..., C1,...,Cnd, B1,...,BM]
805
+ y_ref = tf .convert_to_tensor (y_ref , name = 'y_ref' , dtype = dtype )
806
+
807
+ # x_ref_min.shape = [nd]
808
+ x_grid_points = tuple (
809
+ tf .convert_to_tensor (p , dtype = dtype ) for p in x_grid_points )
810
+ for p in x_grid_points :
811
+ _assert_ndims_statically (p , expect_ndims_at_least = 1 , expect_static = True )
812
+
813
+ # nd is the number of dimensions indexing the interpolation table, it's the
814
+ # 'nd' in the function name.
815
+ nd = len (x_grid_points )
816
+
817
+ # Convert axis and check it statically.
818
+ axis = _intake_axis_for_nd_interp (axis , y_ref , nd )
819
+
820
+ # Check that the number of grid points implied by x_grid_points and y_ref
821
+ # match.
822
+ for k , p_k in enumerate (x_grid_points ):
823
+ nx_k = p_k .shape [- 1 ]
824
+ ny_k = y_ref .shape [axis + k ]
825
+ if ny_k is not None and ny_k is not None and nx_k != ny_k :
826
+ raise ValueError (
827
+ f'x_grid_points[{ k } ] contained { nx_k } points, which differed from '
828
+ f'{ ny_k } , the number of points in the { k } th table dimension of '
829
+ f'y_ref.' )
830
+
831
+ x_batch_shape = ps .shape_slice (x , np .s_ [:- 2 ])
832
+ x_grid_points_batch_shapes = list (
833
+ ps .shape_slice (p , np .s_ [:- 1 ]) for p in x_grid_points )
834
+ y_ref_batch_shape = ps .shape_slice (y_ref , np .s_ [:axis ])
835
+
836
+ # Do a brute-force broadcast of batch dims (add zeros).
837
+ batch_shape = y_ref_batch_shape
838
+ for tensor in [x_batch_shape ] + x_grid_points_batch_shapes :
839
+ batch_shape = ps .broadcast_shape (batch_shape , tensor )
840
+
841
+ def _batch_shape_of_zeros_with_rightmost_singletons (n_singletons ):
842
+ """Return Tensor of zeros with some singletons on the rightmost dims."""
843
+ ones = ps .ones (shape = [n_singletons ], dtype = tf .int32 )
844
+ return ps .concat ([batch_shape , ones ], axis = 0 )
845
+
846
+ x = _broadcast_with (
847
+ x , _batch_shape_of_zeros_with_rightmost_singletons (n_singletons = 2 ))
848
+ x_grid_points = tuple (
849
+ _broadcast_with (
850
+ p , _batch_shape_of_zeros_with_rightmost_singletons (n_singletons = 1 ))
851
+ for p in x_grid_points )
852
+ y_ref = _broadcast_with (
853
+ y_ref ,
854
+ _batch_shape_of_zeros_with_rightmost_singletons (
855
+ n_singletons = ps .rank (y_ref ) - axis ))
856
+
857
+ # At this point,
858
+ # x.shape = [A1, ..., An, D, nd], where n = batch_ndims
859
+ # and
860
+ # y_ref.shape = [A1, ..., An, C1, C2,..., Cnd, B1,...,BM]
861
+ # y_ref[A1, ..., An, i1,...,ind] is a shape [B1,...,BM] Tensor with value
862
+ # at index [i1,...,ind] in the interpolation table.
863
+ # and `p_k = x_grid_points[k]` has shape [A1, ..., An, Ck].
864
+
865
+ batch_ndims = ps .rank (x ) - 2
866
+
867
+ # ny[k] is number of y reference points in interp dim k.
868
+ # It is used to indicate the dimension sizes...
869
+ # It could also be called nx, if we actually materialized a grid of x
870
+ # points. We don't though, as x points are given only as axis values.
871
+ ny = tf .cast (
872
+ ps .shape_slice (y_ref , np .s_ [batch_ndims :batch_ndims + nd ]), tf .int32 )
873
+
874
+ # Map the `kth` point `x_grid_points[k]` to [0, ny[k] - 1].
875
+ # This is the (fractional) index of x, "unclipped" meaning it may take
876
+ # values outside [0, ..., ny[k]].
877
+ # x_idx_unclipped[A1, ..., An, d, k] is the fractional index into dim k of
878
+ # interpolation table for the dth x value.
879
+ x_idx_unclipped = []
880
+ for k , p_k in enumerate (x_grid_points ):
881
+ # x_k and x_k_clipped shape [A1, ..., An, D].
882
+ # Clip x_k below...no need to clip above since, in the place it is used
883
+ # below, we have a tf.minimum(ny[k] - 1,...)
884
+ x_k = x [..., k ]
885
+ x_k_clipped = tf .maximum (x_k , tf .reduce_min (p_k , axis = - 1 , keepdims = True ))
886
+
887
+ # This construction of indices ensures that idx_below_k < idx_above_k.
888
+ # In particular, the use of x_k_clipped ensures this, even if x_k is OOB.
889
+ idx_above_k = tf .minimum (
890
+ ny [k ] - 1 , tf .searchsorted (p_k , x_k_clipped , side = 'right' ))
891
+ idx_below_k = tf .maximum (idx_above_k - 1 , 0 )
892
+ x_above_k = tf .gather (p_k , idx_above_k , batch_dims = batch_ndims )
893
+ x_below_k = tf .gather (p_k , idx_below_k , batch_dims = batch_ndims )
894
+
895
+ # The use of x_k (not clipped) here allows x_idx_unclipped to be < 0 or >
896
+ # ny[k] - 1.
897
+ x_idx_unclipped .append (
898
+ tf .cast (idx_below_k , dtype ) + (x_k - x_below_k ) /
899
+ (x_above_k - x_below_k ))
900
+
901
+ x_idx_unclipped = tf .stack (x_idx_unclipped , axis = - 1 )
902
+
903
+ return _batch_interp_with_gather_nd (
904
+ x = x ,
905
+ x_idx_unclipped = x_idx_unclipped ,
906
+ y_ref = y_ref ,
907
+ nd = nd ,
908
+ fill_value = fill_value ,
909
+ batch_ndims = batch_ndims )
910
+
911
+
701
912
def _batch_interp_with_gather_nd (x , x_idx_unclipped , y_ref , nd , fill_value ,
702
913
batch_ndims ):
703
914
"""Batch interpolation starting with indices."""
@@ -852,6 +1063,38 @@ def _assert_ndims_statically(x,
852
1063
expect_ndims_at_least , ndims ))
853
1064
854
1065
1066
+ def _intake_fill_value_for_nd_interp (fill_value , dtype ):
1067
+ """Check `fill_value` and return after converting numeric to tensor."""
1068
+ if isinstance (fill_value , str ):
1069
+ if fill_value != 'constant_extension' :
1070
+ raise ValueError (
1071
+ 'A fill value ({}) was not an allowed string ({})' .format (
1072
+ fill_value , 'constant_extension' ))
1073
+ else :
1074
+ fill_value = tf .convert_to_tensor (
1075
+ fill_value , name = 'fill_value' , dtype = dtype )
1076
+ _assert_ndims_statically (fill_value , expect_ndims = 0 )
1077
+ return fill_value
1078
+
1079
+
1080
+ def _intake_axis_for_nd_interp (axis , y_ref , nd ):
1081
+ """Convert `axis` to its non-negative value and return after validation."""
1082
+ axis = ps .convert_to_shape_tensor (axis , dtype = tf .int32 , name = 'axis' )
1083
+ axis = ps .non_negative_axis (axis , ps .rank (y_ref ))
1084
+ tensorshape_util .assert_has_rank (axis .shape , 0 )
1085
+ axis_ = tf .get_static_value (axis )
1086
+ y_ref_rank_ = tf .get_static_value (tf .rank (y_ref ))
1087
+ if axis_ is not None and y_ref_rank_ is not None :
1088
+ if axis_ + nd > y_ref_rank_ :
1089
+ raise ValueError (
1090
+ 'Since dims `[axis, axis + nd)` index the interpolation table, we '
1091
+ 'must have `axis + nd <= rank(y_ref)`. Found: '
1092
+ '`axis`: {}, rank(y_ref): {}, and inferred `nd` from trailing '
1093
+ 'dimensions of `x_ref_min` to be {}.' .format (
1094
+ axis_ , y_ref_rank_ , nd ))
1095
+ return axis
1096
+
1097
+
855
1098
def _make_expand_x_fn_for_non_batch_interpolation (y_ref , axis ):
856
1099
"""Make func to expand left/right (of axis) dims of tensors shaped like x."""
857
1100
# This expansion is to help x broadcast with `y`, the output.
0 commit comments