Skip to content

Commit 7c617db

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Deprecate get_fitted_map_saas_ensemble in favor of EnsembleMapSaasGP (meta-pytorch#3036)
Summary: With `EnsembleMapSaasGP`, we no longer need a helper that constructs a `SaasFullyBayasianSingleTaskGP` from individually fitted models. Differential Revision: D83782823
1 parent e390e13 commit 7c617db

File tree

2 files changed

+37
-141
lines changed

2 files changed

+37
-141
lines changed

botorch/fit.py

Lines changed: 17 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
from typing import Any
1616
from warnings import catch_warnings, simplefilter, warn_explicit, WarningMessage
1717

18-
import torch
19-
2018
from botorch.exceptions.errors import ModelFittingError, UnsupportedError
2119
from botorch.exceptions.warnings import OptimizationWarning
2220
from botorch.logging import logger
@@ -27,7 +25,7 @@
2725
SaasFullyBayesianSingleTaskGP,
2826
)
2927
from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP
30-
from botorch.models.map_saas import get_map_saas_model
28+
from botorch.models.map_saas import EnsembleMapSaasSingleTaskGP, get_map_saas_model
3129
from botorch.models.model_list_gp_regression import ModelListGP
3230
from botorch.models.transforms.input import InputTransform
3331
from botorch.models.transforms.outcome import OutcomeTransform
@@ -45,6 +43,7 @@
4543
TensorCheckpoint,
4644
)
4745
from botorch.utils.dispatcher import Dispatcher, type_bypassing_encoder
46+
from botorch.utils.types import _DefaultType, DEFAULT
4847
from gpytorch.likelihoods import Likelihood
4948
from gpytorch.mlls._approximate_mll import _ApproximateMarginalLogLikelihood
5049
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
@@ -53,7 +52,6 @@
5352
from linear_operator.utils.errors import NotPSDError
5453
from pyro.infer.mcmc import MCMC, NUTS
5554
from torch import device, Tensor
56-
from torch.distributions import HalfCauchy
5755
from torch.nn import Parameter
5856
from torch.utils.data import DataLoader
5957

