Skip to content

Commit faba91e

Browse files
davmretensorflower-gardener
authored andcommitted
Support copy and __add__ for StructuralTimeSeries models.
This allows model construction using the sugar: model = tfp.sts.LocalLevel(observed_time_series=series) model += tfp.sts.Seasonal(num_seasons=7, observed_time_series=series) model += tfp.sts.LinearRegression(design_matrix) # etc. This change also adds an `init_parameters` property analogous to the `parameters` property of Distributions and Bijectors. (bare `parameters` was already taken for STS components). This was required in order to support `__add__`. After adding it, it was trivial to also support `copy`. PiperOrigin-RevId: 384768503
1 parent 18c5c9a commit faba91e

File tree

11 files changed

+184
-10
lines changed

11 files changed

+184
-10
lines changed

tensorflow_probability/python/sts/components/autoregressive.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,7 @@ def __init__(self,
324324
name: the name of this model component.
325325
Default value: 'Autoregressive'.
326326
"""
327+
init_parameters = dict(locals())
327328
with tf.name_scope(name or 'Autoregressive') as name:
328329
masked_time_series = None
329330
if observed_time_series is not None:
@@ -386,6 +387,7 @@ def __init__(self,
386387
tfb.Softplus()]))
387388
],
388389
latent_size=order,
390+
init_parameters=init_parameters,
389391
name=name)
390392

391393
@property

tensorflow_probability/python/sts/components/dynamic_regression.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,9 +269,8 @@ def __init__(self,
269269
Default value: 'DynamicLinearRegression'.
270270
271271
"""
272-
272+
init_parameters = dict(locals())
273273
with tf.name_scope(name or 'DynamicLinearRegression') as name:
274-
275274
dtype = dtype_util.common_dtype(
276275
[design_matrix, drift_scale_prior, initial_weights_prior])
277276

@@ -306,6 +305,7 @@ def __init__(self,
306305
tfb.Softplus()]))
307306
],
308307
latent_size=num_features,
308+
init_parameters=init_parameters,
309309
name=name)
310310

311311
@property

