Skip to content

Commit 2a4baa1

Browse files
esantorellafacebook-github-bot
authored andcommitted
Warn when optimizer kwargs are being ignored in BoTorch optim utils _filter_kwargs (meta-pytorch#1645)
Summary: Pull Request resolved: meta-pytorch#1645 The HOGP tutorial was using kwargs 'maxiter' and 'disp' that were being silently ignored since the Adam optimizer doesn't take those arguments. Reviewed By: saitcakmak Differential Revision: D42729349 fbshipit-source-id: 66361b2f0787196cb200f161357f37d906d82fdd
1 parent 7f3aa92 commit 2a4baa1

File tree

2 files changed

+30
-3
lines changed

2 files changed

+30
-3
lines changed

botorch/optim/utils/common.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from inspect import signature
1212
from logging import debug as logging_debug
1313
from typing import Any, Callable, Optional, Tuple
14-
from warnings import warn_explicit, WarningMessage
14+
from warnings import warn, warn_explicit, WarningMessage
1515

1616
import numpy as np
1717
from linear_operator.utils.errors import NanError, NotPSDError
@@ -29,7 +29,15 @@ class _TDefault:
2929
def _filter_kwargs(function: Callable, **kwargs: Any) -> Any:
3030
r"""Filter out kwargs that are not applicable for a given function.
3131
Return a copy of given kwargs dict with only the required kwargs."""
32-
return {k: v for k, v in kwargs.items() if k in signature(function).parameters}
32+
allowed_params = signature(function).parameters
33+
removed = {k for k in kwargs.keys() if k not in allowed_params}
34+
if len(removed) > 0:
35+
warn(
36+
f"Keyword arguments {list(removed)} will be ignored because they are"
37+
f" not allowed parameters for function {function.__name__}. Allowed "
38+
f"parameters are {list(allowed_params.keys())}."
39+
)
40+
return {k: v for k, v in kwargs.items() if k not in removed}
3341

3442

3543
def _handle_numerical_errors(

test/optim/utils/test_common.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,31 @@
1010
from warnings import catch_warnings, warn
1111

1212
import numpy as np
13-
from botorch.optim.utils import _handle_numerical_errors, _warning_handler_template
13+
from botorch.optim.utils import (
14+
_filter_kwargs,
15+
_handle_numerical_errors,
16+
_warning_handler_template,
17+
)
1418
from botorch.utils.testing import BotorchTestCase
1519
from linear_operator.utils.errors import NanError, NotPSDError
1620

1721

1822
class TestUtilsCommon(BotorchTestCase):
23+
def test__filter_kwargs(self) -> None:
24+
def mock_adam(params, lr: float = 0.001) -> None:
25+
return # pragma: nocover
26+
27+
kwargs = {"lr": 0.01, "maxiter": 3000}
28+
with catch_warnings(record=True) as ws:
29+
valid_kwargs = _filter_kwargs(mock_adam, **kwargs)
30+
expected_msg = (
31+
"Keyword arguments ['maxiter'] will be ignored because they are not"
32+
" allowed parameters for function mock_adam. Allowed parameters are "
33+
"['params', 'lr']."
34+
)
35+
self.assertEqual(expected_msg, str(ws[0].message))
36+
self.assertEqual(set(valid_kwargs.keys()), {"lr"})
37+
1938
def test_handle_numerical_errors(self):
2039
x = np.zeros(1, dtype=np.float64)
2140

0 commit comments

Comments
 (0)