Skip to content

Commit 7df149c

Browse files
sdaultonfacebook-github-bot
authored andcommitted
add one hot to numeric input transform (meta-pytorch#1517)
Summary: Pull Request resolved: meta-pytorch#1517 see title Differential Revision: https://internalfb.com/D41482322 fbshipit-source-id: 56703ad4dca8b09167b8d8365043b8b8b33148f5
1 parent d026e80 commit 7df149c

File tree

2 files changed

+213
-0
lines changed

2 files changed

+213
-0
lines changed

botorch/models/transforms/input.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from torch import nn, Tensor
3232
from torch.distributions import Kumaraswamy
3333
from torch.nn import Module, ModuleDict
34+
from torch.nn.functional import one_hot
3435

3536

3637
class InputTransform(ABC):
@@ -1370,3 +1371,135 @@ def _expanded_perturbations(self, X: Tensor) -> Tensor:
13701371
else:
13711372
p = p(X) if self.indices is None else p(X[..., self.indices])
13721373
return p.transpose(-3, -2) # p is batch_shape x n_p x n x d
1374+
1375+
1376+
class OneHotToNumeric(InputTransform, Module):
1377+
r"""Transform categorical parameters from a one-hot to a numeric representation.
1378+
1379+
This assumes that the categoricals are the trailing dimensions.
1380+
"""
1381+
1382+
def __init__(
1383+
self,
1384+
dim: int,
1385+
categorical_features: Optional[Dict[int, int]] = None,
1386+
transform_on_train: bool = False,
1387+
transform_on_eval: bool = True,
1388+
transform_on_fantasize: bool = False,
1389+
) -> None:
1390+
r"""Initialize.
1391+
1392+
Args:
1393+
dim: The dimension of the one-hot-encoded input.
1394+
categorical_features: A dictionary mapping the starting index of each
1395+
categorical feature to its cardinality. This assumes that categoricals
1396+
are one-hot encoded.
1397+
transform_on_train: A boolean indicating whether to apply the
1398+
transforms in train() mode. Default: False.
1399+
transform_on_eval: A boolean indicating whether to apply the
1400+
transform in eval() mode. Default: True.
1401+
transform_on_fantasize: A boolean indicating whether to apply the
1402+
transform when called from within a `fantasize` call. Default: False.
1403+
1404+
Returns:
1405+
A `batch_shape x n x d'`-dim tensor of where the one-hot encoded
1406+
categoricals are transformed to integer representation.
1407+
"""
1408+
super().__init__()
1409+
self.transform_on_train = transform_on_train
1410+
self.transform_on_eval = transform_on_eval
1411+
self.transform_on_fantasize = transform_on_fantasize
1412+
categorical_features = categorical_features or {}
1413+
# sort by starting index
1414+
self.categorical_features = OrderedDict(
1415+
sorted(categorical_features.items(), key=lambda x: x[0])
1416+
)
1417+
if len(self.categorical_features) > 0:
1418+
self.categorical_start_idx = min(self.categorical_features.keys())
1419+
# check that the trailing dimensions are categoricals
1420+
end = self.categorical_start_idx
1421+
err_msg = (
1422+
f"{self.__class__.__name__} requires that the categorical "
1423+
"parameters are the rightmost elements."
1424+
)
1425+
for start, card in self.categorical_features.items():
1426+
# the end of one one-hot representation should be followed
1427+
# by the start of the next
1428+
if end != start:
1429+
raise ValueError(err_msg)
1430+
# This assumes that the categoricals are the trailing
1431+
# dimensions
1432+
end = start + card
1433+
if end != dim:
1434+
# check end
1435+
raise ValueError(err_msg)
1436+
# the numeric representation dimension is the total number of parameters
1437+
# (continuous, integer, and categorical)
1438+
self.numeric_dim = self.categorical_start_idx + len(categorical_features)
1439+
1440+
def transform(self, X: Tensor) -> Tensor:
1441+
r"""Transform the categorical inputs into integer representation.
1442+
1443+
Args:
1444+
X: A `batch_shape x n x d`-dim tensor of inputs.
1445+
1446+
Returns:
1447+
A `batch_shape x n x d'`-dim tensor of where the one-hot encoded
1448+
categoricals are transformed to integer representation.
1449+
"""
1450+
if len(self.categorical_features) > 0:
1451+
X_numeric = X[..., : self.numeric_dim].clone()
1452+
idx = self.categorical_start_idx
1453+
for start, card in self.categorical_features.items():
1454+
X_numeric[..., idx] = X[..., start : start + card].argmax(dim=-1)
1455+
idx += 1
1456+
return X_numeric
1457+
return X
1458+
1459+
def untransform(self, X: Tensor) -> Tensor:
1460+
r"""Transform the categoricals from integer representation to one-hot.
1461+
1462+
Args:
1463+
X: A `batch_shape x n x d'`-dim tensor of transformed inputs, where
1464+
the categoricals are represented as integers.
1465+
1466+
Returns:
1467+
A `batch_shape x n x d`-dim tensor of inputs, where the categoricals
1468+
have been transformed to one-hot representation.
1469+
"""
1470+
if len(self.categorical_features) > 0:
1471+
self.numeric_dim
1472+
one_hot_categoricals = [
1473+
# note that self.categorical_features is sorted by the starting index
1474+
# in one-hot representation
1475+
one_hot(
1476+
X[..., idx - len(self.categorical_features)].long(),
1477+
num_classes=cardinality,
1478+
)
1479+
for idx, cardinality in enumerate(self.categorical_features.values())
1480+
]
1481+
X = torch.cat(
1482+
[
1483+
X[..., : self.categorical_start_idx],
1484+
*one_hot_categoricals,
1485+
],
1486+
dim=-1,
1487+
)
1488+
return X
1489+
1490+
def equals(self, other: InputTransform) -> bool:
1491+
r"""Check if another input transform is equivalent.
1492+
1493+
Args:
1494+
other: Another input transform.
1495+
1496+
Returns:
1497+
A boolean indicating if the other transform is equivalent.
1498+
"""
1499+
return (
1500+
type(self) == type(other)
1501+
and (self.transform_on_train == other.transform_on_train)
1502+
and (self.transform_on_eval == other.transform_on_eval)
1503+
and (self.transform_on_fantasize == other.transform_on_fantasize)
1504+
and self.categorical_features == other.categorical_features
1505+
)

test/models/transforms/test_input.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
InputTransform,
2222
Log10,
2323
Normalize,
24+
OneHotToNumeric,
2425
Round,
2526
Warp,
2627
)
@@ -915,6 +916,85 @@ def test_warp_transform(self):
915916
warp_tf._set_concentration(i=1, value=3.0)
916917
self.assertTrue((warp_tf.concentration1 == 3.0).all())
917918

