|
34 | 34 | from botorch.posteriors.gpytorch import GPyTorchPosterior
|
35 | 35 | from botorch.test_utils.mock import mock_optimize
|
36 | 36 | from botorch.utils.constraints import LogTransformedInterval
|
| 37 | +from botorch.utils.datasets import SupervisedDataset |
37 | 38 | from botorch.utils.testing import BotorchTestCase
|
38 | 39 | from gpytorch.constraints import Interval
|
39 | 40 | from gpytorch.kernels import AdditiveKernel, MaternKernel, ScaleKernel
|
@@ -568,6 +569,56 @@ def test_ensemble_map_saas_validation(self) -> None:
|
568 | 569 | train_X=torch.rand(2, 5, 3), train_Y=torch.rand(2, 5, 1)
|
569 | 570 | )
|
570 | 571 |
|
| 572 | + def test_ensemble_map_saas_construct_inputs(self) -> None: |
| 573 | + """Test the construct_inputs class method for EnsembleMapSaasSingleTaskGP.""" |
| 574 | + |
| 575 | + train_X, train_Y, _ = self._get_data() |
| 576 | + training_data = SupervisedDataset( |
| 577 | + X=train_X, Y=train_Y, feature_names=["x1", "x2", "x3"], outcome_names=["y"] |
| 578 | + ) |
| 579 | + |
| 580 | + # Test with default num_taus |
| 581 | + inputs_default = EnsembleMapSaasSingleTaskGP.construct_inputs( |
| 582 | + training_data=training_data |
| 583 | + ) |
| 584 | + self.assertIn("num_taus", inputs_default) |
| 585 | + self.assertEqual(inputs_default["num_taus"], 4) |
| 586 | + self.assertIn("train_X", inputs_default) |
| 587 | + self.assertIn("train_Y", inputs_default) |
| 588 | + self.assertAllClose(inputs_default["train_X"], train_X) |
| 589 | + self.assertAllClose(inputs_default["train_Y"], train_Y) |
| 590 | + |
| 591 | + # Test with custom num_taus |
| 592 | + custom_num_taus = 6 |
| 593 | + inputs_custom = EnsembleMapSaasSingleTaskGP.construct_inputs( |
| 594 | + training_data=training_data, num_taus=custom_num_taus |
| 595 | + ) |
| 596 | + self.assertIn("num_taus", inputs_custom) |
| 597 | + self.assertEqual(inputs_custom["num_taus"], custom_num_taus) |
| 598 | + self.assertIn("train_X", inputs_custom) |
| 599 | + self.assertIn("train_Y", inputs_custom) |
| 600 | + self.assertAllClose(inputs_custom["train_X"], train_X) |
| 601 | + self.assertAllClose(inputs_custom["train_Y"], train_Y) |
| 602 | + |
| 603 | + # Test with train_Yvar in the dataset |
| 604 | + train_Yvar = 0.1 * torch.rand_like(train_Y) |
| 605 | + training_data_with_yvar = SupervisedDataset( |
| 606 | + X=train_X, |
| 607 | + Y=train_Y, |
| 608 | + Yvar=train_Yvar, |
| 609 | + feature_names=["x1", "x2", "x3"], |
| 610 | + outcome_names=["y"], |
| 611 | + ) |
| 612 | + inputs_with_yvar = EnsembleMapSaasSingleTaskGP.construct_inputs( |
| 613 | + training_data=training_data_with_yvar, num_taus=3 |
| 614 | + ) |
| 615 | + self.assertIn("train_Yvar", inputs_with_yvar) |
| 616 | + self.assertAllClose(inputs_with_yvar["train_Yvar"], train_Yvar) |
| 617 | + self.assertEqual(inputs_with_yvar["num_taus"], 3) |
| 618 | + model_with_yvar = EnsembleMapSaasSingleTaskGP(**inputs_with_yvar) |
| 619 | + self.assertIsInstance(model_with_yvar, EnsembleMapSaasSingleTaskGP) |
| 620 | + self.assertEqual(model_with_yvar.batch_shape, torch.Size([3])) |
| 621 | + |
571 | 622 |
|
572 | 623 | class TestAdditiveMapSaasSingleTaskGP(BotorchTestCase):
|
573 | 624 | def _get_data_and_model(
|
|
0 commit comments