Skip to content

Commit d452927

Browse files
csutertensorflower-gardener
authored andcommitted
Make EnsembleAdjustmentKalmanFilterTest tests deterministic.
PiperOrigin-RevId: 381105864
1 parent 14fbf7f commit d452927

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

tensorflow_probability/python/experimental/sequential/ensemble_adjustment_kalman_filter_test.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,10 @@ def observation_fn(_, particles, extra):
6464
step=0, particles=particles, extra={'unchanged': 1})
6565

6666
predicted_state = tfs.ensemble_kalman_filter_predict(
67-
state, transition_fn=transition_fn, inflate_fn=None)
67+
state,
68+
transition_fn=transition_fn,
69+
inflate_fn=None,
70+
seed=test_util.test_seed())
6871

6972
# Check that extra is correctly propagated.
7073
self.assertIn('unchanged', predicted_state.extra)
@@ -116,13 +119,18 @@ def observation_fn(_, particles, extra):
116119

117120
for i in range(10):
118121
state = tfs.ensemble_kalman_filter_predict(
119-
state, transition_fn=transition_fn, inflate_fn=None)
122+
state,
123+
transition_fn=transition_fn,
124+
inflate_fn=None,
125+
seed=test_util.test_seed())
120126

121127
self.assertIn('transition_count', state.extra)
122128
self.assertEqual(i + 1, state.extra['transition_count'])
123129

124130
state = tfs.ensemble_adjustment_kalman_filter_update(
125-
state, observation=[1. * i], observation_fn=observation_fn)
131+
state,
132+
observation=[1. * i],
133+
observation_fn=observation_fn)
126134

127135
self.assertIn('observation_count', state.extra)
128136
self.assertEqual(i + 1, state.extra['observation_count'])

0 commit comments

Comments
 (0)