Skip to content

Commit 26e96d4

Browse files
Jihao Andreas Linmeta-codesync[bot]
authored andcommitted
Implement construct_inputs class method for Ax integration (meta-pytorch#3037)
Summary: Pull Request resolved: meta-pytorch#3037 LatentKroneckerGP requires `train_X`, `train_T`, and `train_Y` as input data, where `train_X` and `train_T` define the Cartesian product space and `train_Y` are the corresponding observations (with potentially missing values). Ax provides the data as samples from the product space and we need to separate it into the individual factors. For example, let X = [a, b, c] and T = [0, 1], then the full product space is {(a, 0), (a, 1), (b, 0), (b, 1), (c, 0), (c, 1)}. Ax would provide us with observations like x1 = (a, 0), y1 = 1 x2 = (a, 1), y2 = 2 x3 = (b, 0), y3 = 3 x4 = (c, 1), y4 = 4 and we need to transform them into X = [a, b, c], T = [0, 1], and Y = [[1, 2], [3, nan], [nan, 4]] (note that y values for (b, 1) and (c, 0) are missing). Reviewed By: saitcakmak Differential Revision: D83781022 fbshipit-source-id: 6a0d153fd8f776a4a33acf1f0581d76a0ba31148
1 parent 2cc41dc commit 26e96d4

File tree

2 files changed

+135
-11
lines changed

2 files changed

+135
-11
lines changed

botorch/models/latent_kronecker_gp.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,20 @@
2424
"""
2525

2626
import contextlib
27+
import warnings
2728
from typing import Any
2829

2930
import torch
3031
from botorch.acquisition.objective import PosteriorTransform
3132
from botorch.exceptions.errors import BotorchTensorDimensionError
33+
from botorch.exceptions.warnings import InputDataWarning
3234
from botorch.models.gpytorch import GPyTorchModel
3335
from botorch.models.model import FantasizeMixin, Model
3436
from botorch.models.transforms.input import InputTransform
3537
from botorch.models.transforms.outcome import OutcomeTransform, Standardize
3638
from botorch.posteriors.gpytorch import GPyTorchPosterior
3739
from botorch.posteriors.latent_kronecker import LatentKroneckerGPPosterior
40+
from botorch.utils.datasets import SupervisedDataset
3841
from botorch.utils.types import _DefaultType, DEFAULT
3942
from gpytorch.distributions import MultivariateNormal
4043
from gpytorch.kernels import MaternKernel, ScaleKernel
@@ -427,3 +430,64 @@ def condition_on_observations(
427430
raise NotImplementedError(
428431
f"Conditioning currently not supported for {self.__class__.__name__}"
429432
)
433+
434+
@classmethod
435+
def construct_inputs(cls, training_data: SupervisedDataset) -> dict[str, Any]:
436+
"""
437+
Constructs the input tensors for LatentKroneckerGP from a SupervisedDataset.
438+
439+
This method processes the provided training data to extract and organize the
440+
features and targets into the required format for the LatentKroneckerGP model.
441+
It factorizes inputs from the product space into the factors X and T.
442+
The matching output Y values are assembled by mapping observed values to their
443+
corresponding positions and filling missing values with NaN.
444+
445+
Args:
446+
training_data: A SupervisedDataset containing training inputs and outputs.
447+
448+
Returns:
449+
A dictionary with keys `train_X`, `train_T`, and `train_Y`, where:
450+
- `train_X`: The unique feature values (excluding the T dimension).
451+
- `train_T`: The unique feature values of the T dimension.
452+
- `train_Y`: The outputs aligned with the Cartesian product of
453+
`train_X` and `train_T`, with missing values filled as NaN.
454+
"""
455+
model_inputs = super().construct_inputs(training_data=training_data)
456+
457+
if "train_Yvar" in model_inputs:
458+
warnings.warn(
459+
"Ignoring Yvar values in provided training data, because "
460+
"they are currently not supported by LatentKroneckerGP.",
461+
InputDataWarning,
462+
stacklevel=2,
463+
)
464+
465+
t_idx = training_data.feature_names.index("step")
466+
x_idx = [i for i in range(len(training_data.feature_names)) if i != t_idx]
467+
468+
# Factorize product space into factors X and T by finding unique values
469+
train_X, x_idx = model_inputs["train_X"][..., x_idx].unique(
470+
sorted=True, return_inverse=True, dim=-2
471+
)
472+
train_T, t_idx = model_inputs["train_X"][..., [t_idx]].unique(
473+
sorted=True, return_inverse=True, dim=-2
474+
)
475+
476+
# Initialize train_Y with NaN for the full Cartesian product
477+
batch_shape = train_X.shape[:-2]
478+
n_x = train_X.shape[-2]
479+
n_t = train_T.shape[-2]
480+
train_Y = torch.full(
481+
(*batch_shape, n_x * n_t, 1),
482+
torch.nan,
483+
dtype=model_inputs["train_Y"].dtype,
484+
device=model_inputs["train_Y"].device,
485+
)
486+
487+
# Convert 2D indices to 1D indices
488+
y_idx = x_idx * n_t + t_idx
489+
# Map original observations to their positions in the Cartesian product
490+
train_Y[..., y_idx, :] = model_inputs["train_Y"]
491+
train_Y = train_Y.reshape(*batch_shape, n_x, n_t)
492+
493+
return {"train_X": train_X, "train_T": train_T, "train_Y": train_Y}

test/models/test_latent_kronecker_gp.py

Lines changed: 71 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@
1010
import torch
1111
from botorch.acquisition.objective import ScalarizedPosteriorTransform
1212
from botorch.exceptions.errors import BotorchTensorDimensionError
13-
from botorch.exceptions.warnings import OptimizationWarning
13+
from botorch.exceptions.warnings import InputDataWarning, OptimizationWarning
1414
from botorch.fit import fit_gpytorch_mll
1515
from botorch.models.latent_kronecker_gp import LatentKroneckerGP
1616
from botorch.models.transforms import Normalize, Standardize
17+
from botorch.utils.datasets import SupervisedDataset
1718
from botorch.utils.testing import BotorchTestCase, get_random_data
1819
from botorch.utils.types import DEFAULT
1920
from gpytorch.kernels import MaternKernel, RBFKernel, ScaleKernel
@@ -38,7 +39,7 @@ def _get_data_with_missing_entries(
3839
mask[torch.randperm(n_train * t)[: n_train * t // 2]] = False
3940
train_Y[..., ~mask.reshape(n_train, t)] = torch.nan
4041

41-
return train_X, train_T, train_Y
42+
return train_X, train_T, train_Y, mask
4243

4344

4445
class TestLatentKroneckerGP(BotorchTestCase):
@@ -71,7 +72,7 @@ def test_default_init(self):
7172
intf = None
7273
octf = None
7374

74-
train_X, train_T, train_Y = _get_data_with_missing_entries(
75+
train_X, train_T, train_Y, mask = _get_data_with_missing_entries(
7576
n_train=n_train, d=d, t=t, batch_shape=batch_shape, tkwargs=tkwargs
7677
)
7778

@@ -85,8 +86,7 @@ def test_default_init(self):
8586
model.to(**tkwargs)
8687

8788
# test init
88-
mask_valid = torch.isfinite(train_Y.reshape(-1, n_train, t)[0]).flatten()
89-
train_Y_flat = train_Y.reshape(*batch_shape, -1)[..., mask_valid]
89+
train_Y_flat = train_Y.reshape(*batch_shape, -1)[..., mask]
9090
if use_transforms:
9191
self.assertIsInstance(model.input_transform, Normalize)
9292
self.assertIsInstance(model.outcome_transform, Standardize)
@@ -124,7 +124,7 @@ def test_custom_init(self):
124124
):
125125
tkwargs = {"device": self.device, "dtype": dtype}
126126

127-
train_X, train_T, train_Y = _get_data_with_missing_entries(
127+
train_X, train_T, train_Y, _ = _get_data_with_missing_entries(
128128
n_train=n_train, d=d, t=t, batch_shape=batch_shape, tkwargs=tkwargs
129129
)
130130

@@ -230,7 +230,7 @@ def test_gp_train(self):
230230
intf = None
231231
octf = None
232232

233-
train_X, train_T, train_Y = _get_data_with_missing_entries(
233+
train_X, train_T, train_Y, _ = _get_data_with_missing_entries(
234234
n_train=n_train, d=d, t=t, batch_shape=batch_shape, tkwargs=tkwargs
235235
)
236236

@@ -271,7 +271,7 @@ def _test_gp_eval_shapes(
271271
intf = None
272272
octf = None
273273

274-
train_X, train_T, train_Y = _get_data_with_missing_entries(
274+
train_X, train_T, train_Y, _ = _get_data_with_missing_entries(
275275
n_train=n_train, d=d, t=t, batch_shape=batch_shape, tkwargs=tkwargs
276276
)
277277

@@ -441,7 +441,7 @@ def test_gp_eval_values(self):
441441
intf = None
442442
octf = None
443443

444-
train_X, train_T, train_Y = _get_data_with_missing_entries(
444+
train_X, train_T, train_Y, _ = _get_data_with_missing_entries(
445445
n_train=n_train, d=d, t=t, batch_shape=batch_shape, tkwargs=tkwargs
446446
)
447447

@@ -507,7 +507,7 @@ def test_iterative_methods(self):
507507
batch_shape = torch.Size([])
508508
tkwargs = {"device": self.device, "dtype": torch.double}
509509

510-
train_X, train_T, train_Y = _get_data_with_missing_entries(
510+
train_X, train_T, train_Y, _ = _get_data_with_missing_entries(
511511
n_train=10, d=1, t=1, batch_shape=batch_shape, tkwargs=tkwargs
512512
)
513513

@@ -525,7 +525,7 @@ def test_not_implemented(self):
525525
batch_shape = torch.Size([])
526526
tkwargs = {"device": self.device, "dtype": torch.double}
527527

528-
train_X, train_T, train_Y = _get_data_with_missing_entries(
528+
train_X, train_T, train_Y, _ = _get_data_with_missing_entries(
529529
n_train=10, d=1, t=1, batch_shape=batch_shape, tkwargs=tkwargs
530530
)
531531

@@ -558,3 +558,63 @@ def test_not_implemented(self):
558558
err_msg = f"Only GaussianLikelihood currently supported for {cls_name}"
559559
with self.assertRaisesRegex(NotImplementedError, err_msg):
560560
model.posterior(train_X)
561+
562+
def test_construct_inputs(self) -> None:
563+
# This test relies on the fact that the random (missing) data generation
564+
# does not remove all occurrences of a particular X or T value. Therefore,
565+
# we fix the random seed and set n_train and t to slightly larger values.
566+
567+
torch.manual_seed(12345)
568+
for batch_shape, n_train, d, t, dtype in itertools.product(
569+
( # batch_shape
570+
torch.Size([]),
571+
torch.Size([1]),
572+
torch.Size([2]),
573+
torch.Size([2, 3]),
574+
),
575+
(15,), # n_train
576+
(1, 2), # d
577+
(10,), # t
578+
(torch.float, torch.double), # dtype
579+
):
580+
tkwargs = {"device": self.device, "dtype": dtype}
581+
582+
train_X, train_T, train_Y, mask = _get_data_with_missing_entries(
583+
n_train=n_train, d=d, t=t, batch_shape=batch_shape, tkwargs=tkwargs
584+
)
585+
586+
train_X_supervised = torch.cat(
587+
[
588+
train_X.repeat_interleave(t, dim=-2),
589+
train_T.repeat(*([1] * len(batch_shape)), n_train, 1),
590+
],
591+
dim=-1,
592+
)
593+
train_Y_supervised = train_Y.reshape(*batch_shape, n_train * t, 1)
594+
595+
# randomly permute data to test robustness to non-contiguous data
596+
idx = torch.randperm(n_train * t, device=self.device)
597+
train_X_supervised = train_X_supervised[..., idx, :][..., mask[idx], :]
598+
train_Y_supervised = train_Y_supervised[..., idx, :][..., mask[idx], :]
599+
600+
dataset = SupervisedDataset(
601+
X=train_X_supervised,
602+
Y=train_Y_supervised,
603+
Yvar=train_Y_supervised, # just to check warning
604+
feature_names=[f"x_{i}" for i in range(d)] + ["step"],
605+
outcome_names=["y"],
606+
)
607+
608+
w_msg = "Ignoring Yvar values in provided training data, because "
609+
w_msg += "they are currently not supported by LatentKroneckerGP."
610+
with self.assertWarnsRegex(InputDataWarning, w_msg):
611+
model_inputs = LatentKroneckerGP.construct_inputs(dataset)
612+
613+
# this test generates train_X and train_T in sorted order
614+
# the data is randomly permuted before passing to construct_inputs
615+
# construct_inputs sorts the data, so we expect the results to be equal
616+
self.assertAllClose(model_inputs["train_X"], train_X, atol=0.0)
617+
self.assertAllClose(model_inputs["train_T"], train_T, atol=0.0)
618+
self.assertAllClose(
619+
model_inputs["train_Y"], train_Y, atol=0.0, equal_nan=True
620+
)

0 commit comments

Comments
 (0)