Skip to content

Commit 63dd0cd

Browse files
esantorellafacebook-github-bot
authored andcommitted
Bug fix: _filter_kwargs was erroring when provided a function without a __name__ attribute (#1678)
Summary: Pull Request resolved: #1678 See #1667 Reviewed By: danielrjiang Differential Revision: D43286116 fbshipit-source-id: 3da3e6ff23b517f5379ee90f407dc04d4f2ad06e
1 parent ad38736 commit 63dd0cd

File tree

3 files changed

+20
-12
lines changed

3 files changed

+20
-12
lines changed

botorch/optim/utils/common.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,14 @@ def _filter_kwargs(function: Callable, **kwargs: Any) -> Any:
2323
allowed_params = signature(function).parameters
2424
removed = {k for k in kwargs.keys() if k not in allowed_params}
2525
if len(removed) > 0:
26+
fn_descriptor = (
27+
f" for function {function.__name__}"
28+
if hasattr(function, "__name__")
29+
else ""
30+
)
2631
warn(
2732
f"Keyword arguments {list(removed)} will be ignored because they are"
28-
f" not allowed parameters for function {function.__name__}. Allowed "
33+
f" not allowed parameters{fn_descriptor}. Allowed "
2934
f"parameters are {list(allowed_params.keys())}."
3035
)
3136
return {k: v for k, v in kwargs.items() if k not in removed}

test/optim/test_optimize.py

-6
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,10 @@ def test_optimize_acqf_joint(
115115
mock_gen_candidates_scipy,
116116
mock_gen_candidates_torch,
117117
):
118-
# Mocks don't have a __name__ attribute.
119-
# Set the attribute, since it is needed for testing _filter_kwargs
120118
if mock_gen_candidates == mock_gen_candidates_torch:
121119
mock_signature.return_value = signature(gen_candidates_torch)
122120
else:
123121
mock_signature.return_value = signature(gen_candidates_scipy)
124-
mock_gen_candidates.__name__ = "gen_candidates"
125122

126123
mock_gen_batch_initial_conditions.return_value = torch.zeros(
127124
num_restarts, q, 3, device=self.device, dtype=dtype
@@ -835,13 +832,10 @@ def nlc(x):
835832
mock_gen_candidates_torch,
836833
mock_gen_candidates_scipy,
837834
):
838-
# Mocks don't have a __name__ attribute.
839-
# Set the attribute, since it is needed for testing _filter_kwargs
840835
if mock_gen_candidates == mock_gen_candidates_torch:
841836
mock_signature.return_value = signature(gen_candidates_torch)
842837
else:
843838
mock_signature.return_value = signature(gen_candidates_scipy)
844-
mock_gen_candidates.__name__ = "gen_candidates"
845839
for dtype in (torch.float, torch.double):
846840

847841
mock_acq_function = MockAcquisitionFunction()

test/optim/utils/test_common.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,23 @@ def mock_adam(params, lr: float = 0.001) -> None:
2525
return # pragma: nocover
2626

2727
kwargs = {"lr": 0.01, "maxiter": 3000}
28-
with catch_warnings(record=True) as ws:
28+
expected_msg = (
29+
r"Keyword arguments \['maxiter'\] will be ignored because they are "
30+
r"not allowed parameters for function mock_adam. Allowed parameters "
31+
r"are \['params', 'lr'\]."
32+
)
33+
34+
with self.assertWarnsRegex(Warning, expected_msg):
2935
valid_kwargs = _filter_kwargs(mock_adam, **kwargs)
36+
self.assertEqual(set(valid_kwargs.keys()), {"lr"})
37+
38+
mock_partial = partial(mock_adam, lr=2.0)
3039
expected_msg = (
31-
"Keyword arguments ['maxiter'] will be ignored because they are not"
32-
" allowed parameters for function mock_adam. Allowed parameters are "
33-
"['params', 'lr']."
40+
r"Keyword arguments \['maxiter'\] will be ignored because they are "
41+
r"not allowed parameters. Allowed parameters are \['params', 'lr'\]."
3442
)
35-
self.assertEqual(expected_msg, str(ws[0].message))
43+
with self.assertWarnsRegex(Warning, expected_msg):
44+
valid_kwargs = _filter_kwargs(mock_partial, **kwargs)
3645
self.assertEqual(set(valid_kwargs.keys()), {"lr"})
3746

3847
def test_handle_numerical_errors(self):

0 commit comments

Comments
 (0)