Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] require scikit-learn>=0.24.2, make scikit-learn estimators compatible with scikit-learn>=1.6.0dev #6651

Open
wants to merge 38 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
1adb77b
__sklearn_tags__ replacing sklearn's BaseEstimator._more_tags_
vnherdeiro Sep 11, 2024
8ed87d2
fixing tags dict -> dataclass
vnherdeiro Sep 11, 2024
32ec431
fixing wrong import
vnherdeiro Sep 11, 2024
ade9798
remove type hint
vnherdeiro Sep 11, 2024
2085a12
remove type hint
vnherdeiro Sep 11, 2024
a9ec348
fix linting
vnherdeiro Sep 11, 2024
fcc4e12
triggering new CI (scikit-learn dev has changed)
vnherdeiro Sep 14, 2024
3b15646
bringing back _more_tags, adding convertsion from more_tags to sklear…
vnherdeiro Sep 15, 2024
34d9eb4
lint fix
vnherdeiro Sep 15, 2024
6d20ef8
Update python-package/lightgbm/sklearn.py
vnherdeiro Sep 16, 2024
d715311
adressing PR comments
vnherdeiro Sep 16, 2024
c4ec9a4
move comment
jameslamb Sep 16, 2024
b0a4703
updates
jameslamb Sep 21, 2024
7eb861a
remove uses of super()
jameslamb Sep 24, 2024
b137ba2
fix version constraint in lint job, add one more comment
jameslamb Sep 24, 2024
d1915c0
Update python-package/lightgbm/sklearn.py
jameslamb Sep 24, 2024
6cf2158
Merge branch 'master' into fix_sklearn_more_tags_deprecation
jameslamb Sep 26, 2024
b5663aa
Merge branch 'fix_sklearn_more_tags_deprecation' of github.com:vnherd…
jameslamb Sep 26, 2024
118efd9
use scikit-learn 1.6 nightlies again, move some code to compat.py, re…
jameslamb Oct 2, 2024
4fb82f3
optionally use validate_data(), starting in scikit-learn 1.6
jameslamb Oct 3, 2024
c42c53d
fix validate_data() for older versions, update tests
jameslamb Oct 4, 2024
58d77e7
Merge branch 'master' of github.com:microsoft/LightGBM into fix_sklea…
jameslamb Oct 4, 2024
33fb5b6
more changes
jameslamb Oct 4, 2024
6689faa
fix n_features_in setting
jameslamb Oct 5, 2024
9a05670
fix return type
jameslamb Oct 5, 2024
815433f
remove now-unnecessary _LGBMCheckXY()
jameslamb Oct 5, 2024
ffebe41
correct comment
jameslamb Oct 5, 2024
722474d
Merge branch 'master' of github.com:microsoft/LightGBM into fix_sklea…
jameslamb Oct 6, 2024
f2cb2fe
Apply suggestions from code review
jameslamb Oct 6, 2024
86b5ab3
move __version__ import to compat.py, test with all ML tasks
jameslamb Oct 6, 2024
125f4ea
just set the setters and deleters
jameslamb Oct 6, 2024
4233d70
set floor of scikit-learn>=0.24.2, fix ordering of n_features_in_ set…
jameslamb Oct 6, 2024
330df3f
fix conflicts
jameslamb Oct 6, 2024
e8e4cdb
Update python-package/lightgbm/sklearn.py
jameslamb Oct 6, 2024
0b0ea24
Merge branch 'master' into fix_sklearn_more_tags_deprecation
jameslamb Oct 6, 2024
f22e494
forgot to commit ... fix comment
jameslamb Oct 7, 2024
b124797
Merge branch 'master' of github.com:microsoft/LightGBM into fix_sklea…
jameslamb Oct 7, 2024
beab71c
Merge branch 'fix_sklearn_more_tags_deprecation' of github.com:vnherd…
jameslamb Oct 7, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/test-python-latest.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ python -m pip install \
'numpy>=2.0.0.dev0' \
'matplotlib>=3.10.0.dev0' \
'pandas>=3.0.0.dev0' \
'scikit-learn==1.5.*' \
'scikit-learn>=1.6.dev0' \
'scipy>=1.15.0.dev0'

python -m pip install \
Expand Down
2 changes: 1 addition & 1 deletion .ci/test-python-oldest.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pip install \
'numpy==1.19.0' \
'pandas==1.1.3' \
'pyarrow==6.0.1' \
'scikit-learn==0.24.0' \
'scikit-learn==0.24.2' \
'scipy==1.6.0' \
|| exit 1
echo "done installing lightgbm's dependencies"
Expand Down
1 change: 1 addition & 0 deletions .ci/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ if [[ $TASK == "lint" ]]; then
'mypy>=1.11.1' \
'pre-commit>=3.8.0' \
'pyarrow-core>=17.0' \
'scikit-learn>=1.5.2' \
'r-lintr>=3.1.2'
source activate $CONDA_ENV
echo "Linting Python code"
Expand Down
83 changes: 78 additions & 5 deletions python-package/lightgbm/compat.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# coding: utf-8
"""Compatibility library."""

