Skip to content

Commit 6a82ed8

Browse files
Googlertensorflower-gardener
authored andcommitted
Add require_integer_total_count optional construction argument to
NegativeBinomial. PiperOrigin-RevId: 462704047
1 parent 4153e36 commit 6a82ed8

File tree

2 files changed

+23
-6
lines changed

2 files changed

+23
-6
lines changed

tensorflow_probability/python/distributions/negative_binomial.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def __init__(self,
6060
probs=None,
6161
validate_args=False,
6262
allow_nan_stats=True,
63+
require_integer_total_count=True,
6364
name='NegativeBinomial'):
6465
"""Construct NegativeBinomial distributions.
6566
@@ -89,6 +90,8 @@ def __init__(self,
8990
(e.g., mean, mode, variance) use the value "`NaN`" to indicate the
9091
result is undefined. When `False`, an exception is raised if one or
9192
more of the statistic's batch members are undefined.
93+
require_integer_total_count: Python `bool`, default `True`. When `True`,
94+
the total_count parameter is required to be integer.
9295
name: Python `str` name prefixed to Ops created by this class.
9396
"""
9497

@@ -105,6 +108,7 @@ def __init__(self,
105108
logits, dtype=dtype, name='logits')
106109
self._total_count = tensor_util.convert_nonref_to_tensor(
107110
total_count, dtype=dtype, name='total_count')
111+
self._require_integer_total_count = require_integer_total_count
108112

109113
super(NegativeBinomial, self).__init__(
110114
dtype=dtype,
@@ -255,7 +259,7 @@ def _default_event_space_bijector(self):
255259
def _parameter_control_dependencies(self, is_init):
256260
return maybe_assert_negative_binomial_param_correctness(
257261
is_init, self.validate_args, self._total_count, self._probs,
258-
self._logits)
262+
self._logits, self._require_integer_total_count)
259263

260264
def _sample_control_dependencies(self, x):
261265
"""Check counts for proper shape and values, then return tensor version."""
@@ -267,7 +271,8 @@ def _sample_control_dependencies(self, x):
267271

268272

269273
def maybe_assert_negative_binomial_param_correctness(
270-
is_init, validate_args, total_count, probs, logits):
274+
is_init, validate_args, total_count, probs, logits,
275+
require_integer_total_count):
271276
"""Return assertions for `NegativeBinomial`-type distributions."""
272277
if is_init:
273278
x, name = (probs, 'probs') if logits is None else (logits, 'logits')
@@ -284,11 +289,13 @@ def maybe_assert_negative_binomial_param_correctness(
284289
assertions.extend([
285290
assert_util.assert_positive(
286291
total_count,
287-
message='`total_count` has components less than or equal to 0.'),
288-
distribution_util.assert_integer_form(
289-
total_count,
290-
message='`total_count` has fractional components.')
292+
message='`total_count` has components less than or equal to 0.')
291293
])
294+
if require_integer_total_count:
295+
assertions.extend([
296+
distribution_util.assert_integer_form(
297+
total_count, message='`total_count` has fractional components.')
298+
])
292299
if probs is not None:
293300
if is_init != tensor_util.is_ref(probs):
294301
probs = tf.convert_to_tensor(probs)

tensorflow_probability/python/distributions/negative_binomial_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,16 @@ def testGradientOfLogProbEvaluates(self):
318318
self.evaluate(tfp.math.value_and_gradient(
319319
tfd.NegativeBinomial(0.1, 0.).log_prob, [0.1]))
320320

321+
def testRequireIntegerTotalCount(self):
322+
with self.assertRaisesOpError(
323+
'`total_count` has fractional components.'):
324+
d = tfd.NegativeBinomial(total_count=2.5, probs=0.7, validate_args=True)
325+
self.evaluate(d.log_prob(5))
326+
327+
d2 = tfd.NegativeBinomial(total_count=2.5, probs=0.7, validate_args=True,
328+
require_integer_total_count=False)
329+
self.evaluate(d2.log_prob(5))
330+
321331

322332
@test_util.test_all_tf_execution_regimes
323333
class NegativeBinomialFromVariableTest(test_util.TestCase):

0 commit comments

Comments
 (0)