Skip to content

Commit c7aa3b2

Browse files
langmoretensorflower-gardener
authored andcommitted
Optimization in batch_interp_*nd_grid functions.
Stop accumulating by appending (possibly large) tensors in a list then summing. Instead, add the tensors as they become available. This reduces memory usage (if XLA compilation is not used). With XLA compilation, this appears to be a no-op. PiperOrigin-RevId: 463181092
1 parent 4533188 commit c7aa3b2

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

tensorflow_probability/python/math/interpolation.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -975,7 +975,13 @@ def _expand_x_fn(tensor):
975975
nan_idx = _expand_x_fn(nan_idx)
976976
t = tf.where(nan_idx, tf.constant(np.nan, dtype), t)
977977

978-
terms = []
978+
# Initialize y and accumulate in a loop. An alternative would be to store
979+
# summands in a list. However, without XLA compilation, the "list method"
980+
# results in storage of 2^nd (possibly large) summands, which could OOM.
981+
# Thus, if you are not XLA compiling, the method below is highly preferred.
982+
# With XLA compilation, both methods are equivalent.
983+
y = tf.zeros((), dtype=dtype)
984+
979985
# Our work above has located x's fractional index inside a cube of above/below
980986
# indices. The distance to the below indices is t, and to the above indices
981987
# is s.
@@ -1027,9 +1033,7 @@ def _expand_x_fn(tensor):
10271033
y_ref_pt = tf.gather_nd(
10281034
y_ref, tf.stack(gather_from_y_ref_idx, axis=-1), batch_dims=batch_ndims)
10291035

1030-
terms.append(y_ref_pt * opposite_volume)
1031-
1032-
y = tf.math.add_n(terms)
1036+
y = y + y_ref_pt * opposite_volume
10331037

10341038
if tf.debugging.is_numeric_tensor(fill_value):
10351039
# Recall x_idx_unclipped.shape = [D, nd],

0 commit comments

Comments
 (0)