Skip to content

Commit 60edbd3

Browse files
brianwa84jburnim
authored andcommitted
Numpy fix for Multinomial sampling.
PiperOrigin-RevId: 346338660
1 parent 50a1ca9 commit 60edbd3

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tensorflow_probability/python/distributions/multinomial.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,7 @@ def fn(i, num_trials, consumed_prob, accum):
413413

414414
num_trials = tf.cast(num_trials, probs.dtype)
415415
# Pre-broadcast with probs
416-
num_trials += tf.zeros_like(probs[..., 0])
416+
num_trials = num_trials + tf.zeros_like(probs[..., 0])
417417
# Pre-enlarge for different output samples
418418
num_trials = _replicate_along_left(num_trials, num_samples)
419419
i = tf.constant(0)

0 commit comments

Comments
 (0)