@@ -307,6 +307,11 @@ def test_prior_sample(self):
307307 2 ,
308308 ] + param .prior .batch_shape .as_list () + param .prior .event_shape .as_list ())
309309
310+ def test_joint_distribution_name (self ):
311+ model = self ._build_sts (name = 'foo' )
312+ jd = model .joint_distribution (num_timesteps = 5 )
313+ self .assertEqual ('foo' , jd .name )
314+
310315 def test_joint_distribution_log_prob (self ):
311316 model = self ._build_sts (
312317 # Dummy series to build the model with float64 priors. Working in
@@ -402,14 +407,17 @@ def test_add_component(self):
402407@test_util .test_all_tf_execution_regimes
403408class AutoregressiveTest (test_util .TestCase , _StsTestHarness ):
404409
405- def _build_sts (self , observed_time_series = None ):
406- return Autoregressive (order = 3 , observed_time_series = observed_time_series )
410+ def _build_sts (self , observed_time_series = None , ** kwargs ):
411+ return Autoregressive (
412+ order = 3 ,
413+ observed_time_series = observed_time_series ,
414+ ** kwargs )
407415
408416
409417@test_util .test_all_tf_execution_regimes
410418class ARMATest (test_util .TestCase , _StsTestHarness ):
411419
412- def _build_sts (self , observed_time_series = None ):
420+ def _build_sts (self , observed_time_series = None , ** kwargs ):
413421 one = 1.
414422 if observed_time_series is not None :
415423 observed_time_series = (
@@ -421,36 +429,42 @@ def _build_sts(self, observed_time_series=None):
421429 ma_order = 1 ,
422430 integration_degree = 0 ,
423431 level_drift_prior = tfd .Normal (loc = one , scale = one ),
424- observed_time_series = observed_time_series )
432+ observed_time_series = observed_time_series ,
433+ ** kwargs )
425434
426435
427436@test_util .test_all_tf_execution_regimes
428437class ARIMATest (test_util .TestCase , _StsTestHarness ):
429438
430- def _build_sts (self , observed_time_series = None ):
439+ def _build_sts (self , observed_time_series = None , ** kwargs ):
431440 return AutoregressiveIntegratedMovingAverage (
432441 ar_order = 1 , ma_order = 2 , integration_degree = 2 ,
433- observed_time_series = observed_time_series )
442+ observed_time_series = observed_time_series ,
443+ ** kwargs )
434444
435445
436446@test_util .test_all_tf_execution_regimes
437447class LocalLevelTest (test_util .TestCase , _StsTestHarness ):
438448
439- def _build_sts (self , observed_time_series = None ):
440- return LocalLevel (observed_time_series = observed_time_series )
449+ def _build_sts (self , observed_time_series = None , ** kwargs ):
450+ return LocalLevel (
451+ observed_time_series = observed_time_series ,
452+ ** kwargs )
441453
442454
443455@test_util .test_all_tf_execution_regimes
444456class LocalLinearTrendTest (test_util .TestCase , _StsTestHarness ):
445457
446- def _build_sts (self , observed_time_series = None ):
447- return LocalLinearTrend (observed_time_series = observed_time_series )
458+ def _build_sts (self , observed_time_series = None , ** kwargs ):
459+ return LocalLinearTrend (
460+ observed_time_series = observed_time_series ,
461+ ** kwargs )
448462
449463
450464@test_util .test_all_tf_execution_regimes
451465class SeasonalTest (test_util .TestCase , _StsTestHarness ):
452466
453- def _build_sts (self , observed_time_series = None ):
467+ def _build_sts (self , observed_time_series = None , ** kwargs ):
454468 # Note that a Seasonal model with `num_steps_per_season > 1` would have
455469 # deterministic dependence between timesteps, so evaluating `log_prob` of an
456470 # arbitrary time series leads to Cholesky decomposition errors unless the
@@ -461,79 +475,87 @@ def _build_sts(self, observed_time_series=None):
461475 return Seasonal (num_seasons = 7 ,
462476 num_steps_per_season = 1 ,
463477 observed_time_series = observed_time_series ,
464- constrain_mean_effect_to_zero = False )
478+ constrain_mean_effect_to_zero = False ,
479+ ** kwargs )
465480
466481
467482@test_util .test_all_tf_execution_regimes
468483class SeasonalWithZeroMeanConstraintTest (test_util .TestCase , _StsTestHarness ):
469484
470- def _build_sts (self , observed_time_series = None ):
485+ def _build_sts (self , observed_time_series = None , ** kwargs ):
471486 return Seasonal (num_seasons = 7 ,
472487 num_steps_per_season = 1 ,
473488 observed_time_series = observed_time_series ,
474- constrain_mean_effect_to_zero = True )
489+ constrain_mean_effect_to_zero = True ,
490+ ** kwargs )
475491
476492
477493@test_util .test_all_tf_execution_regimes
478494class SeasonalWithMultipleStepsAndNoiseTest (test_util .TestCase ,
479495 _StsTestHarness ):
480496
481- def _build_sts (self , observed_time_series = None ):
497+ def _build_sts (self , observed_time_series = None , ** kwargs ):
482498 day_of_week = tfp .sts .Seasonal (num_seasons = 7 ,
483499 num_steps_per_season = 24 ,
484500 allow_drift = False ,
485501 observed_time_series = observed_time_series ,
486502 name = 'day_of_week' )
487503 return tfp .sts .Sum (components = [day_of_week ],
488- observed_time_series = observed_time_series )
504+ observed_time_series = observed_time_series ,
505+ ** kwargs )
489506
490507
491508@test_util .test_all_tf_execution_regimes
492509class SemiLocalLinearTrendTest (test_util .TestCase , _StsTestHarness ):
493510
494- def _build_sts (self , observed_time_series = None ):
495- return SemiLocalLinearTrend (observed_time_series = observed_time_series )
511+ def _build_sts (self , observed_time_series = None , ** kwargs ):
512+ return SemiLocalLinearTrend (
513+ observed_time_series = observed_time_series ,
514+ ** kwargs )
496515
497516
498517@test_util .test_all_tf_execution_regimes
499518class SmoothSeasonalTest (test_util .TestCase , _StsTestHarness ):
500519
501- def _build_sts (self , observed_time_series = None ):
520+ def _build_sts (self , observed_time_series = None , ** kwargs ):
502521 return SmoothSeasonal (period = 42 ,
503522 frequency_multipliers = [1 , 2 , 4 ],
504- observed_time_series = observed_time_series )
523+ observed_time_series = observed_time_series ,
524+ ** kwargs )
505525
506526
507527@test_util .test_all_tf_execution_regimes
508528class SmoothSeasonalWithNoDriftTest (test_util .TestCase , _StsTestHarness ):
509529
510- def _build_sts (self , observed_time_series = None ):
530+ def _build_sts (self , observed_time_series = None , ** kwargs ):
511531 smooth_seasonal = SmoothSeasonal (period = 42 ,
512532 frequency_multipliers = [1 , 2 , 4 ],
513533 allow_drift = False ,
514534 observed_time_series = observed_time_series )
515535 # The test harness doesn't like models with no parameters, so wrap with Sum.
516536 return tfp .sts .Sum ([smooth_seasonal ],
517- observed_time_series = observed_time_series )
537+ observed_time_series = observed_time_series ,
538+ ** kwargs )
518539
519540
520541@test_util .test_all_tf_execution_regimes
521542class SumTest (test_util .TestCase , _StsTestHarness ):
522543
523- def _build_sts (self , observed_time_series = None ):
544+ def _build_sts (self , observed_time_series = None , ** kwargs ):
524545 first_component = LocalLinearTrend (
525546 observed_time_series = observed_time_series , name = 'first_component' )
526547 second_component = LocalLinearTrend (
527548 observed_time_series = observed_time_series , name = 'second_component' )
528549 return Sum (
529550 components = [first_component , second_component ],
530- observed_time_series = observed_time_series )
551+ observed_time_series = observed_time_series ,
552+ ** kwargs )
531553
532554
533555@test_util .test_all_tf_execution_regimes
534556class LinearRegressionTest (test_util .TestCase , _StsTestHarness ):
535557
536- def _build_sts (self , observed_time_series = None ):
558+ def _build_sts (self , observed_time_series = None , ** kwargs ):
537559 max_timesteps = 100
538560 num_features = 3
539561
@@ -557,13 +579,14 @@ def _build_sts(self, observed_time_series=None):
557579 max_timesteps , num_features ).astype (dtype ),
558580 weights_prior = prior )
559581 return Sum (components = [regression ],
560- observed_time_series = observed_time_series )
582+ observed_time_series = observed_time_series ,
583+ ** kwargs )
561584
562585
563586@test_util .test_all_tf_execution_regimes
564587class SparseLinearRegressionTest (test_util .TestCase , _StsTestHarness ):
565588
566- def _build_sts (self , observed_time_series = None ):
589+ def _build_sts (self , observed_time_series = None , ** kwargs ):
567590 max_timesteps = 100
568591 num_features = 3
569592
@@ -584,13 +607,14 @@ def _build_sts(self, observed_time_series=None):
584607 max_timesteps , num_features ).astype (dtype ),
585608 weights_batch_shape = batch_shape )
586609 return Sum (components = [regression ],
587- observed_time_series = observed_time_series )
610+ observed_time_series = observed_time_series ,
611+ ** kwargs )
588612
589613
590614@test_util .test_all_tf_execution_regimes
591615class DynamicLinearRegressionTest (test_util .TestCase , _StsTestHarness ):
592616
593- def _build_sts (self , observed_time_series = None ):
617+ def _build_sts (self , observed_time_series = None , ** kwargs ):
594618 max_timesteps = 100
595619 num_features = 3
596620
@@ -604,7 +628,8 @@ def _build_sts(self, observed_time_series=None):
604628 return DynamicLinearRegression (
605629 design_matrix = np .random .randn (
606630 max_timesteps , num_features ).astype (dtype ),
607- observed_time_series = observed_time_series )
631+ observed_time_series = observed_time_series ,
632+ ** kwargs )
608633
609634
610635if __name__ == '__main__' :
0 commit comments