@@ -60,6 +60,7 @@ def __init__(self,
60
60
probs = None ,
61
61
validate_args = False ,
62
62
allow_nan_stats = True ,
63
+ require_integer_total_count = True ,
63
64
name = 'NegativeBinomial' ):
64
65
"""Construct NegativeBinomial distributions.
65
66
@@ -89,6 +90,8 @@ def __init__(self,
89
90
(e.g., mean, mode, variance) use the value "`NaN`" to indicate the
90
91
result is undefined. When `False`, an exception is raised if one or
91
92
more of the statistic's batch members are undefined.
93
+ require_integer_total_count: Python `bool`, default `True`. When `True`,
94
+ the total_count parameter is required to be integer.
92
95
name: Python `str` name prefixed to Ops created by this class.
93
96
"""
94
97
@@ -105,6 +108,7 @@ def __init__(self,
105
108
logits , dtype = dtype , name = 'logits' )
106
109
self ._total_count = tensor_util .convert_nonref_to_tensor (
107
110
total_count , dtype = dtype , name = 'total_count' )
111
+ self ._require_integer_total_count = require_integer_total_count
108
112
109
113
super (NegativeBinomial , self ).__init__ (
110
114
dtype = dtype ,
@@ -255,7 +259,7 @@ def _default_event_space_bijector(self):
255
259
def _parameter_control_dependencies (self , is_init ):
256
260
return maybe_assert_negative_binomial_param_correctness (
257
261
is_init , self .validate_args , self ._total_count , self ._probs ,
258
- self ._logits )
262
+ self ._logits , self . _require_integer_total_count )
259
263
260
264
def _sample_control_dependencies (self , x ):
261
265
"""Check counts for proper shape and values, then return tensor version."""
@@ -267,7 +271,8 @@ def _sample_control_dependencies(self, x):
267
271
268
272
269
273
def maybe_assert_negative_binomial_param_correctness (
270
- is_init , validate_args , total_count , probs , logits ):
274
+ is_init , validate_args , total_count , probs , logits ,
275
+ require_integer_total_count ):
271
276
"""Return assertions for `NegativeBinomial`-type distributions."""
272
277
if is_init :
273
278
x , name = (probs , 'probs' ) if logits is None else (logits , 'logits' )
@@ -284,11 +289,13 @@ def maybe_assert_negative_binomial_param_correctness(
284
289
assertions .extend ([
285
290
assert_util .assert_positive (
286
291
total_count ,
287
- message = '`total_count` has components less than or equal to 0.' ),
288
- distribution_util .assert_integer_form (
289
- total_count ,
290
- message = '`total_count` has fractional components.' )
292
+ message = '`total_count` has components less than or equal to 0.' )
291
293
])
294
+ if require_integer_total_count :
295
+ assertions .extend ([
296
+ distribution_util .assert_integer_form (
297
+ total_count , message = '`total_count` has fractional components.' )
298
+ ])
292
299
if probs is not None :
293
300
if is_init != tensor_util .is_ref (probs ):
294
301
probs = tf .convert_to_tensor (probs )
0 commit comments