919+
def test_one_hot_to_numeric(self):
920+
dim = 8
921+
# test exception when categoricals are not the trailing dimensions
922+
categorical_features = {0: 2}
923+
with self.assertRaises(ValueError):
924+
OneHotToNumeric(dim=dim, categorical_features=categorical_features)
925+
# categoricals at start and end of X but not in between
926+
categorical_features = {0: 3, 6: 2}
927+
with self.assertRaises(ValueError):
928+
OneHotToNumeric(dim=dim, categorical_features=categorical_features)
929+
for dtype in (torch.float, torch.double):
930+
categorical_features = {6: 2, 3: 3}
931+
tf = OneHotToNumeric(dim=dim, categorical_features=categorical_features)
932+
tf.eval()
933+
self.assertEqual(tf.categorical_features, {3: 3, 6: 2})
934+
cat1_numeric = torch.randint(0, 3, (3,), device=self.device)
935+
cat1 = one_hot(cat1_numeric, num_classes=3)
936+
cat2_numeric = torch.randint(0, 2, (3,), device=self.device)
937+
cat2 = one_hot(cat2_numeric, num_classes=2)
938+
cont = torch.rand(3, 3, dtype=dtype, device=self.device)
939+
X = torch.cat([cont, cat1, cat2], dim=-1)
940+
# test forward
941+
X_numeric = tf(X)
942+
expected = torch.cat(
943+
[
944+
cont,
945+
cat1_numeric.view(-1, 1).to(cont),
946+
cat2_numeric.view(-1, 1).to(cont),
947+
],
948+
dim=-1,
949+
)
950+
self.assertTrue(torch.equal(X_numeric, expected))
951+
952+
# test untransform
953+
X2 = tf.untransform(X_numeric)
954+
self.assertTrue(torch.equal(X2, X))
955+
956+
# test no
957+
tf = OneHotToNumeric(dim=dim, categorical_features={})
958+
tf.eval()
959+
X_tf = tf(X)
960+
self.assertTrue(torch.equal(X, X_tf))
961+
X2 = tf(X_tf)
962+
self.assertTrue(torch.equal(X2, X_tf))
963+
964+
# test no transform on eval
965+
tf2 = OneHotToNumeric(
966+
dim=dim, categorical_features=categorical_features, transform_on_eval=False
967+
)
968+
tf2.eval()
969+
X_tf = tf2(X)
970+
self.assertTrue(torch.equal(X, X_tf))
971+
972+
# test no transform on train
973+
tf2 = OneHotToNumeric(
974+
dim=dim, categorical_features=categorical_features, transform_on_train=False
975+
)
976+
X_tf = tf2(X)
977+
self.assertTrue(torch.equal(X, X_tf))
978+
tf2.eval()
979+
X_tf = tf2(X)
980+
self.assertFalse(torch.equal(X, X_tf))
981+
982+
# test equals
983+
tf3 = OneHotToNumeric(
984+
dim=dim, categorical_features=categorical_features, transform_on_train=False
985+
)
986+
self.assertTrue(tf3.equals(tf2))
987+
# test different transform_on_train
988+
tf3 = OneHotToNumeric(
989+
dim=dim, categorical_features=categorical_features, transform_on_train=True
990+
)
991+
self.assertFalse(tf3.equals(tf2))
992+
# test categorical features
993+
tf3 = OneHotToNumeric(
994+
dim=dim, categorical_features={}, transform_on_train=False
995+
)
996+
self.assertFalse(tf3.equals(tf2))
997+
918998

919999
class TestAppendFeatures(BotorchTestCase):
9201000
def test_append_features(self):

0 commit comments

Comments
 (0)