Skip to content

Commit d48cdfc

Browse files
committed
Remove unused function
Add test cases
1 parent 2675210 commit d48cdfc

File tree

2 files changed

+3
-7
lines changed

2 files changed

+3
-7
lines changed

tensorflow_probability/python/stats/sample_stats.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -925,13 +925,6 @@ def _prepare_window_args(x, low_indices=None, high_indices=None, axis=0):
925925
return x, indices, axis
926926

927927

928-
def _safe_average(totals, counts):
929-
# This tf.where protects `totals` from getting a gradient signal
930-
# when `counts` is 0.
931-
safe_totals = tf.where(~tf.equal(counts, 0), totals, 0)
932-
return tf.where(~tf.equal(counts, 0), safe_totals / counts, 0)
933-
934-
935928
def log_average_probs(logits, sample_axis=0, event_axis=None, keepdims=False,
936929
validate_args=False, name=None):
937930
"""Computes `log(average(to_probs(logits)))` in a numerically stable manner.

tensorflow_probability/python/stats/sample_stats_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -731,13 +731,16 @@ def check_windowed(self, func, numpy_func):
731731
check_fn((64, 4, 8), (32, 1, 1), axis=0)
732732
check_fn((64, 4, 8), (32, 4, 1), axis=0)
733733
check_fn((64, 4, 8), (32, 4, 8), axis=0)
734+
check_fn((64, 4, 8), (64, 4, 8), axis=0)
734735
check_fn((64, 4, 8), (64, 64, 1), axis=1)
735736
check_fn((64, 4, 8), (1, 64, 1), axis=1)
736737
check_fn((64, 4, 8), (64, 2, 8), axis=1)
738+
check_fn((64, 4, 8), (64, 4, 8), axis=1)
737739
check_fn((64, 4, 8), (64, 4, 64), axis=2)
738740
check_fn((64, 4, 8), (1, 1, 64), axis=2)
739741
check_fn((64, 4, 8), (64, 4, 4), axis=2)
740742
check_fn((64, 4, 8), (1, 1, 4), axis=2)
743+
check_fn((64, 4, 8), (64, 4, 8), axis=2)
741744

742745
with self.assertRaises(Exception):
743746
# Non broadcastable shapes

0 commit comments

Comments
 (0)