@@ -64,7 +64,10 @@ def observation_fn(_, particles, extra):
64
64
step = 0 , particles = particles , extra = {'unchanged' : 1 })
65
65
66
66
predicted_state = tfs .ensemble_kalman_filter_predict (
67
- state , transition_fn = transition_fn , inflate_fn = None )
67
+ state ,
68
+ transition_fn = transition_fn ,
69
+ inflate_fn = None ,
70
+ seed = test_util .test_seed ())
68
71
69
72
# Check that extra is correctly propagated.
70
73
self .assertIn ('unchanged' , predicted_state .extra )
@@ -116,13 +119,18 @@ def observation_fn(_, particles, extra):
116
119
117
120
for i in range (10 ):
118
121
state = tfs .ensemble_kalman_filter_predict (
119
- state , transition_fn = transition_fn , inflate_fn = None )
122
+ state ,
123
+ transition_fn = transition_fn ,
124
+ inflate_fn = None ,
125
+ seed = test_util .test_seed ())
120
126
121
127
self .assertIn ('transition_count' , state .extra )
122
128
self .assertEqual (i + 1 , state .extra ['transition_count' ])
123
129
124
130
state = tfs .ensemble_adjustment_kalman_filter_update (
125
- state , observation = [1. * i ], observation_fn = observation_fn )
131
+ state ,
132
+ observation = [1. * i ],
133
+ observation_fn = observation_fn )
126
134
127
135
self .assertIn ('observation_count' , state .extra )
128
136
self .assertEqual (i + 1 , state .extra ['observation_count' ])
0 commit comments