Skip to content

Change dtype of prior parameter of sts.Seasonal from float32 to float64 #2003

@CleverEskimo

Description

@CleverEskimo

Hi all,

I have question based on dtype change in sts.Seasonal, namely I have created multiple of objects like:

Parameter: local_linear_trend/_slope_scale
Prior: tfp.distributions.LogNormal("slope_scale_prior", batch_shape=[], event_shape=[], dtype=float64)
--------------------------------------------------------------------------------------------------------------------------------------------
Parameter: month_of_year/_drift_scale
Prior: tfp.distributions.LogNormal("LogNormal", batch_shape=[], event_shape=[], dtype=float32)
--------------------------------------------------------------------------------------------------------------------------------------------
Parameter: day_of_week/_drift_scale
Prior: tfp.distributions.LogNormal("LogNormal", batch_shape=[], event_shape=[], dtype=float64)

As you can see in the case of month_of_year/_drift_scale Prior dtype is float32. It looks like it is not allowed during training as there is an exception:

ValueError: ConstrainedSeasonalStateSpaceModel, type=<dtype: 'float32'>, must be of the same type (<dtype: 'float64'>) as LocalLinearTrendStateSpaceModel.

Is there any solution to change dtype from float32 to float64?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions