Skip to content

Commit 43e874a

Browse files
committed
Remove unused function
Add test cases
1 parent 0bba698 commit 43e874a

File tree

2 files changed

+3
-7
lines changed

2 files changed

+3
-7
lines changed

tensorflow_probability/python/stats/sample_stats.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -925,13 +925,6 @@ def _prepare_window_args(x, low_indices=None, high_indices=None, axis=0):
925925
return x, indices, axis
926926

927927

928-
def _safe_average(totals, counts):
929-
# This tf.where protects `totals` from getting a gradient signal
930-
# when `counts` is 0.
931-
safe_totals = tf.where(~tf.equal(counts, 0), totals, 0)
932-
return tf.where(~tf.equal(counts, 0), safe_totals / counts, 0)
933-
934-
935928
def log_average_probs(logits, sample_axis=0, event_axis=None, keepdims=False,
936929
validate_args=False, name=None):
937930
"""Computes `log(average(to_probs(logits)))` in a numerically stable manner.

tensorflow_probability/python/stats/sample_stats_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -735,13 +735,16 @@ def check_windowed(self, func, numpy_func):
735735
check_fn((64, 4, 8), (32, 1, 1), axis=0)
736736
check_fn((64, 4, 8), (32, 4, 1), axis=0)
737737
check_fn((64, 4, 8), (32, 4, 8), axis=0)
738+
check_fn((64, 4, 8), (64, 4, 8), axis=0)
738739
check_fn((64, 4, 8), (64, 64, 1), axis=1)
739740
check_fn((64, 4, 8), (1, 64, 1), axis=1)
740741
check_fn((64, 4, 8), (64, 2, 8), axis=1)
742+
check_fn((64, 4, 8), (64, 4, 8), axis=1)
741743
check_fn((64, 4, 8), (64, 4, 64), axis=2)
742744
check_fn((64, 4, 8), (1, 1, 64), axis=2)
743745
check_fn((64, 4, 8), (64, 4, 4), axis=2)
744746
check_fn((64, 4, 8), (1, 1, 4), axis=2)
747+
check_fn((64, 4, 8), (64, 4, 8), axis=2)
745748

746749
with self.assertRaises(Exception):
747750
# Non broadcastable shapes

0 commit comments

Comments
 (0)