Skip to content

Commit 6e83921

Browse files
sharadmvtensorflower-gardener
authored andcommitted
Thread STS name into its joint distribution
PiperOrigin-RevId: 427855341
1 parent afbc1a5 commit 6e83921

File tree

2 files changed

+57
-31
lines changed

2 files changed

+57
-31
lines changed

tensorflow_probability/python/sts/structural_time_series.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,8 @@ def state_space_model_likelihood(**param_vals):
387387
# Likelihood.
388388
[('observed_time_series', state_space_model_likelihood)]),
389389
use_vectorized_map=False,
390-
batch_ndims=batch_ndims))
390+
batch_ndims=batch_ndims,
391+
name=self.name))
391392

392393
if observed_time_series is not None:
393394
return joint_distribution.experimental_pin(

tensorflow_probability/python/sts/structural_time_series_test.py

Lines changed: 55 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,11 @@ def test_prior_sample(self):
307307
2,
308308
] + param.prior.batch_shape.as_list() + param.prior.event_shape.as_list())
309309

310+
def test_joint_distribution_name(self):
311+
model = self._build_sts(name='foo')
312+
jd = model.joint_distribution(num_timesteps=5)
313+
self.assertEqual('foo', jd.name)
314+
310315
def test_joint_distribution_log_prob(self):
311316
model = self._build_sts(
312317
# Dummy series to build the model with float64 priors. Working in
@@ -402,14 +407,17 @@ def test_add_component(self):
402407
@test_util.test_all_tf_execution_regimes
403408
class AutoregressiveTest(test_util.TestCase, _StsTestHarness):
404409

405-
def _build_sts(self, observed_time_series=None):
406-
return Autoregressive(order=3, observed_time_series=observed_time_series)
410+
def _build_sts(self, observed_time_series=None, **kwargs):
411+
return Autoregressive(
412+
order=3,
413+
observed_time_series=observed_time_series,
414+
**kwargs)
407415

408416

409417
@test_util.test_all_tf_execution_regimes
410418
class ARMATest(test_util.TestCase, _StsTestHarness):
411419

412-
def _build_sts(self, observed_time_series=None):
420+
def _build_sts(self, observed_time_series=None, **kwargs):
413421
one = 1.
414422
if observed_time_series is not None:
415423
observed_time_series = (
@@ -421,36 +429,42 @@ def _build_sts(self, observed_time_series=None):
421429
ma_order=1,
422430
integration_degree=0,
423431
level_drift_prior=tfd.Normal(loc=one, scale=one),
424-
observed_time_series=observed_time_series)
432+
observed_time_series=observed_time_series,
433+
**kwargs)
425434

426435

427436
@test_util.test_all_tf_execution_regimes
428437
class ARIMATest(test_util.TestCase, _StsTestHarness):
429438

430-
def _build_sts(self, observed_time_series=None):
439+
def _build_sts(self, observed_time_series=None, **kwargs):
431440
return AutoregressiveIntegratedMovingAverage(
432441
ar_order=1, ma_order=2, integration_degree=2,
433-
observed_time_series=observed_time_series)
442+
observed_time_series=observed_time_series,
443+
**kwargs)
434444

435445

436446
@test_util.test_all_tf_execution_regimes
437447
class LocalLevelTest(test_util.TestCase, _StsTestHarness):
438448

439-
def _build_sts(self, observed_time_series=None):
440-
return LocalLevel(observed_time_series=observed_time_series)
449+
def _build_sts(self, observed_time_series=None, **kwargs):
450+
return LocalLevel(
451+
observed_time_series=observed_time_series,
452+
**kwargs)
441453

442454

443455
@test_util.test_all_tf_execution_regimes
444456
class LocalLinearTrendTest(test_util.TestCase, _StsTestHarness):
445457

446-
def _build_sts(self, observed_time_series=None):
447-
return LocalLinearTrend(observed_time_series=observed_time_series)
458+
def _build_sts(self, observed_time_series=None, **kwargs):
459+
return LocalLinearTrend(
460+
observed_time_series=observed_time_series,
461+
**kwargs)
448462

449463

450464
@test_util.test_all_tf_execution_regimes
451465
class SeasonalTest(test_util.TestCase, _StsTestHarness):
452466

453-
def _build_sts(self, observed_time_series=None):
467+
def _build_sts(self, observed_time_series=None, **kwargs):
454468
# Note that a Seasonal model with `num_steps_per_season > 1` would have
455469
# deterministic dependence between timesteps, so evaluating `log_prob` of an
456470
# arbitrary time series leads to Cholesky decomposition errors unless the
@@ -461,79 +475,87 @@ def _build_sts(self, observed_time_series=None):
461475
return Seasonal(num_seasons=7,
462476
num_steps_per_season=1,
463477
observed_time_series=observed_time_series,
464-
constrain_mean_effect_to_zero=False)
478+
constrain_mean_effect_to_zero=False,
479+
**kwargs)
465480

466481

467482
@test_util.test_all_tf_execution_regimes
468483
class SeasonalWithZeroMeanConstraintTest(test_util.TestCase, _StsTestHarness):
469484

470-
def _build_sts(self, observed_time_series=None):
485+
def _build_sts(self, observed_time_series=None, **kwargs):
471486
return Seasonal(num_seasons=7,
472487
num_steps_per_season=1,
473488
observed_time_series=observed_time_series,
474-
constrain_mean_effect_to_zero=True)
489+
constrain_mean_effect_to_zero=True,
490+
**kwargs)
475491

