Skip to content

Commit a42cd65

Browse files
jduerholtmeta-codesync[bot]
authored andcommitted
Correct handling of input transforms for AdditiveMapSaasSingleTaskGP (meta-pytorch#3042)
Summary: <!-- Thank you for sending the PR! We appreciate you spending the time to make BoTorch better. Help us understand your motivation by explaining why you decided to make this change. You can learn more about contributing to BoTorch here: https://github.com/meta-pytorch/botorch/blob/main/CONTRIBUTING.md --> ## Motivation In `AdditiveMapSaasSingleTaskGP`, input transforms that alter the shape of the training data are not handled correctly. This PR fixes it. A few questions to the Map Saas complex in general: - We have now the ensemble version and the non-ensemble version, which one would you recommend in general? - We have getter methods that generate fitted or unfitted versions of the MAP models and we have the option to just initialise the respective class and fit it in the default MAP way. What do you recommend here? Currently, I am using the `AdditiveMapSaasSingleTaskGP` via initialising the class and fitting it like any SingleTaskGP. ### Have you read the [Contributing Guidelines on pull requests](https://github.com/meta-pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes. Pull Request resolved: meta-pytorch#3042 Test Plan: Unit tests. Reviewed By: Balandat Differential Revision: D84059558 Pulled By: saitcakmak fbshipit-source-id: 415b2ea1ac57c63abb5fa72554cff2cfdd1eae33
1 parent 443bcc5 commit a42cd65

File tree

2 files changed

+40
-2
lines changed

2 files changed

+40
-2
lines changed

botorch/models/map_saas.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,8 +403,15 @@ def __init__(
403403
if train_Yvar is None
404404
else None
405405
)
406+
if input_transform is not None:
407+
with torch.no_grad():
408+
transformed_X = input_transform(train_X)
409+
ard_num_dims = transformed_X.shape[-1]
410+
else:
411+
ard_num_dims = train_X.shape[-1]
412+
406413
covar_module = get_additive_map_saas_covar_module(
407-
ard_num_dims=train_X.shape[-1],
414+
ard_num_dims=ard_num_dims,
408415
num_taus=num_taus,
409416
batch_shape=self._aug_batch_shape,
410417
# Need to pass dtype and device at initialization of the covar_module

test/models/test_map_saas.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import itertools
88
import math
99
import pickle
10+
from functools import partial
1011
from itertools import product
1112
from typing import Any
1213
from unittest import mock
@@ -27,7 +28,12 @@
2728
get_gaussian_likelihood_with_gamma_prior,
2829
get_mean_module_with_normal_prior,
2930
)
30-
from botorch.models.transforms.input import AppendFeatures, FilterFeatures, Normalize
31+
from botorch.models.transforms.input import (
32+
AppendFeatures,
33+
FilterFeatures,
34+
Normalize,
35+
NumericToCategoricalEncoding,
36+
)
3137
from botorch.models.transforms.outcome import Standardize
3238
from botorch.optim.utils import get_parameters_and_bounds, sample_all_priors
3339
from botorch.posteriors.fully_bayesian import GaussianMixturePosterior
@@ -43,6 +49,7 @@
4349
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
4450
from gpytorch.priors import GammaPrior, HalfCauchyPrior, NormalPrior
4551
from torch import Tensor
52+
from torch.nn.functional import one_hot
4653

4754

4855
class TestMapSaas(BotorchTestCase):
@@ -647,6 +654,30 @@ def _get_data_and_model(
647654
)
648655
return train_X, train_Y, train_Yvar, model
649656

657+
def test_input_transform_dimensions(self) -> None:
658+
for dtype in (torch.float, torch.double):
659+
tkwargs = {"device": self.device, "dtype": dtype}
660+
# Create data
661+
X = torch.rand(12, 2, **tkwargs) * 2
662+
Y = 1 - (X - 0.5).norm(dim=-1, keepdim=True)
663+
Y += 0.1 * torch.rand_like(Y)
664+
# Add a categorical feature
665+
new_col = torch.randint(0, 3, (X.shape[0], 1), **tkwargs)
666+
X = torch.cat([X, new_col], dim=1)
667+
668+
input_transform = NumericToCategoricalEncoding(
669+
dim=3,
670+
categorical_features={2: 3},
671+
encoders={2: partial(one_hot, num_classes=3)},
672+
)
673+
674+
model = AdditiveMapSaasSingleTaskGP(
675+
train_X=X,
676+
train_Y=Y,
677+
input_transform=input_transform,
678+
)
679+
self.assertEqual(model.covar_module.kernels[0].base_kernel.ard_num_dims, 5)
680+
650681
def test_construct_mean_module(self) -> None:
651682
tkwargs = {"device": self.device, "dtype": torch.double}
652683
for batch_shape in [None, torch.Size([5])]:

0 commit comments

Comments
 (0)