@@ -24,7 +24,7 @@ def __init__(self,
2424 process_covariance : Optional [Covariance ] = None ,
2525 initial_covariance : Optional [Covariance ] = None ,
2626 measure_funs : Optional [dict [str , str ]] = None ,
27- adaptive_measure_var : bool = False ):
27+ adaptive_scaling : bool = False ):
2828
2929 if initial_covariance is None :
3030 initial_covariance = Covariance .from_processes (processes , cov_type = 'initial' )
@@ -37,7 +37,7 @@ def __init__(self,
3737 measures = measures ,
3838 measure_covariance = measure_covariance ,
3939 measure_funs = measure_funs ,
40- adaptive_measure_var = adaptive_measure_var ,
40+ adaptive_scaling = adaptive_scaling ,
4141 )
4242 self .process_covariance = process_covariance .set_id ('process_covariance' )
4343 self .initial_covariance = initial_covariance .set_id ('initial_covariance' )
@@ -46,6 +46,7 @@ def _predict_cov(self,
4646 cov : torch .Tensor ,
4747 transition_mat : torch .Tensor ,
4848 Q : torch .Tensor ,
49+ scaling : Optional [torch .Tensor ] = None ,
4950 mask : Optional [torch .Tensor ] = None ) -> torch .Tensor :
5051 if mask is None or mask .all ():
5152 mask = slice (None )
@@ -100,20 +101,24 @@ def _parse_kwargs(self,
100101 )
101102
102103 # process-variance:
103- measure_scaling = torch .diag_embed (self ._get_measure_scaling ().unsqueeze (0 ))
104104 pcov_kwargs = {}
105105 if self .process_covariance .expected_kwargs :
106106 pcov_kwargs = {k : kwargs [k ] for k in self .process_covariance .expected_kwargs }
107107 used_keys |= set (pcov_kwargs )
108+
109+ mcov = self .measure_covariance ({}, num_groups = 1 , num_times = 1 , _ignore_input = True )[0 , 0 ]
110+ measure_std = mcov .diagonal (dim1 = - 2 , dim2 = - 1 ).sqrt ()
111+ for idx in self .measure_covariance .empty_idx :
112+ measure_std [idx ] = torch .ones_like (measure_std [idx ]) # empty measures have no variance, so set to 1
113+
108114 if pcov_kwargs :
109- measure_scaling = measure_scaling .unsqueeze (0 )
110115 pcov_raw = self .process_covariance (pcov_kwargs , num_groups = num_groups , num_times = num_timesteps )
111- Qs = measure_scaling @ pcov_raw @ measure_scaling
116+ Qs = self . _apply_cov_scaling ( pcov_raw , scaling = measure_std , is_process_cov = True )
112117 predict_kwargs ['Q' ] = Qs .unbind (1 )
113118 else :
114119 # faster if not time-varying
115- pcov_raw = self .process_covariance (pcov_kwargs , num_groups = num_groups , num_times = 1 )
116- Qs = measure_scaling @ pcov_raw . squeeze ( 1 ) @ measure_scaling
120+ pcov_raw = self .process_covariance (pcov_kwargs , num_groups = num_groups , num_times = 1 ). squeeze ( 1 )
121+ Qs = self . _apply_cov_scaling ( pcov_raw , scaling = measure_std , is_process_cov = True )
117122 predict_kwargs ['Q' ] = [Qs ] * num_timesteps
118123
119124 return predict_kwargs , update_kwargs , used_keys
0 commit comments