476492

477493
@test_util.test_all_tf_execution_regimes
478494
class SeasonalWithMultipleStepsAndNoiseTest(test_util.TestCase,
479495
_StsTestHarness):
480496

481-
def _build_sts(self, observed_time_series=None):
497+
def _build_sts(self, observed_time_series=None, **kwargs):
482498
day_of_week = tfp.sts.Seasonal(num_seasons=7,
483499
num_steps_per_season=24,
484500
allow_drift=False,
485501
observed_time_series=observed_time_series,
486502
name='day_of_week')
487503
return tfp.sts.Sum(components=[day_of_week],
488-
observed_time_series=observed_time_series)
504+
observed_time_series=observed_time_series,
505+
**kwargs)
489506

490507

491508
@test_util.test_all_tf_execution_regimes
492509
class SemiLocalLinearTrendTest(test_util.TestCase, _StsTestHarness):
493510

494-
def _build_sts(self, observed_time_series=None):
495-
return SemiLocalLinearTrend(observed_time_series=observed_time_series)
511+
def _build_sts(self, observed_time_series=None, **kwargs):
512+
return SemiLocalLinearTrend(
513+
observed_time_series=observed_time_series,
514+
**kwargs)
496515

497516

498517
@test_util.test_all_tf_execution_regimes
499518
class SmoothSeasonalTest(test_util.TestCase, _StsTestHarness):
500519

501-
def _build_sts(self, observed_time_series=None):
520+
def _build_sts(self, observed_time_series=None, **kwargs):
502521
return SmoothSeasonal(period=42,
503522
frequency_multipliers=[1, 2, 4],
504-
observed_time_series=observed_time_series)
523+
observed_time_series=observed_time_series,
524+
**kwargs)
505525

506526

507527
@test_util.test_all_tf_execution_regimes
508528
class SmoothSeasonalWithNoDriftTest(test_util.TestCase, _StsTestHarness):
509529

510-
def _build_sts(self, observed_time_series=None):
530+
def _build_sts(self, observed_time_series=None, **kwargs):
511531
smooth_seasonal = SmoothSeasonal(period=42,
512532
frequency_multipliers=[1, 2, 4],
513533
allow_drift=False,
514534
observed_time_series=observed_time_series)
515535
# The test harness doesn't like models with no parameters, so wrap with Sum.
516536
return tfp.sts.Sum([smooth_seasonal],
517-
observed_time_series=observed_time_series)
537+
observed_time_series=observed_time_series,
538+
**kwargs)
518539

519540

520541
@test_util.test_all_tf_execution_regimes
521542
class SumTest(test_util.TestCase, _StsTestHarness):
522543

523-
def _build_sts(self, observed_time_series=None):
544+
def _build_sts(self, observed_time_series=None, **kwargs):
524545
first_component = LocalLinearTrend(
525546
observed_time_series=observed_time_series, name='first_component')
526547
second_component = LocalLinearTrend(
527548
observed_time_series=observed_time_series, name='second_component')
528549
return Sum(
529550
components=[first_component, second_component],
530-
observed_time_series=observed_time_series)
551+
observed_time_series=observed_time_series,
552+
**kwargs)
531553

532554

533555
@test_util.test_all_tf_execution_regimes
534556
class LinearRegressionTest(test_util.TestCase, _StsTestHarness):
535557

536-
def _build_sts(self, observed_time_series=None):
558+
def _build_sts(self, observed_time_series=None, **kwargs):
537559
max_timesteps = 100
538560
num_features = 3
539561

@@ -557,13 +579,14 @@ def _build_sts(self, observed_time_series=None):
557579
max_timesteps, num_features).astype(dtype),
558580
weights_prior=prior)
559581
return Sum(components=[regression],
560-
observed_time_series=observed_time_series)
582+
observed_time_series=observed_time_series,
583+
**kwargs)
561584

562585

563586
@test_util.test_all_tf_execution_regimes
564587
class SparseLinearRegressionTest(test_util.TestCase, _StsTestHarness):
565588

566-
def _build_sts(self, observed_time_series=None):
589+
def _build_sts(self, observed_time_series=None, **kwargs):
567590
max_timesteps = 100
568591
num_features = 3
569592

@@ -584,13 +607,14 @@ def _build_sts(self, observed_time_series=None):
584607
max_timesteps, num_features).astype(dtype),
585608
weights_batch_shape=batch_shape)
586609
return Sum(components=[regression],
587-
observed_time_series=observed_time_series)
610+
observed_time_series=observed_time_series,
611+
**kwargs)
588612

589613

590614
@test_util.test_all_tf_execution_regimes
591615
class DynamicLinearRegressionTest(test_util.TestCase, _StsTestHarness):
592616

593-
def _build_sts(self, observed_time_series=None):
617+
def _build_sts(self, observed_time_series=None, **kwargs):
594618
max_timesteps = 100
595619
num_features = 3
596620

@@ -604,7 +628,8 @@ def _build_sts(self, observed_time_series=None):
604628
return DynamicLinearRegression(
605629
design_matrix=np.random.randn(
606630
max_timesteps, num_features).astype(dtype),
607-
observed_time_series=observed_time_series)
631+
observed_time_series=observed_time_series,
632+
**kwargs)
608633

609634

610635
if __name__ == '__main__':

0 commit comments

Comments
 (0)