Skip to content

Commit

Permalink
Silence input data warnings in tests (#2739)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2739

X-link: pytorch/botorch#2508

This logic broke since D61797434 updated the warning messages, leading to too many of these warnings in test outputs again.

Reviewed By: Balandat, esantorella

Differential Revision: D62200731

fbshipit-source-id: a8c802abc613e0b144c6eb817f4692857a4cb83d
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Sep 5, 2024
1 parent 0611c5e commit 66aed70
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 5 deletions.
2 changes: 1 addition & 1 deletion ax/modelbridge/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,7 +904,7 @@ def cross_validate(
# users with this warning, we filter it out.
warnings.filterwarnings(
"ignore",
message=r"Data \(outcome observations\) not standardized",
message=r"Data \(outcome observations\) is not standardized",
category=InputDataWarning,
)
cv_predictions = self._cross_validate(
Expand Down
2 changes: 1 addition & 1 deletion ax/modelbridge/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def cross_validate(
# To avoid confusing users with this warning, we filter it out.
warnings.filterwarnings(
"ignore",
message=r"Data \(outcome observations\) not standardized",
message=r"Data \(outcome observations\) is not standardized",
category=InputDataWarning,
)
cv_test_predictions = model._cross_validate(
Expand Down
2 changes: 1 addition & 1 deletion ax/modelbridge/tests/test_base_modelbridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def warn_and_return_mock_obs(
nonlocal called
called = True
warnings.warn(
"Data (outcome observations) not standardized",
"Data (outcome observations) is not standardized",
InputDataWarning,
stacklevel=2,
)
Expand Down
16 changes: 15 additions & 1 deletion ax/utils/common/tests/test_testutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@

import io
import sys
import warnings

import torch
from ax.utils.common.base import Base
from ax.utils.common.testutils import TestCase
from botorch.exceptions.warnings import InputDataWarning
from botorch.models.gp_regression import SingleTaskGP


# pyre-fixme[3]: Return type must be annotated.
Expand All @@ -19,7 +23,7 @@ def _f():
raise e


F_FAILURE_LINENO = 19 # Line # for the error in `_f`.
F_FAILURE_LINENO = 23 # Line # for the error in `_f`.


def _g() -> None:
Expand Down Expand Up @@ -113,3 +117,13 @@ def decorated_test() -> None:
self.assertEqual(None, self._long_test_active_reason)
decorated_test()
self.assertEqual(None, self._long_test_active_reason)

def test_warning_filtering(self) -> None:
with warnings.catch_warnings(record=True) as ws:
# Model with unstandardized float data, which would typically raise
# multiple warnings.
SingleTaskGP(
train_X=torch.rand(5, 2, dtype=torch.float) * 10,
train_Y=torch.rand(5, 1, dtype=torch.float) * 10,
)
self.assertFalse(any(w.category == InputDataWarning for w in ws))
7 changes: 6 additions & 1 deletion ax/utils/common/testutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,12 @@ def setUp(self) -> None:
# BoTorch input standardization warnings.
warnings.filterwarnings(
"ignore",
message="Input data is not",
message=r"Data \(outcome observations\) is not standardized ",
category=InputDataWarning,
)
warnings.filterwarnings(
"ignore",
message=r"Data \(input features\) is not",
category=InputDataWarning,
)

Expand Down

0 comments on commit 66aed70

Please sign in to comment.