from typing import Any, List
from typing import TYPE_CHECKING, Any, List

# scikit-learn is intentionally imported first here,
# see https://github.com/microsoft/LightGBM/issues/6509
"""sklearn"""
try:
from sklearn import __version__ as _sklearn_version
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_sample_weight
Expand All @@ -29,6 +30,69 @@ def _check_sample_weight(sample_weight: Any, X: Any, dtype: Any = None) -> Any:
check_consistent_length(sample_weight, X)
return sample_weight

try:
from sklearn.utils.validation import validate_data
except ImportError:
# validate_data() was added in scikit-learn 1.6, this function roughly imitates it for older versions.
# It can be removed when lightgbm's minimum scikit-learn version is at least 1.6.
def validate_data(
_estimator,
X,
y="no_validation",
accept_sparse: bool = True,
# 'force_all_finite' was renamed to 'ensure_all_finite' in scikit-learn 1.6
ensure_all_finite: bool = False,
ensure_min_samples: int = 1,
# trap other keyword arguments that only work on scikit-learn >=1.6, like 'reset'
**ignored_kwargs,
):
# it's safe to import _num_features unconditionally because:
#
# * it was first added in scikit-learn 0.24.2
# * lightgbm cannot be used with scikit-learn versions older than that
# * this validate_data() re-implementation will not be called in scikit-learn>=1.6
#
from sklearn.utils.validation import _num_features

# _num_features() raises a TypeError on 1-dimensional input. That's a problem
# because scikit-learn's 'check_fit1d' estimator check sets that expectation that
# estimators must raise a ValueError when a 1-dimensional input is passed to fit().
#
# So here, lightgbm avoids calling _num_features() on 1-dimensional inputs.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if hasattr(X, "shape") and len(X.shape) == 1:
n_features_in_ = 1
else:
n_features_in_ = _num_features(X)

no_val_y = isinstance(y, str) and y == "no_validation"

# NOTE: check_X_y() calls check_array() internally, so only need to call one or the other of them here
if no_val_y:
X = check_array(X, accept_sparse=accept_sparse, force_all_finite=ensure_all_finite)
else:
X, y = check_X_y(
X,
y,
accept_sparse=accept_sparse,
force_all_finite=ensure_all_finite,
ensure_min_samples=ensure_min_samples,
)

# this only needs to be updated at fit() time
_estimator.n_features_in_ = n_features_in_

# raise the same error that scikit-learn's `validate_data()` does on scikit-learn>=1.6
if _estimator.__sklearn_is_fitted__() and _estimator._n_features != n_features_in_:
raise ValueError(
f"X has {n_features_in_} features, but {_estimator.__class__.__name__} "
f"is expecting {_estimator._n_features} features as input."
)

if no_val_y:
return X
else:
return X, y

SKLEARN_INSTALLED = True
_LGBMBaseCrossValidator = BaseCrossValidator
_LGBMModelBase = BaseEstimator
Expand All @@ -38,12 +102,11 @@ def _check_sample_weight(sample_weight: Any, X: Any, dtype: Any = None) -> Any:
LGBMNotFittedError = NotFittedError
_LGBMStratifiedKFold = StratifiedKFold
_LGBMGroupKFold = GroupKFold
_LGBMCheckXY = check_X_y
_LGBMCheckArray = check_array
_LGBMCheckSampleWeight = _check_sample_weight
_LGBMAssertAllFinite = assert_all_finite
_LGBMCheckClassificationTargets = check_classification_targets
_LGBMComputeSampleWeight = compute_sample_weight
_LGBMValidateData = validate_data
except ImportError:
SKLEARN_INSTALLED = False

Expand All @@ -67,12 +130,22 @@ class _LGBMRegressorBase: # type: ignore
LGBMNotFittedError = ValueError
_LGBMStratifiedKFold = None
_LGBMGroupKFold = None
_LGBMCheckXY = None
_LGBMCheckArray = None
_LGBMCheckSampleWeight = None
_LGBMAssertAllFinite = None
_LGBMCheckClassificationTargets = None
_LGBMComputeSampleWeight = None
_LGBMValidateData = None
_sklearn_version = None

# additional scikit-learn imports only for type hints
if TYPE_CHECKING:
# sklearn.utils.Tags can be imported unconditionally once
# lightgbm's minimum scikit-learn version is 1.6 or higher
try:
from sklearn.utils import Tags as _sklearn_Tags
except ImportError:
_sklearn_Tags = None


"""pandas"""
try:
Expand Down
Loading
Loading