Skip to content

Commit 7e1350b

Browse files
esantorellafacebook-github-bot
authored andcommitted
Don't use functools.lru_cache on methods (meta-pytorch#1650)
Summary: Pull Request resolved: meta-pytorch#1650 It can cause memory leaks. For explanation and fix, see https://rednafi.github.io/reflections/dont-wrap-instance-methods-with-functoolslru_cache-decorator-in-python.html Reviewed By: Balandat Differential Revision: D42980747 fbshipit-source-id: 225edb7e43c363d1f2b4580bec77488542b1a7c3
1 parent b3d3074 commit 7e1350b

File tree

1 file changed

+39
-26
lines changed

1 file changed

+39
-26
lines changed

botorch/posteriors/fully_bayesian.py

Lines changed: 39 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66

77
from functools import lru_cache
8-
from typing import Callable, Tuple
8+
from typing import Callable, Optional, Tuple
99

1010
import torch
1111
from botorch.posteriors.gpytorch import GPyTorchPosterior
@@ -79,41 +79,54 @@ def __init__(self, distribution: MultivariateNormal) -> None:
7979
else distribution.variance.unsqueeze(-1)
8080
)
8181

82+
self._mixture_mean: Optional[Tensor] = None
83+
self._mixture_variance: Optional[Tensor] = None
84+
85+
# using lru_cache on methods can cause memory leaks. See flake8 B019
86+
# So we define a function here instead, to be called by self.quantile
87+
@lru_cache
88+
def _quantile(value: Tensor) -> Tensor:
89+
r"""Compute the posterior quantiles for the mixture of models."""
90+
if value.numel() > 1:
91+
return torch.stack([self.quantile(v) for v in value], dim=0)
92+
if value <= 0 or value >= 1:
93+
raise ValueError("value is expected to be in the range (0, 1).")
94+
dist = torch.distributions.Normal(loc=self.mean, scale=self.variance.sqrt())
95+
if self.mean.shape[MCMC_DIM] == 1: # Analytical solution
96+
return dist.icdf(value).squeeze(MCMC_DIM)
97+
icdf_val = dist.icdf(value)
98+
low = icdf_val.min(dim=MCMC_DIM).values - TOL
99+
high = icdf_val.max(dim=MCMC_DIM).values + TOL
100+
bounds = torch.cat((low.unsqueeze(0), high.unsqueeze(0)), dim=0)
101+
return batched_bisect(
102+
f=lambda x: dist.cdf(x.unsqueeze(MCMC_DIM)).mean(dim=MCMC_DIM),
103+
target=value.item(),
104+
bounds=bounds,
105+
)
106+
107+
self._quantile = _quantile
108+
82109
@property
83-
@lru_cache(maxsize=None)
84110
def mixture_mean(self) -> Tensor:
85111
r"""The posterior mean for the mixture of models."""
86-
return self._mean.mean(dim=MCMC_DIM)
112+
if self._mixture_mean is None:
113+
self._mixture_mean = self._mean.mean(dim=MCMC_DIM)
114+
return self._mixture_mean
87115

88116
@property
89-
@lru_cache(maxsize=None)
90117
def mixture_variance(self) -> Tensor:
91118
r"""The posterior variance for the mixture of models."""
92-
num_mcmc_samples = self.mean.shape[MCMC_DIM]
93-
t1 = self._variance.sum(dim=MCMC_DIM) / num_mcmc_samples
94-
t2 = self._mean.pow(2).sum(dim=MCMC_DIM) / num_mcmc_samples
95-
t3 = -(self._mean.sum(dim=MCMC_DIM) / num_mcmc_samples).pow(2)
96-
return t1 + t2 + t3
119+
if self._mixture_variance is None:
120+
num_mcmc_samples = self.mean.shape[MCMC_DIM]
121+
t1 = self._variance.sum(dim=MCMC_DIM) / num_mcmc_samples
122+
t2 = self._mean.pow(2).sum(dim=MCMC_DIM) / num_mcmc_samples
123+
t3 = -(self._mean.sum(dim=MCMC_DIM) / num_mcmc_samples).pow(2)
124+
self._mixture_variance = t1 + t2 + t3
125+
return self._mixture_variance
97126

98-
@lru_cache(maxsize=None)
99127
def quantile(self, value: Tensor) -> Tensor:
100128
r"""Compute the posterior quantiles for the mixture of models."""
101-
if value.numel() > 1:
102-
return torch.stack([self.quantile(v) for v in value], dim=0)
103-
if value <= 0 or value >= 1:
104-
raise ValueError("value is expected to be in the range (0, 1).")
105-
dist = torch.distributions.Normal(loc=self.mean, scale=self.variance.sqrt())
106-
if self.mean.shape[MCMC_DIM] == 1: # Analytical solution
107-
return dist.icdf(value).squeeze(MCMC_DIM)
108-
icdf_val = dist.icdf(value)
109-
low = icdf_val.min(dim=MCMC_DIM).values - TOL
110-
high = icdf_val.max(dim=MCMC_DIM).values + TOL
111-
bounds = torch.cat((low.unsqueeze(0), high.unsqueeze(0)), dim=0)
112-
return batched_bisect(
113-
f=lambda x: dist.cdf(x.unsqueeze(MCMC_DIM)).mean(dim=MCMC_DIM),
114-
target=value.item(),
115-
bounds=bounds,
116-
)
129+
return self._quantile(value)
117130

118131
@property
119132
def batch_range(self) -> Tuple[int, int]:

0 commit comments

Comments
 (0)