Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizer parameters validation improvement #542

Open
himkwtn opened this issue Aug 7, 2024 · 1 comment
Open

Optimizer parameters validation improvement #542

himkwtn opened this issue Aug 7, 2024 · 1 comment
Labels
enhancement New feature or request

Comments

@himkwtn
Copy link
Collaborator

himkwtn commented Aug 7, 2024

Right now, parameters validation is done in the constructor as follows

def __init__(
self,
threshold=0.1,
thresholds=None,
nu=1.0,
tol=1e-5,
thresholder="L0",
trimming_fraction=0.0,
trimming_step_size=1.0,
max_iter=30,
copy_X=True,
initial_guess=None,
normalize_columns=False,
verbose=False,
unbias=False,
):
super(SR3, self).__init__(
max_iter=max_iter,
initial_guess=initial_guess,
copy_X=copy_X,
normalize_columns=normalize_columns,
unbias=unbias,
)
if threshold < 0:
raise ValueError("threshold cannot be negative")

According to the scikit-learn documents, validations should be done in the fit method because if we call set_params, it will bypass the validation in the constructor.

Reproducing code example:

from pysindy.optimizers import SR3
opt = SR3(threshold=-1)
# raises "ValueError: threshold cannot be negative"
from pysindy.optimizers import SR3
opt = SR3()
opt.set_params(threshold=-1)
# no error
@himkwtn himkwtn added the enhancement New feature or request label Aug 7, 2024
@Jacob-Stevens-Haas
Copy link
Collaborator

This is a good point, although fairly low impact, since I believe set_params() is only used in gridsearch, which is not done much with SINDy, and we can cheat by doing the same validation in set_params() that we do in __init__(). That said, this is something we want, and the way to do this is:

  • list out all classes that break the scikit-learn API in differentiation, feature library, and optimizer
  • extract all validation into a helper function, refactored to the end of __init__()
  • Modify tests for bad argument combinations to require __init__ to pass, and then
  • move the validation from __init__ to fit()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants