Skip to content

Commit 4def218

Browse files
committed
go back to precomputing measure-scaling
1 parent 5740ee2 commit 4def218

File tree

5 files changed

+102
-77
lines changed

5 files changed

+102
-77
lines changed

torchcast/exp_smooth/exp_smooth.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,24 +22,23 @@ def __init__(self,
2222
measure_covariance: Optional[Covariance] = None,
2323
smoothing_matrix: Optional[SmoothingMatrix] = None,
2424
measure_funs: Optional[dict[str, str]] = None,
25-
adaptive_measure_var: bool = False):
25+
adaptive_scaling: bool = False):
2626

2727
super().__init__(
2828
processes=processes,
2929
measures=measures,
3030
measure_covariance=measure_covariance,
3131
measure_funs=measure_funs,
32-
adaptive_measure_var=adaptive_measure_var,
32+
adaptive_scaling=adaptive_scaling,
3333
)
3434
if smoothing_matrix is None:
3535
smoothing_matrix = SmoothingMatrix.from_measures_and_processes(measures=measures, processes=processes)
3636
self.smoothing_matrix = smoothing_matrix.set_id('smoothing_matrix')
3737

3838
def initial_covariance(self, inputs: dict, num_groups: int, num_times: int, _ignore_input: bool = False) -> Tensor:
39-
# initial covariance is always zero. this will be replaced by the 1-step-ahead covariance in the first call to
40-
# predict
41-
ms = self._get_measure_scaling()
42-
return torch.zeros((num_groups, num_times, self.state_rank, self.state_rank), dtype=ms.dtype, device=ms.device)
39+
# initial covariance is always zero. this will be replaced by the 1-step covariance in the first call to predict
40+
m = list(self.processes.values())[0].initial_mean # get a parameter, any parameter, to get device
41+
return torch.zeros((num_groups, num_times, self.state_rank, self.state_rank), dtype=m.dtype, device=m.device)
4342

4443
def _mask_mats(self,
4544
groups: torch.Tensor,
@@ -105,10 +104,14 @@ def _predict_cov(self,
105104
cov: torch.Tensor,
106105
transition_mat: torch.Tensor,
107106
cov1step: torch.Tensor,
107+
scaling: Optional[torch.Tensor] = None,
108108
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
109109
# new_cov will at least be cov1step (see note above in _update_step)
110110
new_cov = cov1step
111111

112+
if scaling is not None:
113+
raise NotImplementedError
114+
112115
# fastpath: if the call to update returned the zero-dim tensor (see _update above) then we are done
113116
if len(cov.shape):
114117
if mask is None or mask.all():

torchcast/kalman_filter/binomial_filter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __init__(self,
2222
measure_covariance: Optional[Covariance] = None,
2323
process_covariance: Optional[Covariance] = None,
2424
initial_covariance: Optional[Covariance] = None,
25-
adaptive_measure_var: bool = False):
25+
adaptive_scaling: bool = False):
2626

2727
if binary_measures is None:
2828
binary_measures = list(measures)
@@ -41,7 +41,7 @@ def __init__(self,
4141
process_covariance=process_covariance,
4242
measure_covariance=measure_covariance,
4343
initial_covariance=initial_covariance,
44-
adaptive_measure_var=adaptive_measure_var,
44+
adaptive_scaling=adaptive_scaling,
4545
measure_funs={m: 'ilogit' for m in binary_measures},
4646
)
4747

torchcast/kalman_filter/kalman_filter.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __init__(self,
2424
process_covariance: Optional[Covariance] = None,
2525
initial_covariance: Optional[Covariance] = None,
2626
measure_funs: Optional[dict[str, str]] = None,
27-
adaptive_measure_var: bool = False):
27+
adaptive_scaling: bool = False):
2828

2929
if initial_covariance is None:
3030
initial_covariance = Covariance.from_processes(processes, cov_type='initial')
@@ -37,7 +37,7 @@ def __init__(self,
3737
measures=measures,
3838
measure_covariance=measure_covariance,
3939
measure_funs=measure_funs,
40-
adaptive_measure_var=adaptive_measure_var,
40+
adaptive_scaling=adaptive_scaling,
4141
)
4242
self.process_covariance = process_covariance.set_id('process_covariance')
4343
self.initial_covariance = initial_covariance.set_id('initial_covariance')
@@ -46,6 +46,7 @@ def _predict_cov(self,
4646
cov: torch.Tensor,
4747
transition_mat: torch.Tensor,
4848
Q: torch.Tensor,
49+
scaling: Optional[torch.Tensor] = None,
4950
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
5051
if mask is None or mask.all():
5152
mask = slice(None)
@@ -100,20 +101,24 @@ def _parse_kwargs(self,
100101
)
101102

102103
# process-variance:
103-
measure_scaling = torch.diag_embed(self._get_measure_scaling().unsqueeze(0))
104104
pcov_kwargs = {}
105105
if self.process_covariance.expected_kwargs:
106106
pcov_kwargs = {k: kwargs[k] for k in self.process_covariance.expected_kwargs}
107107
used_keys |= set(pcov_kwargs)
108+
109+
mcov = self.measure_covariance({}, num_groups=1, num_times=1, _ignore_input=True)[0, 0]
110+
measure_std = mcov.diagonal(dim1=-2, dim2=-1).sqrt()
111+
for idx in self.measure_covariance.empty_idx:
112+
measure_std[idx] = torch.ones_like(measure_std[idx]) # empty measures have no variance, so set to 1
113+
108114
if pcov_kwargs:
109-
measure_scaling = measure_scaling.unsqueeze(0)
110115
pcov_raw = self.process_covariance(pcov_kwargs, num_groups=num_groups, num_times=num_timesteps)
111-
Qs = measure_scaling @ pcov_raw @ measure_scaling
116+
Qs = self._apply_cov_scaling(pcov_raw, scaling=measure_std, is_process_cov=True)
112117
predict_kwargs['Q'] = Qs.unbind(1)
113118
else:
114119
# faster if not time-varying
115-
pcov_raw = self.process_covariance(pcov_kwargs, num_groups=num_groups, num_times=1)
116-
Qs = measure_scaling @ pcov_raw.squeeze(1) @ measure_scaling
120+
pcov_raw = self.process_covariance(pcov_kwargs, num_groups=num_groups, num_times=1).squeeze(1)
121+
Qs = self._apply_cov_scaling(pcov_raw, scaling=measure_std, is_process_cov=True)
117122
predict_kwargs['Q'] = [Qs] * num_timesteps
118123

119124
return predict_kwargs, update_kwargs, used_keys

torchcast/state_space/adaptive_measure_var.py renamed to torchcast/state_space/adaptive_scaling.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,17 @@
1212
from torchcast.process.utils import Bounded
1313

1414

15-
class AdaptiveMeasureVar(nn.Module):
15+
class AdaptiveScaler(nn.Module):
1616
def reset(self):
1717
raise NotImplementedError
1818

1919
def forward(self, residuals: torch.Tensor, skip_mask: torch.Tensor) -> torch.Tensor:
2020
raise NotImplementedError
2121

2222

23-
class EWMAdaptiveMeasureVar(AdaptiveMeasureVar):
23+
class EWMAdaptiveScaler(AdaptiveScaler):
2424
"""
25-
Exponentially Weighted Moving Average (EWM) based adaptive measure variance.
25+
Exponentially Weighted Moving Average (EWM) based adaptive scaling.
2626
"""
2727

2828
def __init__(self,

0 commit comments

Comments
 (0)