|
7 | 7 | import itertools
|
8 | 8 | import math
|
9 | 9 | import pickle
|
| 10 | +from functools import partial |
10 | 11 | from itertools import product
|
11 | 12 | from typing import Any
|
12 | 13 | from unittest import mock
|
|
27 | 28 | get_gaussian_likelihood_with_gamma_prior,
|
28 | 29 | get_mean_module_with_normal_prior,
|
29 | 30 | )
|
30 |
| -from botorch.models.transforms.input import AppendFeatures, FilterFeatures, Normalize |
| 31 | +from botorch.models.transforms.input import ( |
| 32 | + AppendFeatures, |
| 33 | + FilterFeatures, |
| 34 | + Normalize, |
| 35 | + NumericToCategoricalEncoding, |
| 36 | +) |
31 | 37 | from botorch.models.transforms.outcome import Standardize
|
32 | 38 | from botorch.optim.utils import get_parameters_and_bounds, sample_all_priors
|
33 | 39 | from botorch.posteriors.fully_bayesian import GaussianMixturePosterior
|
|
43 | 49 | from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
|
44 | 50 | from gpytorch.priors import GammaPrior, HalfCauchyPrior, NormalPrior
|
45 | 51 | from torch import Tensor
|
| 52 | +from torch.nn.functional import one_hot |
46 | 53 |
|
47 | 54 |
|
48 | 55 | class TestMapSaas(BotorchTestCase):
|
@@ -647,6 +654,30 @@ def _get_data_and_model(
|
647 | 654 | )
|
648 | 655 | return train_X, train_Y, train_Yvar, model
|
649 | 656 |
|
| 657 | + def test_input_transform_dimensions(self) -> None: |
| 658 | + for dtype in (torch.float, torch.double): |
| 659 | + tkwargs = {"device": self.device, "dtype": dtype} |
| 660 | + # Create data |
| 661 | + X = torch.rand(12, 2, **tkwargs) * 2 |
| 662 | + Y = 1 - (X - 0.5).norm(dim=-1, keepdim=True) |
| 663 | + Y += 0.1 * torch.rand_like(Y) |
| 664 | + # Add a categorical feature |
| 665 | + new_col = torch.randint(0, 3, (X.shape[0], 1), **tkwargs) |
| 666 | + X = torch.cat([X, new_col], dim=1) |
| 667 | + |
| 668 | + input_transform = NumericToCategoricalEncoding( |
| 669 | + dim=3, |
| 670 | + categorical_features={2: 3}, |
| 671 | + encoders={2: partial(one_hot, num_classes=3)}, |
| 672 | + ) |
| 673 | + |
| 674 | + model = AdditiveMapSaasSingleTaskGP( |
| 675 | + train_X=X, |
| 676 | + train_Y=Y, |
| 677 | + input_transform=input_transform, |
| 678 | + ) |
| 679 | + self.assertEqual(model.covar_module.kernels[0].base_kernel.ard_num_dims, 5) |
| 680 | + |
650 | 681 | def test_construct_mean_module(self) -> None:
|
651 | 682 | tkwargs = {"device": self.device, "dtype": torch.double}
|
652 | 683 | for batch_shape in [None, torch.Size([5])]:
|
|
0 commit comments