Skip to content

Commit

Permalink
Ignore sequential argument when using optimize_acqf_mixed (facebo…
Browse files Browse the repository at this point in the history
…ok#2545)

Summary:
Pull Request resolved: facebook#2545

Passing `sequential` has been raising a deprecation warning, and stopping passing it enables  pytorch/botorch#2390 .

Currently, the `sequential` argument is always provided by Ax and can be overridden by the user. This solution silently ignores the argument whenever the mixed optimizer is used. A nicer solution would be for Ax to only construct the `sequential` argument when it is needed and for there to be an exception when the user passes `sequential=False` and the mixed optimizer is used. If BoTorch plans to eventually enable `sequential=True` with `optimize_acqf_mixed`, then the exception should be raised by BoTorch so that Ax doesn't have to stay in sync with BoTorch's current capabilities. However, I think this code could use a thorough cleanup, so I went with the simple solution rather than add more `if` statements.

Reviewed By: Balandat

Differential Revision: D59057005

fbshipit-source-id: c545f030672b7e30c4405c734b7e7c6605d8a1f8
  • Loading branch information
esantorella authored and facebook-github-bot committed Jun 26, 2024
1 parent 8bc2c59 commit 18d0842
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
2 changes: 2 additions & 0 deletions ax/models/torch/botorch_modular/acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,8 @@ def optimize(
return candidates, acqf_values, arm_weights

# 2b. Handle mixed search spaces that have discrete and continuous features.
# Only sequential optimization is supported for `optimize_acqf_mixed`.
optimizer_options_with_defaults.pop("sequential")
candidates, acqf_values = optimize_acqf_mixed(
acq_function=self.acqf,
bounds=bounds,
Expand Down
5 changes: 1 addition & 4 deletions ax/models/torch/tests/test_acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,9 +510,7 @@ def test_optimize_acqf_discrete_local_search(
tkwargs = {"dtype": self.X.dtype, "device": self.X.device}
ssd = SearchSpaceDigest(
feature_names=["a", "b", "c"],
# pyre-fixme[6]: For 2nd param expected `List[Tuple[Union[float, int],
# Union[float, int]]]` but got `List[Tuple[int, int, int]]`.
bounds=[(0, 0, 0), (1, 1, 1)],
bounds=[(0, 1) for _ in range(3)],
categorical_features=[0, 1, 2],
discrete_choices={ # 30 * 60 * 90 > 100,000
k: np.linspace(0, 1, 30 * (k + 1)).tolist() for k in range(3)
Expand Down Expand Up @@ -568,7 +566,6 @@ def test_optimize_mixed(self, mock_optimize_acqf_mixed: Mock) -> None:
)
mock_optimize_acqf_mixed.assert_called_with(
acq_function=acquisition.acqf,
sequential=True,
bounds=mock.ANY,
q=3,
options={"init_batch_limit": 32, "batch_limit": 5},
Expand Down

0 comments on commit 18d0842

Please sign in to comment.