|
6 | 6 |
|
7 | 7 | import itertools
|
8 | 8 | import warnings
|
| 9 | +from functools import partial |
9 | 10 |
|
10 | 11 | import torch
|
11 | 12 | from botorch.acquisition.objective import ScalarizedPosteriorTransform
|
|
24 | 25 | from botorch.models.model import FantasizeMixin
|
25 | 26 | from botorch.models.multitask import MultiTaskGP
|
26 | 27 | from botorch.models.transforms import Standardize
|
27 |
| -from botorch.models.transforms.input import ChainedInputTransform, InputTransform |
| 28 | +from botorch.models.transforms.input import ( |
| 29 | + ChainedInputTransform, |
| 30 | + InputTransform, |
| 31 | + NumericToCategoricalEncoding, |
| 32 | +) |
28 | 33 | from botorch.models.utils import fantasize
|
29 | 34 | from botorch.posteriors.gpytorch import GPyTorchPosterior
|
30 | 35 | from botorch.sampling.normal import SobolQMCNormalSampler
|
|
39 | 44 | from gpytorch.settings import trace_mode
|
40 | 45 | from torch import Tensor
|
41 | 46 |
|
| 47 | +from torch.nn.functional import one_hot |
| 48 | + |
42 | 49 |
|
43 | 50 | class SimpleInputTransform(InputTransform, torch.nn.Module):
|
44 | 51 | def __init__(self, transform_on_train: bool) -> None:
|
@@ -691,6 +698,43 @@ def test_condition_on_observations_model_list(self):
|
691 | 698 | X=torch.rand(2, 1, **tkwargs), Y=torch.rand(2, 2, **tkwargs)
|
692 | 699 | )
|
693 | 700 |
|
| 701 | + def test_condition_on_observations_input_transform_shape_manipulation(self): |
| 702 | + for dtype in (torch.float, torch.double): |
| 703 | + tkwargs = {"device": self.device, "dtype": dtype} |
| 704 | + |
| 705 | + # Create data |
| 706 | + X = torch.rand(12, 2, **tkwargs) * 2 |
| 707 | + Y = 1 - (X - 0.5).norm(dim=-1, keepdim=True) |
| 708 | + Y += 0.1 * torch.rand_like(Y) |
| 709 | + # Add a categorical feature |
| 710 | + new_col = torch.randint(0, 3, (X.shape[0], 1), **tkwargs) |
| 711 | + X = torch.cat([X, new_col], dim=1) |
| 712 | + |
| 713 | + train_X = X[:10] |
| 714 | + train_Y = Y[:10] |
| 715 | + |
| 716 | + condition_X = X[10:] |
| 717 | + condition_Y = Y[10:] |
| 718 | + |
| 719 | + # setup transform and model |
| 720 | + input_transform = NumericToCategoricalEncoding( |
| 721 | + dim=3, |
| 722 | + categorical_features={2: 3}, |
| 723 | + encoders={2: partial(one_hot, num_classes=3)}, |
| 724 | + ) |
| 725 | + |
| 726 | + model = SimpleGPyTorchModel( |
| 727 | + train_X, train_Y, input_transform=input_transform |
| 728 | + ) |
| 729 | + model.eval() |
| 730 | + _ = model.posterior(train_X) |
| 731 | + |
| 732 | + conditioned_model = model.condition_on_observations( |
| 733 | + condition_X, condition_Y |
| 734 | + ) |
| 735 | + self.assertAllClose(conditioned_model._original_train_inputs, X) |
| 736 | + self.assertAllClose(conditioned_model.train_inputs[0], input_transform(X)) |
| 737 | + |
694 | 738 | def test_condition_on_observations_input_transform_consistency(self):
|
695 | 739 | """Test that input transforms are applied consistently in
|
696 | 740 | condition_on_observations.
|
|
0 commit comments