@@ -443,13 +441,15 @@ def get_fitted_map_saas_ensemble(
443441
train_Y: Tensor,
444442
train_Yvar: Tensor | None = None,
445443
input_transform: InputTransform | None = None,
446-
outcome_transform: OutcomeTransform | None = None,
444+
outcome_transform: OutcomeTransform | _DefaultType | None = DEFAULT,
447445
taus: Tensor | list[float] | None = None,
448446
num_taus: int = 4,
449447
optimizer_kwargs: dict[str, Any] | None = None,
450448
) -> SaasFullyBayesianSingleTaskGP:
451449
"""Get a fitted SAAS ensemble using several different tau values.
452450
451+
DEPRECATED: Please use `EnsembleMapSaasGP` directly!
452+
453453
Args:
454454
train_X: Tensor of shape `n x d` with training inputs.
455455
train_Y: Tensor of shape `n x 1` with training targets.
@@ -464,57 +464,21 @@ def get_fitted_map_saas_ensemble(
464464
to fit_gpytorch_mll.
465465
466466
Returns:
467-
A fitted SaasFullyBayesianSingleTaskGP with a Matern kernel.
467+
A fitted EnsembleMapSaasGP with a Matern kernel.
468468
"""
469-
tkwargs = {"device": train_X.device, "dtype": train_X.dtype}
470-
if taus is None:
471-
taus = HalfCauchy(0.1).sample([num_taus]).to(**tkwargs)
472-
num_samples = len(taus)
473-
if num_samples == 1:
474-
raise ValueError(
475-
"Use `get_fitted_map_saas_model` if you only specify one value of tau"
476-
)
477-
478-
mean = torch.zeros(num_samples, **tkwargs)
479-
outputscale = torch.zeros(num_samples, **tkwargs)
480-
lengthscale = torch.zeros(num_samples, train_X.shape[-1], **tkwargs)
481-
noise = torch.zeros(num_samples, **tkwargs)
482-
483-
# Fit a model for each tau and save the hyperparameters
484-
for i, tau in enumerate(taus):
485-
model = get_fitted_map_saas_model(
486-
train_X,
487-
train_Y,
488-
train_Yvar=train_Yvar,
489-
input_transform=input_transform,
490-
outcome_transform=outcome_transform,
491-
tau=tau,
492-
optimizer_kwargs=optimizer_kwargs,
493-
)
494-
mean[i] = model.mean_module.constant.detach().clone()
495-
outputscale[i] = model.covar_module.outputscale.detach().clone()
496-
lengthscale[i, :] = model.covar_module.base_kernel.lengthscale.detach().clone()
497-
if train_Yvar is None:
498-
noise[i] = model.likelihood.noise.detach().clone()
499-
500-
# Load the samples into a fully Bayesian SAAS model
501-
ensemble_model = SaasFullyBayesianSingleTaskGP(
469+
logger.warning(
470+
"get_fitted_map_saas_ensemble is deprecated and will be removed in v0.17. "
471+
"Please use EnsembleMapSaasGP instead!"
472+
)
473+
model = EnsembleMapSaasSingleTaskGP(
502474
train_X=train_X,
503475
train_Y=train_Y,
504476
train_Yvar=train_Yvar,
505-
input_transform=(
506-
input_transform.train() if input_transform is not None else None
507-
),
477+
num_taus=num_taus,
478+
taus=taus,
479+
input_transform=input_transform,
508480
outcome_transform=outcome_transform,
509481
)
510-
mcmc_samples = {
511-
"mean": mean,
512-
"outputscale": outputscale,
513-
"lengthscale": lengthscale,
514-
}
515-
if train_Yvar is None:
516-
mcmc_samples["noise"] = noise
517-
ensemble_model.train()
518-
ensemble_model.load_mcmc_samples(mcmc_samples=mcmc_samples)
519-
ensemble_model.eval()
520-
return ensemble_model
482+
mll = ExactMarginalLogLikelihood(model=model, likelihood=model.likelihood)
483+
fit_gpytorch_mll(mll, optimizer_kwargs=optimizer_kwargs)
484+
return model

test/models/test_map_saas.py

Lines changed: 20 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818
fit_gpytorch_mll,
1919
get_fitted_map_saas_ensemble,
2020
get_fitted_map_saas_model,
21+
logger,
2122
)
22-
from botorch.models import SaasFullyBayesianSingleTaskGP, SingleTaskGP
23+
from botorch.models import SingleTaskGP
2324
from botorch.models.map_saas import (
2425
add_saas_prior,
2526
AdditiveMapSaasSingleTaskGP,
@@ -299,93 +300,24 @@ def test_get_saas_model(self) -> None:
299300
self.assertTrue(loss < loss_short)
300301

301302
def test_get_saas_ensemble(self) -> None:
302-
for infer_noise, taus in itertools.product([True, False], [None, [0.1, 0.2]]):
303-
tkwargs = {"device": self.device, "dtype": torch.double}
304-
train_X, train_Y, _ = self._get_data_hardcoded(**tkwargs)
305-
d = train_X.shape[-1]
306-
train_Yvar = (
307-
None
308-
if infer_noise
309-
else 0.1 * torch.arange(len(train_X), **tkwargs).unsqueeze(-1)
310-
)
311-
# Fit without specifying tau
312-
with torch.random.fork_rng():
313-
torch.manual_seed(0)
314-
model = get_fitted_map_saas_ensemble(
315-
train_X=train_X,
316-
train_Y=train_Y,
317-
train_Yvar=train_Yvar,
318-
input_transform=Normalize(d=d),
319-
outcome_transform=Standardize(m=1),
320-
taus=taus,
321-
)
322-
self.assertIsInstance(model, SaasFullyBayesianSingleTaskGP)
323-
num_taus = 4 if taus is None else len(taus)
324-
self.assertEqual(
325-
model.covar_module.base_kernel.lengthscale.shape,
326-
torch.Size([num_taus, 1, d]),
327-
)
328-
self.assertEqual(model.batch_shape, torch.Size([num_taus]))
329-
# Make sure the lengthscales are reasonable
330-
self.assertGreater(
331-
model.covar_module.base_kernel.lengthscale[..., 1:].min(), 50
332-
)
333-
self.assertLess(
334-
model.covar_module.base_kernel.lengthscale[..., 0].max(), 10
335-
)
336-
337-
# testing optimizer_options: short optimization run with maxiter = 3
338-
with torch.random.fork_rng():
339-
torch.manual_seed(0)
340-
fit_gpytorch_mll_mock = mock.Mock(wraps=fit_gpytorch_mll)
341-
with mock.patch(
342-
"botorch.fit.fit_gpytorch_mll",
343-
new=fit_gpytorch_mll_mock,
344-
):
345-
maxiter = 3
346-
model_short = get_fitted_map_saas_ensemble(
347-
train_X=train_X,
348-
train_Y=train_Y,
349-
train_Yvar=train_Yvar,
350-
input_transform=Normalize(d=d),
351-
outcome_transform=Standardize(m=1),
352-
taus=taus,
353-
optimizer_kwargs={"options": {"maxiter": maxiter}},
354-
)
355-
kwargs = fit_gpytorch_mll_mock.call_args.kwargs
356-
# fit_gpytorch_mll has "option" kwarg, not "optimizer_options"
357-
self.assertEqual(
358-
kwargs["optimizer_kwargs"]["options"]["maxiter"], maxiter
359-
)
360-
361-
# compute sum of marginal likelihoods of ensemble after short run
362-
# NOTE: We can't put MLL in train mode here since
363-
# SaasFullyBayesianSingleTaskGP requires NUTS for training.
364-
mll_short = ExactMarginalLogLikelihood(
365-
model=model_short, likelihood=model_short.likelihood
303+
train_X, train_Y, _ = self._get_data_hardcoded(device=self.device)
304+
with self.assertLogs(logger=logger, level="WARNING") as logs, mock.patch(
305+
"botorch.fit.fit_gpytorch_mll"
306+
) as mock_fit:
307+
model = get_fitted_map_saas_ensemble(
308+
train_X=train_X,
309+
train_Y=train_Y,
310+
input_transform=Normalize(d=train_X.shape[-1]),
311+
outcome_transform=Standardize(m=1, batch_shape=torch.Size([4])),
312+
optimizer_kwargs={"options": {"maxiter": 3}},
366313
)
367-
train_inputs = mll_short.model.train_inputs
368-
train_targets = mll_short.model.train_targets
369-
loss_short = -mll_short(model_short(*train_inputs), train_targets)
370-
# compute sum of marginal likelihoods of ensemble after standard run
371-
mll = ExactMarginalLogLikelihood(model=model, likelihood=model.likelihood)
372-
# reusing train_inputs and train_targets, since the transforms are the same
373-
loss = -mll(model(*train_inputs), train_targets)
374-
# the longer running optimization should have smaller loss than the shorter
375-
self.assertLess((loss - loss_short).max(), 0.0)
376-
377-
# test error message
378-
with self.assertRaisesRegex(
379-
ValueError, "if you only specify one value of tau"
380-
):
381-
model_short = get_fitted_map_saas_ensemble(
382-
train_X=train_X,
383-
train_Y=train_Y,
384-
train_Yvar=train_Yvar,
385-
input_transform=Normalize(d=d),
386-
outcome_transform=Standardize(m=1),
387-
taus=[0.1],
388-
)
314+
self.assertTrue(
315+
any("use EnsembleMapSaasGP instead" in output for output in logs.output)
316+
)
317+
self.assertEqual(
318+
mock_fit.call_args.kwargs["optimizer_kwargs"], {"options": {"maxiter": 3}}
319+
)
320+
self.assertIsInstance(model, EnsembleMapSaasSingleTaskGP)
389321

390322
def test_input_transform_in_train(self) -> None:
391323
train_X, train_Y, test_X = self._get_data()
@@ -522,7 +454,7 @@ def test_batch_model_fitting(self) -> None:
522454

523455
@mock_optimize
524456
def test_emsemble_map_saas(self) -> None:
525-
train_X, train_Y, test_X = self._get_data()
457+
train_X, train_Y, test_X = self._get_data(device=self.device)
526458
d = train_X.shape[-1]
527459
num_taus = 8
528460
for with_options in (False, True):

0 commit comments

Comments
 (0)