Skip to content

Commit 14db88b

Browse files
srvasudetensorflower-gardener
authored andcommitted
Use dtype_util inside LGSSM to infer dtype.
PiperOrigin-RevId: 472602722
1 parent 24301c5 commit 14db88b

File tree

2 files changed

+27
-11
lines changed

2 files changed

+27
-11
lines changed

tensorflow_probability/python/distributions/linear_gaussian_ssm.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -393,8 +393,17 @@ def __init__(self,
393393
mask, dtype_hint=tf.bool, name='mask')
394394
self._experimental_parallelize = experimental_parallelize
395395

396-
# TODO(b/78475680): Friendly dtype inference.
397-
dtype = initial_state_prior.dtype
396+
dtype_list = [initial_state_prior,
397+
observation_matrix,
398+
transition_matrix,
399+
transition_noise,
400+
observation_noise]
401+
402+
# Infer dtype from time invariant objects. This list will be non-empty
403+
# since it will always include `initial_state_prior`.
404+
dtype = dtype_util.common_dtype(
405+
list(filter(lambda x: not callable(x), dtype_list)),
406+
dtype_hint=tf.float32)
398407

399408
# Internally, the transition and observation matrices are
400409
# canonicalized as callables returning a LinearOperator. This

tensorflow_probability/python/distributions/linear_gaussian_ssm_test.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,20 +55,23 @@ def _build_iid_normal_model(self, num_timesteps, latent_size,
5555
observation_variance):
5656
"""Build a model whose outputs are IID normal by construction."""
5757

58-
transition_variance = self._build_placeholder(transition_variance)
59-
observation_variance = self._build_placeholder(observation_variance)
58+
transition_variance = self._build_placeholder(
59+
self.dtype(transition_variance))
60+
observation_variance = self._build_placeholder(
61+
self.dtype(observation_variance))
6062

6163
# Use orthogonal matrices to project a (potentially
6264
# high-dimensional) latent space of IID normal variables into a
6365
# low-dimensional observation that is still IID normal.
6466
random_orthogonal_matrix = lambda: np.linalg.qr(
6567
np.random.randn(latent_size, latent_size))[0][:observation_size, :]
66-
observation_matrix = self._build_placeholder(random_orthogonal_matrix())
68+
observation_matrix = self._build_placeholder(
69+
random_orthogonal_matrix().astype(self.dtype))
6770

6871
model = lgssm.LinearGaussianStateSpaceModel(
6972
num_timesteps=num_timesteps,
7073
transition_matrix=self._build_placeholder(
71-
np.zeros([latent_size, latent_size])),
74+
np.zeros([latent_size, latent_size]).astype(self.dtype)),
7275
transition_noise=mvn_diag.MultivariateNormalDiag(
7376
scale_diag=tf.sqrt(transition_variance) *
7477
tf.ones([latent_size], dtype=self.dtype)),
@@ -389,23 +392,27 @@ def testExcessiveConcretizationOfParams(self):
389392
transition_std = 3.0
390393
observation_std = 5.0
391394

395+
dtype = np.float32
396+
392397
num_timesteps = tfp_hps.defer_and_count_usage(
393398
tf.Variable(1, name='num_timesteps'))
394399
transition_matrix = tfp_hps.defer_and_count_usage(
395-
tf.Variable(np.eye(latent_size), name='transition_matrix'))
400+
tf.Variable(
401+
np.eye(latent_size).astype(dtype), name='transition_matrix'))
396402
transition_noise_scale = tfp_hps.defer_and_count_usage(
397403
tf.Variable(
398-
np.full([latent_size], transition_std),
404+
np.full([latent_size], transition_std).astype(dtype),
399405
name='transition_noise_scale'))
400406
observation_matrix = tfp_hps.defer_and_count_usage(
401-
tf.Variable(np.eye(latent_size), name='observation_matrix'))
407+
tf.Variable(
408+
np.eye(latent_size).astype(dtype), name='observation_matrix'))
402409
observation_noise_scale = tfp_hps.defer_and_count_usage(
403410
tf.Variable(
404-
np.full([latent_size], observation_std),
411+
np.full([latent_size], observation_std).astype(dtype),
405412
name='observation_noise_scale'))
406413
initial_state_prior_scale = tfp_hps.defer_and_count_usage(
407414
tf.Variable(
408-
np.full([latent_size], observation_std),
415+
np.full([latent_size], observation_std).astype(dtype),
409416
name='initial_state_prior_scale'))
410417

411418
model = lgssm.LinearGaussianStateSpaceModel(

0 commit comments

Comments
 (0)