Skip to content

Commit 443bcc5

Browse files
saitcakmakmeta-codesync[bot]
authored andcommitted
Improve initialization with continuous relaxation in optimize_acqf_mixed_alternating (meta-pytorch#3041)
Summary: Pull Request resolved: meta-pytorch#3041 `post_processing_func` is often used to apply Ax transforms to ensure that the generated points are aligned with the Ax search space. The transforms expects the values they are called with to be a member of the transformed search space. Since the initialization step already includes rounding operations, and the post processing function is already applied before returning from `optimize_acqf_mixed_alternating`, we can skip it at initialization. Not using `post_processing_func` in the initialization gets around the errors that may be raised by `post_processing_func` (due to the points not being in the specified search space), and will eliminate excessive fallback to random initialization that we see in Ax. Reviewed By: dme65 Differential Revision: D83979491 fbshipit-source-id: 783d634ff24c159d1a46777e534c372da60974b4
1 parent 3ec2ab0 commit 443bcc5

File tree

2 files changed

+57
-1
lines changed

2 files changed

+57
-1
lines changed

botorch/optim/optimize_mixed.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,7 @@ def generate_starting_points(
551551
"batch_limit": options.get("batch_limit", MAX_BATCH_SIZE),
552552
"init_batch_limit": options.get("init_batch_limit", MAX_BATCH_SIZE),
553553
},
554+
post_processing_func=None,
554555
)
555556
x_init_candts, _ = _optimize_acqf(opt_inputs=updated_opt_inputs)
556557
x_init_candts = x_init_candts.squeeze(-2).detach()

test/optim/test_optimize_mixed.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import random
8+
import warnings
89
from dataclasses import fields
910
from itertools import product
1011
from typing import Any, Callable
@@ -58,6 +59,7 @@ def _make_opt_inputs(
5859
fixed_features: dict[int, float] | None = None,
5960
return_best_only: bool = True,
6061
sequential: bool = True,
62+
post_processing_func: Callable[[Tensor], Tensor] | None = None,
6163
) -> OptimizeAcqfInputs:
6264
r"""Helper to construct `OptimizeAcqfInputs` from limited inputs."""
6365
return OptimizeAcqfInputs(
@@ -71,7 +73,7 @@ def _make_opt_inputs(
7173
equality_constraints=equality_constraints,
7274
nonlinear_inequality_constraints=nonlinear_inequality_constraints,
7375
fixed_features=fixed_features or {},
74-
post_processing_func=None,
76+
post_processing_func=post_processing_func,
7577
batch_initial_conditions=None,
7678
return_best_only=return_best_only,
7779
gen_candidates=gen_candidates_scipy,
@@ -1458,3 +1460,56 @@ def test_optimize_acqf_mixed_continuous_relaxation(self) -> None:
14581460
self.assertAllClose(
14591461
discrete_call_args["opt_inputs"].post_processing_func(X), X_expected
14601462
)
1463+
1464+
def test_initialization_w_continuous_relaxation(self) -> None:
1465+
# Testing with integer variables.
1466+
train_X, _, _, cont_dims = self._get_data()
1467+
# Update the data to introduce integer dimensions.
1468+
cat_dims = {0: [0, 1]}
1469+
discrete_dims = {3: list(range(41)), 4: list(range(16))}
1470+
all_integer_dims: list[int] = [0, 3, 4]
1471+
1472+
def org_post_proc_func(X: Tensor) -> Tensor:
1473+
# Just error out if things are not rounded already.
1474+
# This stands in for any Ax transform that expects discrete values.
1475+
if X[..., all_integer_dims].round() != X[..., all_integer_dims]:
1476+
raise ValueError("Expected discrete values")
1477+
return X
1478+
1479+
# The key test is that this call doesn't error out.
1480+
with mock.patch(
1481+
"botorch.optim.optimize_mixed._optimize_acqf", wraps=_optimize_acqf
1482+
) as mock_opt, warnings.catch_warnings(record=True) as ws:
1483+
X, _ = generate_starting_points(
1484+
opt_inputs=_make_opt_inputs(
1485+
acq_function=qLogNoisyExpectedImprovement(
1486+
model=QuadraticDeterministicModel(
1487+
root=torch.zeros(
1488+
train_X.shape[-1],
1489+
device=self.device,
1490+
dtype=torch.double,
1491+
)
1492+
),
1493+
X_baseline=train_X,
1494+
cache_root=False,
1495+
),
1496+
bounds=torch.tensor(
1497+
[[0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 40.0, 15.0]],
1498+
dtype=torch.double,
1499+
device=self.device,
1500+
),
1501+
post_processing_func=org_post_proc_func,
1502+
num_restarts=4,
1503+
),
1504+
discrete_dims=discrete_dims,
1505+
cat_dims=cat_dims,
1506+
cont_dims=torch.tensor(cont_dims, device=self.device),
1507+
)
1508+
# Check that it was called with no post processing func.
1509+
mock_opt.assert_called_once()
1510+
self.assertIsNone(mock_opt.call_args.kwargs["opt_inputs"].post_processing_func)
1511+
# Check that optimization failure warning is not raised.
1512+
self.assertFalse(any(issubclass(w.category, OptimizationWarning) for w in ws))
1513+
# Check that generated points are rounded.
1514+
self.assertEqual(X.shape, torch.Size([4, train_X.shape[-1]]))
1515+
self.assertAllClose(X[..., all_integer_dims], X[..., all_integer_dims].round())

0 commit comments

Comments
 (0)