tensorflow_probability/python/sts/components/local_level.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ def __init__(self,
281281
name: the name of this model component.
282282
Default value: 'LocalLevel'.
283283
"""
284-
284+
init_parameters = dict(locals())
285285
with tf.name_scope(name or 'LocalLevel') as name:
286286

287287
dtype = dtype_util.common_dtype([level_scale_prior, initial_level_prior])
@@ -319,6 +319,7 @@ def __init__(self,
319319
tfb.Softplus()])),
320320
],
321321
latent_size=1,
322+
init_parameters=init_parameters,
322323
name=name)
323324

324325
@property

tensorflow_probability/python/sts/components/local_linear_trend.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ def __init__(self,
350350
name: the name of this model component.
351351
Default value: 'LocalLinearTrend'.
352352
"""
353-
353+
init_parameters = dict(locals())
354354
with tf.name_scope(name or 'LocalLinearTrend') as name:
355355
_, observed_stddev, observed_initial = (
356356
sts_util.empirical_statistics(observed_time_series)
@@ -400,6 +400,7 @@ def __init__(self,
400400
Parameter('slope_scale', slope_scale_prior, scaled_softplus)
401401
],
402402
latent_size=2,
403+
init_parameters=init_parameters,
403404
name=name)
404405

405406
@property

tensorflow_probability/python/sts/components/regression.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ def __init__(self,
176176
name: the name of this model component.
177177
Default value: 'LinearRegression'.
178178
"""
179+
init_parameters = dict(locals())
179180
with tf.name_scope(name or 'LinearRegression') as name:
180181

181182
if not isinstance(design_matrix, tfl.LinearOperator):
@@ -225,6 +226,7 @@ def __init__(self,
225226
),
226227
],
227228
latent_size=0,
229+
init_parameters=init_parameters,
228230
name=name)
229231

230232
@property
@@ -412,8 +414,8 @@ def __init__(self,
412414
name: the name of this model component.
413415
Default value: 'SparseLinearRegression'.
414416
"""
417+
init_parameters = dict(locals())
415418
with tf.name_scope(name or 'SparseLinearRegression') as name:
416-
417419
if not isinstance(design_matrix, tfl.LinearOperator):
418420
design_matrix = tfl.LinearOperatorFullMatrix(
419421
tf.convert_to_tensor(value=design_matrix, name='design_matrix'),
@@ -468,6 +470,7 @@ def __init__(self,
468470
bijector=tfb.Identity())
469471
],
470472
latent_size=0,
473+
init_parameters=init_parameters,
471474
name=name)
472475

473476
@property

tensorflow_probability/python/sts/components/seasonal.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -812,7 +812,7 @@ def __init__(self,
812812
name: the name of this model component.
813813
Default value: 'Seasonal'.
814814
"""
815-
815+
init_parameters = dict(locals())
816816
with tf.name_scope(name or 'Seasonal') as name:
817817

818818
_, observed_stddev, observed_initial = (
@@ -875,6 +875,7 @@ def __init__(self,
875875
parameters,
876876
latent_size=(num_seasons - 1
877877
if self.constrain_mean_effect_to_zero else num_seasons),
878+
init_parameters=init_parameters,
878879
name=name)
879880

880881
@property

tensorflow_probability/python/sts/components/semilocal_linear_trend.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ def __init__(self,
368368
name: the name of this model component.
369369
Default value: 'SemiLocalLinearTrend'.
370370
"""
371-
371+
init_parameters = dict(locals())
372372
with tf.name_scope(name or 'SemiLocalLinearTrend') as name:
373373
if observed_time_series is not None:
374374
_, observed_stddev, observed_initial = sts_util.empirical_statistics(
@@ -429,6 +429,7 @@ def __init__(self,
429429
autoregressive_coef_bijector),
430430
],
431431
latent_size=2,
432+
init_parameters=init_parameters,
432433
name=name)
433434

434435
@property

tensorflow_probability/python/sts/components/smooth_seasonal.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ def __init__(self,
399399
Default value: 'SmoothSeasonal'.
400400
401401
"""
402-
402+
init_parameters = dict(locals())
403403
with tf.name_scope(name or 'SmoothSeasonal') as name:
404404

405405
_, observed_stddev, observed_initial = (
@@ -436,6 +436,7 @@ def __init__(self,
436436
super(SmoothSeasonal, self).__init__(
437437
parameters=parameters,
438438
latent_size=latent_size,
439+
init_parameters=init_parameters,
439440
name=name)
440441

441442
@property

tensorflow_probability/python/sts/components/sum.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,7 @@ def __init__(self,
421421
Raises:
422422
ValueError: if components do not have unique names.
423423
"""
424-
424+
init_parameters = dict(locals())
425425
with tf.name_scope(name or 'Sum') as name:
426426
if observed_time_series is not None:
427427
observed_mean, observed_stddev, _ = (
@@ -474,6 +474,7 @@ def __init__(self,
474474
parameters=parameters,
475475
latent_size=sum(
476476
[component.latent_size for component in components]),
477+
init_parameters=init_parameters,
477478
name=name)
478479

479480
@property

tensorflow_probability/python/sts/structural_time_series.py

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ class StructuralTimeSeries(object):
4343
structure determined by the child class.
4444
"""
4545

46-
def __init__(self, parameters, latent_size, name='StructuralTimeSeries'):
46+
def __init__(self, parameters, latent_size, init_parameters=None,
47+
name='StructuralTimeSeries'):
4748
"""Construct a specification for a structural time series model.
4849
4950
Args:
@@ -54,13 +55,27 @@ def __init__(self, parameters, latent_size, name='StructuralTimeSeries'):
5455
parameter ordering used by fitting and inference algorithms.
5556
latent_size: Python `int` specifying the dimensionality of the latent
5657
state space for this model.
58+
init_parameters: Python `dict` of parameters used to instantiate this
59+
model component.
5760
name: Python `str` name for this model component.
5861
"""
5962

63+
self._init_parameters = init_parameters
6064
self._parameters = parameters
6165
self._latent_size = latent_size
6266
self._name = name
6367

68+
@property
69+
def init_parameters(self):
70+
"""Parameters used to instantiate this `StructuralTimeSeries`."""
71+
if self._init_parameters is None:
72+
raise ValueError(
73+
'Component has `init_parameters` of `None`. See the built-in '
74+
'components (e.g., `LocalLevel`) for examples of how to properly '
75+
'store parameters to `__init__`.')
76+
return {k: v for k, v in self._init_parameters.items()
77+
if not k.startswith('__') and v is not self}
78+
6479
@property
6580
def parameters(self):
6681
"""List of Parameter(name, prior, bijector) namedtuples for this model."""
@@ -108,6 +123,66 @@ def batch_shape_tensor(self):
108123
batch_shape, param.prior.batch_shape_tensor())
109124
return batch_shape
110125

126+
def __add__(self, other):
127+
"""Models the sum of the series from the two components."""
128+
# Local import to avoid circular dependency.
129+
from tensorflow_probability.python.sts.components import sum as sts_sum # pylint: disable=g-import-not-at-top
130+
131+
sum_kwargs = {}
132+
if isinstance(self, sts_sum.Sum) and isinstance(other, sts_sum.Sum):
133+
# Debatably, a Sum + Sum should sum the `constant_offset` parameters, and
134+
# should model the sum of the component observation noises (e.g., if
135+
# `noise1 ~ N(0, scale=1)` and `noise2 ~ N(0, scale=1)`, then
136+
# `noise1 + noise2 ~ N(0, scale=sqrt(2))`). But this would be surprising
137+
# in the likely-common case where both sums are modeling the same
138+
# observed_time_series. So we instead treat Sum + Sum as a logical
139+
# 'merge' operation on the components, requiring that all other parameters
140+
# match.
141+
_assert_dict_contents_are_equal(
142+
self.init_parameters, other.init_parameters,
143+
ignore=['components'],
144+
message='Cannot add Sum components with different parameter values.')
145+
sum_kwargs = self.init_parameters
146+
elif isinstance(self, sts_sum.Sum):
147+
sum_kwargs = self.init_parameters
148+
elif isinstance(other, sts_sum.Sum):
149+
sum_kwargs = other.init_parameters
150+
else:
151+
# If creating a Sum from scratch, try to infer a heuristic noise prior.
152+
observed_time_series = self.init_parameters.get(
153+
'observed_time_series', None)
154+
if observed_time_series is None:
155+
observed_time_series = other.init_parameters.get(
156+
'observed_time_series', None)
157+
if observed_time_series is None:
158+
raise ValueError('Could not automatically create a `Sum` component '
159+
'because neither summand was initialized with an '
160+
'`observed_time_series`. You may still instantiate '
161+
'the component manually as `tfp.sts.Sum(...).')
162+
sum_kwargs['observed_time_series'] = observed_time_series
163+
164+
my_components = getattr(self, 'components', [self])
165+
other_components = getattr(other, 'components', [other])
166+
sum_kwargs['components'] = list(my_components) + list(other_components)
167+
return sts_sum.Sum(**sum_kwargs)
168+
169+
def copy(self, **override_parameters_kwargs):
170+
"""Creates a deep copy.
171+
172+
Note: the copy distribution may continue to depend on the original
173+
initialization arguments.
174+
175+
Args:
176+
**override_parameters_kwargs: String/value dictionary of initialization
177+
arguments to override with new values.
178+
Returns:
179+
copy: A new instance of `type(self)` initialized from the union
180+
of self.init_parameters and override_parameters_kwargs, i.e.,
181+
`dict(self.init_parameters, **override_parameters_kwargs)`.
182+
"""
183+
parameters = dict(self.init_parameters, **override_parameters_kwargs)
184+
return type(self)(**parameters)
185+
111186
def _canonicalize_param_vals_as_map(self, param_vals):
112187
"""If given an ordered list of parameter values, build a name:value map.
113188
@@ -278,3 +353,22 @@ def log_joint_fn(*param_vals, **param_kwargs):
278353
return param_lp + observation_lp
279354

280355
return log_joint_fn
356+
357+
358+
def _strict_equals(a, b):
359+
try:
360+
return bool(a == b)
361+
except ValueError:
362+
# `a == b` is an array, Tensor, or similar.
363+
return id(a) == id(b)
364+
365+
366+
def _assert_dict_contents_are_equal(
367+
a, b, message, ignore=(), equals_fn=_strict_equals):
368+
combined_keys = set(a.keys()) | set(b.keys())
369+
for k in combined_keys - set(ignore):
370+
a_val = a.get(k, None)
371+
b_val = b.get(k, None)
372+
if not _strict_equals(a_val, b_val):
373+
raise ValueError(message +
374+
' `{}`: {} vs {}.'.format(k, a_val, b_val))

0 commit comments

Comments
 (0)