|
5 | 5 |
|
6 | 6 |
|
7 | 7 | from functools import lru_cache
|
8 |
| -from typing import Callable, Tuple |
| 8 | +from typing import Callable, Optional, Tuple |
9 | 9 |
|
10 | 10 | import torch
|
11 | 11 | from botorch.posteriors.gpytorch import GPyTorchPosterior
|
@@ -79,41 +79,54 @@ def __init__(self, distribution: MultivariateNormal) -> None:
|
79 | 79 | else distribution.variance.unsqueeze(-1)
|
80 | 80 | )
|
81 | 81 |
|
| 82 | + self._mixture_mean: Optional[Tensor] = None |
| 83 | + self._mixture_variance: Optional[Tensor] = None |
| 84 | + |
| 85 | + # using lru_cache on methods can cause memory leaks. See flake8 B019 |
| 86 | + # So we define a function here instead, to be called by self.quantile |
| 87 | + @lru_cache |
| 88 | + def _quantile(value: Tensor) -> Tensor: |
| 89 | + r"""Compute the posterior quantiles for the mixture of models.""" |
| 90 | + if value.numel() > 1: |
| 91 | + return torch.stack([self.quantile(v) for v in value], dim=0) |
| 92 | + if value <= 0 or value >= 1: |
| 93 | + raise ValueError("value is expected to be in the range (0, 1).") |
| 94 | + dist = torch.distributions.Normal(loc=self.mean, scale=self.variance.sqrt()) |
| 95 | + if self.mean.shape[MCMC_DIM] == 1: # Analytical solution |
| 96 | + return dist.icdf(value).squeeze(MCMC_DIM) |
| 97 | + icdf_val = dist.icdf(value) |
| 98 | + low = icdf_val.min(dim=MCMC_DIM).values - TOL |
| 99 | + high = icdf_val.max(dim=MCMC_DIM).values + TOL |
| 100 | + bounds = torch.cat((low.unsqueeze(0), high.unsqueeze(0)), dim=0) |
| 101 | + return batched_bisect( |
| 102 | + f=lambda x: dist.cdf(x.unsqueeze(MCMC_DIM)).mean(dim=MCMC_DIM), |
| 103 | + target=value.item(), |
| 104 | + bounds=bounds, |
| 105 | + ) |
| 106 | + |
| 107 | + self._quantile = _quantile |
| 108 | + |
82 | 109 | @property
|
83 |
| - @lru_cache(maxsize=None) |
84 | 110 | def mixture_mean(self) -> Tensor:
|
85 | 111 | r"""The posterior mean for the mixture of models."""
|
86 |
| - return self._mean.mean(dim=MCMC_DIM) |
| 112 | + if self._mixture_mean is None: |
| 113 | + self._mixture_mean = self._mean.mean(dim=MCMC_DIM) |
| 114 | + return self._mixture_mean |
87 | 115 |
|
88 | 116 | @property
|
89 |
| - @lru_cache(maxsize=None) |
90 | 117 | def mixture_variance(self) -> Tensor:
|
91 | 118 | r"""The posterior variance for the mixture of models."""
|
92 |
| - num_mcmc_samples = self.mean.shape[MCMC_DIM] |
93 |
| - t1 = self._variance.sum(dim=MCMC_DIM) / num_mcmc_samples |
94 |
| - t2 = self._mean.pow(2).sum(dim=MCMC_DIM) / num_mcmc_samples |
95 |
| - t3 = -(self._mean.sum(dim=MCMC_DIM) / num_mcmc_samples).pow(2) |
96 |
| - return t1 + t2 + t3 |
| 119 | + if self._mixture_variance is None: |
| 120 | + num_mcmc_samples = self.mean.shape[MCMC_DIM] |
| 121 | + t1 = self._variance.sum(dim=MCMC_DIM) / num_mcmc_samples |
| 122 | + t2 = self._mean.pow(2).sum(dim=MCMC_DIM) / num_mcmc_samples |
| 123 | + t3 = -(self._mean.sum(dim=MCMC_DIM) / num_mcmc_samples).pow(2) |
| 124 | + self._mixture_variance = t1 + t2 + t3 |
| 125 | + return self._mixture_variance |
97 | 126 |
|
98 |
| - @lru_cache(maxsize=None) |
99 | 127 | def quantile(self, value: Tensor) -> Tensor:
|
100 | 128 | r"""Compute the posterior quantiles for the mixture of models."""
|
101 |
| - if value.numel() > 1: |
102 |
| - return torch.stack([self.quantile(v) for v in value], dim=0) |
103 |
| - if value <= 0 or value >= 1: |
104 |
| - raise ValueError("value is expected to be in the range (0, 1).") |
105 |
| - dist = torch.distributions.Normal(loc=self.mean, scale=self.variance.sqrt()) |
106 |
| - if self.mean.shape[MCMC_DIM] == 1: # Analytical solution |
107 |
| - return dist.icdf(value).squeeze(MCMC_DIM) |
108 |
| - icdf_val = dist.icdf(value) |
109 |
| - low = icdf_val.min(dim=MCMC_DIM).values - TOL |
110 |
| - high = icdf_val.max(dim=MCMC_DIM).values + TOL |
111 |
| - bounds = torch.cat((low.unsqueeze(0), high.unsqueeze(0)), dim=0) |
112 |
| - return batched_bisect( |
113 |
| - f=lambda x: dist.cdf(x.unsqueeze(MCMC_DIM)).mean(dim=MCMC_DIM), |
114 |
| - target=value.item(), |
115 |
| - bounds=bounds, |
116 |
| - ) |
| 129 | + return self._quantile(value) |
117 | 130 |
|
118 | 131 | @property
|
119 | 132 | def batch_range(self) -> Tuple[int, int]:
|
|
0 commit comments