Skip to content

Commit a87abbd

Browse files
David Erikssonmeta-codesync[bot]
authored andcommitted
EnsembleMapSaasGP -> EnsembleMapSaasSingleTaskGP (meta-pytorch#3038)
Summary: Pull Request resolved: meta-pytorch#3038 Realized a little bit too late that `EnsembleMapSaasGP` will be confusing once this model is extended to an MTGP. Let's rename it consistently with our other models before it's too late :) Reviewed By: Balandat Differential Revision: D83796108 fbshipit-source-id: 9618f63fcec7117af4afa324379a96c8b49790f2
1 parent 0650dcc commit a87abbd

File tree

3 files changed

+20
-16
lines changed

3 files changed

+20
-16
lines changed

botorch/models/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from botorch.models.map_saas import (
2424
add_saas_prior,
2525
AdditiveMapSaasSingleTaskGP,
26-
EnsembleMapSaasGP,
26+
EnsembleMapSaasSingleTaskGP,
2727
)
2828
from botorch.models.model import ModelList
2929
from botorch.models.model_list_gp_regression import ModelListGP
@@ -36,9 +36,7 @@
3636
"AffineDeterministicModel",
3737
"AffineFidelityCostModel",
3838
"ApproximateGPyTorchModel",
39-
"EnsembleMapSaasGP",
40-
"SaasFullyBayesianSingleTaskGP",
41-
"SaasFullyBayesianMultiTaskGP",
39+
"EnsembleMapSaasSingleTaskGP",
4240
"GenericDeterministicModel",
4341
"HigherOrderGP",
4442
"KroneckerMultiTaskGP",
@@ -49,6 +47,8 @@
4947
"PairwiseGP",
5048
"PairwiseLaplaceMarginalLogLikelihood",
5149
"PosteriorMeanModel",
50+
"SaasFullyBayesianMultiTaskGP",
51+
"SaasFullyBayesianSingleTaskGP",
5252
"SingleTaskGP",
5353
"SingleTaskMultiFidelityGP",
5454
"SingleTaskVariationalGP",

botorch/models/map_saas.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ def __init__(self, tau: Tensor | float | None = None):
3939
Args:
4040
tau: Value of the global shrinkage parameter. If `None`, the tau will be
4141
a free parameter and inferred from the data.
42-
Tau can be a tensor for batched models, like `EnsembleMapSaasGP`,
43-
where each batch has a different sparsity prior. If tau is a tensor,
44-
it must have shape `batch_shape`.
42+
Tau can be a tensor for batched models, like
43+
`EnsembleMapSaasSingleTaskGP`, where each batch has a different
44+
sparsity prior. If tau is a tensor, it must have shape `batch_shape`.
4545
"""
4646
self._tau = torch.as_tensor(tau) if tau is not None else None
4747

@@ -427,7 +427,7 @@ def __init__(
427427
self.to(dtype=train_X.dtype, device=train_X.device)
428428

429429

430-
class EnsembleMapSaasGP(SingleTaskGP):
430+
class EnsembleMapSaasSingleTaskGP(SingleTaskGP):
431431
_is_ensemble = True
432432

433433
def __init__(
@@ -440,9 +440,9 @@ def __init__(
440440
outcome_transform: OutcomeTransform | _DefaultType | None = DEFAULT,
441441
input_transform: InputTransform | None = None,
442442
) -> None:
443-
"""Instantiates an ``EnsembleMapSaasGP``, which is a batched ensemble of
444-
``SingleTaskGP``s with the Matern-5/2 kernel and a SAAS prior. The model is
445-
intended to be trained with ``ExactMarginalLogLikelihood`` and
443+
"""Instantiates an ``EnsembleMapSaasSingleTaskGP``, which is a batched
444+
ensemble of ``SingleTaskGP``s with the Matern-5/2 kernel and a SAAS prior.
445+
The model is intended to be trained with ``ExactMarginalLogLikelihood`` and
446446
``fit_gpytorch_mll``. Under the hood, the model is equivalent to a
447447
multi-output ``BatchedMultiOutputGPyTorchModel``, but it produces a
448448
``MixtureGaussiaPosterior``, which leads to ensembling of the model outputs.

test/models/test_map_saas.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from botorch.models.map_saas import (
2323
add_saas_prior,
2424
AdditiveMapSaasSingleTaskGP,
25-
EnsembleMapSaasGP,
25+
EnsembleMapSaasSingleTaskGP,
2626
get_additive_map_saas_covar_module,
2727
get_gaussian_likelihood_with_gamma_prior,
2828
get_mean_module_with_normal_prior,
@@ -527,7 +527,7 @@ def test_emsemble_map_saas(self) -> None:
527527
}
528528
else:
529529
extra_inputs = {}
530-
model = EnsembleMapSaasGP(
530+
model = EnsembleMapSaasSingleTaskGP(
531531
train_X=train_X, train_Y=train_Y, num_taus=num_taus, **extra_inputs
532532
)
533533
sample_all_priors(model) # Checks that the prior is configured correctly.
@@ -553,16 +553,20 @@ def test_emsemble_map_saas(self) -> None:
553553

554554
def test_ensemble_map_saas_validation(self) -> None:
555555
with self.assertRaisesRegex(ValueError, "Expected taus to be of shape"):
556-
EnsembleMapSaasGP(
556+
EnsembleMapSaasSingleTaskGP(
557557
train_X=torch.rand(5, 3),
558558
train_Y=torch.rand(5, 1),
559559
num_taus=3,
560560
taus=torch.rand(2),
561561
)
562562
with self.assertRaisesRegex(UnsupportedError, "only supports single-output"):
563-
EnsembleMapSaasGP(train_X=torch.rand(5, 3), train_Y=torch.rand(5, 2))
563+
EnsembleMapSaasSingleTaskGP(
564+
train_X=torch.rand(5, 3), train_Y=torch.rand(5, 2)
565+
)
564566
with self.assertRaisesRegex(UnsupportedError, "only supports 2D inputs"):
565-
EnsembleMapSaasGP(train_X=torch.rand(2, 5, 3), train_Y=torch.rand(2, 5, 1))
567+
EnsembleMapSaasSingleTaskGP(
568+
train_X=torch.rand(2, 5, 3), train_Y=torch.rand(2, 5, 1)
569+
)
566570

567571

568572
class TestAdditiveMapSaasSingleTaskGP(BotorchTestCase):

0 commit comments

Comments
 (0)