@@ -55,20 +55,23 @@ def _build_iid_normal_model(self, num_timesteps, latent_size,
55
55
observation_variance ):
56
56
"""Build a model whose outputs are IID normal by construction."""
57
57
58
- transition_variance = self ._build_placeholder (transition_variance )
59
- observation_variance = self ._build_placeholder (observation_variance )
58
+ transition_variance = self ._build_placeholder (
59
+ self .dtype (transition_variance ))
60
+ observation_variance = self ._build_placeholder (
61
+ self .dtype (observation_variance ))
60
62
61
63
# Use orthogonal matrices to project a (potentially
62
64
# high-dimensional) latent space of IID normal variables into a
63
65
# low-dimensional observation that is still IID normal.
64
66
random_orthogonal_matrix = lambda : np .linalg .qr (
65
67
np .random .randn (latent_size , latent_size ))[0 ][:observation_size , :]
66
- observation_matrix = self ._build_placeholder (random_orthogonal_matrix ())
68
+ observation_matrix = self ._build_placeholder (
69
+ random_orthogonal_matrix ().astype (self .dtype ))
67
70
68
71
model = lgssm .LinearGaussianStateSpaceModel (
69
72
num_timesteps = num_timesteps ,
70
73
transition_matrix = self ._build_placeholder (
71
- np .zeros ([latent_size , latent_size ])),
74
+ np .zeros ([latent_size , latent_size ]). astype ( self . dtype ) ),
72
75
transition_noise = mvn_diag .MultivariateNormalDiag (
73
76
scale_diag = tf .sqrt (transition_variance ) *
74
77
tf .ones ([latent_size ], dtype = self .dtype )),
@@ -389,23 +392,27 @@ def testExcessiveConcretizationOfParams(self):
389
392
transition_std = 3.0
390
393
observation_std = 5.0
391
394
395
+ dtype = np .float32
396
+
392
397
num_timesteps = tfp_hps .defer_and_count_usage (
393
398
tf .Variable (1 , name = 'num_timesteps' ))
394
399
transition_matrix = tfp_hps .defer_and_count_usage (
395
- tf .Variable (np .eye (latent_size ), name = 'transition_matrix' ))
400
+ tf .Variable (
401
+ np .eye (latent_size ).astype (dtype ), name = 'transition_matrix' ))
396
402
transition_noise_scale = tfp_hps .defer_and_count_usage (
397
403
tf .Variable (
398
- np .full ([latent_size ], transition_std ),
404
+ np .full ([latent_size ], transition_std ). astype ( dtype ) ,
399
405
name = 'transition_noise_scale' ))
400
406
observation_matrix = tfp_hps .defer_and_count_usage (
401
- tf .Variable (np .eye (latent_size ), name = 'observation_matrix' ))
407
+ tf .Variable (
408
+ np .eye (latent_size ).astype (dtype ), name = 'observation_matrix' ))
402
409
observation_noise_scale = tfp_hps .defer_and_count_usage (
403
410
tf .Variable (
404
- np .full ([latent_size ], observation_std ),
411
+ np .full ([latent_size ], observation_std ). astype ( dtype ) ,
405
412
name = 'observation_noise_scale' ))
406
413
initial_state_prior_scale = tfp_hps .defer_and_count_usage (
407
414
tf .Variable (
408
- np .full ([latent_size ], observation_std ),
415
+ np .full ([latent_size ], observation_std ). astype ( dtype ) ,
409
416
name = 'initial_state_prior_scale' ))
410
417
411
418
model = lgssm .LinearGaussianStateSpaceModel (
0 commit comments