-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Open
Description
Hi all,
I have question based on dtype change in sts.Seasonal, namely I have created multiple of objects like:
Parameter: local_linear_trend/_slope_scale
Prior: tfp.distributions.LogNormal("slope_scale_prior", batch_shape=[], event_shape=[], dtype=float64)
--------------------------------------------------------------------------------------------------------------------------------------------
Parameter: month_of_year/_drift_scale
Prior: tfp.distributions.LogNormal("LogNormal", batch_shape=[], event_shape=[], dtype=float32)
--------------------------------------------------------------------------------------------------------------------------------------------
Parameter: day_of_week/_drift_scale
Prior: tfp.distributions.LogNormal("LogNormal", batch_shape=[], event_shape=[], dtype=float64)
As you can see in the case of month_of_year/_drift_scale Prior dtype is float32. It looks like it is not allowed during training as there is an exception:
ValueError: ConstrainedSeasonalStateSpaceModel, type=<dtype: 'float32'>, must be of the same type (<dtype: 'float64'>) as LocalLinearTrendStateSpaceModel.
Is there any solution to change dtype from float32 to float64?
Metadata
Metadata
Assignees
Labels
No labels