From b43608bf2ffc1b17b19f1aee72613c0fe6e5f227 Mon Sep 17 00:00:00 2001 From: Alexander Fischer Date: Fri, 6 Jun 2025 23:02:42 +0200 Subject: [PATCH 1/2] global options --- docs/quickstart.qmd | 38 ++++ pyfixest/__init__.py | 3 + pyfixest/estimation/estimation.py | 276 +++++++++++++++--------------- pyfixest/options.py | 74 ++++++++ 4 files changed, 250 insertions(+), 141 deletions(-) create mode 100644 pyfixest/options.py diff --git a/docs/quickstart.qmd b/docs/quickstart.qmd index 063ec2191..484b8d65f 100644 --- a/docs/quickstart.qmd +++ b/docs/quickstart.qmd @@ -577,3 +577,41 @@ pf.etable([fit_twfe, fit_did2s]) ``` For more details see the vignette on [Difference-in-Differences Estimation](https://py-econometrics.github.io/pyfixest/difference-in-differences.html). + +## Setting Global Options + +We can set all function arguments of the estimation functions `feols`, `fepois`, and `feglm` as global options. + +For example, we can set the estimation data set, variance covariance matrix, weights, small sample corrections and demeaner backend defaults via `pf.set_options()`. + +```{python} +pf.set_options( + data = pf.get_data().dropna(), + vcov = {"CRV1": "f1"}, + weights = "weights", + demeaner_backend = "rust" +) +``` + +In this case, we don't have to provide data, vcov, weights, etc to the model call: they will automatically be applied. + +```{python} +fit1 = pf.feols(fml = "Y~X1") +``` + +If we actively set a function argument, the global default will be overwritten: + +```{python} +fit2 = pf.feols(fml = "Y ~ X1", vcov = "hetero") + +pf.etable([fit1, fit2]) +``` + +We can set a local option context in the following way: + +```{python} +with option_context(vcov="hetero"): + fit3 = pf.feols("Y ~ X1") + +pf.etable([fit1, fit2, fit3]) +``` diff --git a/pyfixest/__init__.py b/pyfixest/__init__.py index 50e4d72d6..722167b8b 100644 --- a/pyfixest/__init__.py +++ b/pyfixest/__init__.py @@ -23,6 +23,7 @@ rwolf, wyoung, ) +from pyfixest.options import option_context, set_option from pyfixest.report import coefplot, dtable, etable, iplot, make_table, summary from pyfixest.utils import ( get_data, @@ -49,9 +50,11 @@ "iplot", "lpdid", "make_table", + "option_context", "panelview", "report", "rwolf", + "set_option", "ssc", "summary", "utils", diff --git a/pyfixest/estimation/estimation.py b/pyfixest/estimation/estimation.py index cd76349c8..990e69066 100644 --- a/pyfixest/estimation/estimation.py +++ b/pyfixest/estimation/estimation.py @@ -1,9 +1,12 @@ +import functools from collections.abc import Mapping +from inspect import signature from typing import Any, Optional, Union +from dataclasses import dataclass, field import pandas as pd +import narwhals as nw -from pyfixest.errors import FeatureDeprecationError from pyfixest.estimation.feols_ import Feols from pyfixest.estimation.fepois_ import Fepois from pyfixest.estimation.FixestMulti_ import FixestMulti @@ -14,30 +17,44 @@ VcovTypeOptions, WeightsTypeOptions, ) +from pyfixest.options import options from pyfixest.utils.dev_utils import DataFrameType, _narwhals_to_pandas from pyfixest.utils.utils import capture_context -from pyfixest.utils.utils import ssc as ssc_func +def autofill_with_options(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + sig = signature(func) + bound = sig.bind_partial(*args, **kwargs) + bound.apply_defaults() + for name, value in bound.arguments.items(): + if value is None and hasattr(options, name): + bound.arguments[name] = getattr(options, name) + return func(**bound.arguments) + + return wrapper + + +@autofill_with_options def feols( - fml: str, - data: DataFrameType, # type: ignore + fml: Optional[str] = None, + data: Optional[DataFrameType] = None, # type: ignore vcov: Optional[Union[VcovTypeOptions, dict[str, str]]] = None, weights: Union[None, str] = None, ssc: Optional[dict[str, Union[str, bool]]] = None, - fixef_rm: FixedRmOptions = "none", - fixef_tol=1e-08, - collin_tol: float = 1e-10, - drop_intercept: bool = False, - i_ref1=None, - copy_data: bool = True, - store_data: bool = True, - lean: bool = False, - weights_type: WeightsTypeOptions = "aweights", - solver: SolverOptions = "scipy.linalg.solve", - demeaner_backend: DemeanerBackendOptions = "numba", - use_compression: bool = False, - reps: int = 100, + fixef_rm: Optional[FixedRmOptions] = None, + fixef_tol: Optional[float] = None, + collin_tol: Optional[float] = None, + drop_intercept: Optional[bool] = None, + copy_data: Optional[bool] = None, + store_data: Optional[bool] = None, + lean: Optional[bool] = None, + weights_type: Optional[WeightsTypeOptions] = None, + solver: Optional[SolverOptions] = None, + demeaner_backend: Optional[DemeanerBackendOptions] = None, + use_compression: Optional[bool] = None, + reps: Optional[int] = None, context: Optional[Union[int, Mapping[str, Any]]] = None, seed: Optional[int] = None, split: Optional[str] = None, @@ -83,11 +100,6 @@ def feols( drop_intercept : bool, optional Whether to drop the intercept from the model, by default False. - i_ref1: None - Deprecated with pyfixest version 0.18.0. Please use i-syntax instead, i.e. - feols('Y~ i(f1, ref=1)', data = data) instead of the former - feols('Y~ i(f1)', data = data, i_ref=1). - copy_data : bool, optional Whether to copy the data before estimation, by default True. If set to False, the data is not copied, which can save memory but @@ -433,37 +445,12 @@ def _lspline(series: pd.Series, knots: list[float]) -> np.array: fit_D.ccv(treatment = "D", cluster = "group_id") ``` """ - if ssc is None: - ssc = ssc_func() - if i_ref1 is not None: - raise FeatureDeprecationError( - """ - The 'i_ref1' function argument is deprecated with pyfixest version 0.18.0. - Please use i-syntax instead, i.e. feols('Y~ i(f1, ref=1)', data = data) - instead of the former feols('Y~ i(f1)', data = data, i_ref=1). - """ - ) + context = {} if context is None else capture_context(context) - _estimation_input_checks( - fml=fml, - data=data, - vcov=vcov, - weights=weights, - ssc=ssc, - fixef_rm=fixef_rm, - collin_tol=collin_tol, - copy_data=copy_data, - store_data=store_data, - lean=lean, - fixef_tol=fixef_tol, - weights_type=weights_type, - use_compression=use_compression, - reps=reps, - seed=seed, - split=split, - fsplit=fsplit, - ) + args = locals() + filtered_args = {k: args[k] for k in EstimationInputs.__dataclass_fields__ if k in args} + EstimationInputs(**filtered_args).validate() fixest = FixestMulti( data=data, @@ -500,6 +487,7 @@ def _lspline(series: pd.Series, knots: list[float]) -> np.array: return fixest.fetch_model(0, print_fml=False) +@autofill_with_options def fepois( fml: str, data: DataFrameType, # type: ignore @@ -514,7 +502,6 @@ def fepois( solver: SolverOptions = "scipy.linalg.solve", demeaner_backend: DemeanerBackendOptions = "numba", drop_intercept: bool = False, - i_ref1=None, copy_data: bool = True, store_data: bool = True, lean: bool = False, @@ -581,11 +568,6 @@ def fepois( drop_intercept : bool, optional Whether to drop the intercept from the model, by default False. - i_ref1: None - Deprecated with pyfixest version 0.18.0. Please use i-syntax instead, i.e. - fepois('Y~ i(f1, ref=1)', data = data) instead of the former - fepois('Y~ i(f1)', data = data, i_ref=1). - copy_data : bool, optional Whether to copy the data before estimation, by default True. If set to False, the data is not copied, which can save memory but @@ -647,44 +629,15 @@ def fepois( For more examples on the use of other function arguments, please take a look at the documentation of the [feols()](https://py-econometrics.github.io/pyfixest/reference/estimation.estimation.feols.html#pyfixest.estimation.estimation.feols) function. """ - if separation_check is None: - separation_check = ["fe"] - if ssc is None: - ssc = ssc_func() - if i_ref1 is not None: - raise FeatureDeprecationError( - """ - The 'i_ref1' function argument is deprecated with pyfixest version 0.18.0. - Please use i-syntax instead, i.e. fepois('Y~ i(f1, ref=1)', data = data) - instead of the former fepois('Y~ i(f1)', data = data, i_ref=1). - """ - ) context = {} if context is None else capture_context(context) # WLS currently not supported for Poisson regression weights = None weights_type = "aweights" - _estimation_input_checks( - fml=fml, - data=data, - vcov=vcov, - weights=weights, - ssc=ssc, - fixef_rm=fixef_rm, - collin_tol=collin_tol, - copy_data=copy_data, - store_data=store_data, - lean=lean, - fixef_tol=fixef_tol, - weights_type=weights_type, - use_compression=False, - reps=None, - seed=None, - split=split, - fsplit=fsplit, - separation_check=separation_check, - ) + args = locals() + filtered_args = {k: args[k] for k in EstimationInputs.__dataclass_fields__ if k in args} + EstimationInputs(**filtered_args).validate() fixest = FixestMulti( data=data, @@ -725,24 +678,24 @@ def fepois( return fixest.fetch_model(0, print_fml=False) +@autofill_with_options def feglm( - fml: str, - data: DataFrameType, # type: ignore - family: str, + fml: Optional[str] = None, + data: Optional[DataFrameType] = None, # type: ignore + family: Optional[str] = None, vcov: Optional[Union[VcovTypeOptions, dict[str, str]]] = None, ssc: Optional[dict[str, Union[str, bool]]] = None, - fixef_rm: FixedRmOptions = "none", - fixef_tol: float = 1e-08, - iwls_tol: float = 1e-08, - iwls_maxiter: int = 25, - collin_tol: float = 1e-10, + fixef_rm: Optional[FixedRmOptions] = None, + fixef_tol: Optional[float] = None, + iwls_tol: Optional[float] = None, + iwls_maxiter: Optional[int] = None, + collin_tol: Optional[float] = None, separation_check: Optional[list[str]] = None, - solver: SolverOptions = "scipy.linalg.solve", - drop_intercept: bool = False, - i_ref1=None, - copy_data: bool = True, - store_data: bool = True, - lean: bool = False, + solver: Optional[SolverOptions] = None, + drop_intercept: Optional[bool] = None, + copy_data: Optional[bool] = None, + store_data: Optional[bool] = None, + lean: Optional[bool] = None, context: Optional[Union[int, Mapping[str, Any]]] = None, split: Optional[str] = None, fsplit: Optional[str] = None, @@ -806,11 +759,6 @@ def feglm( drop_intercept : bool, optional Whether to drop the intercept from the model, by default False. - i_ref1: None - Deprecated with pyfixest version 0.18.0. Please use i-syntax instead, i.e. - fepois('Y~ i(f1, ref=1)', data = data) instead of the former - fepois('Y~ i(f1)', data = data, i_ref=1). - copy_data : bool, optional Whether to copy the data before estimation, by default True. If set to False, the data is not copied, which can save memory but @@ -904,45 +852,15 @@ def feglm( f"Only families 'gaussian', 'logit' and 'probit'are supported but you asked for {family}." ) - if separation_check is None: - separation_check = ["fe"] - if ssc is None: - ssc = ssc_func() - if i_ref1 is not None: - raise FeatureDeprecationError( - """ - The 'i_ref1' function argument is deprecated with pyfixest version 0.18.0. - Please use i-syntax instead, i.e. fepois('Y~ i(f1, ref=1)', data = data) - instead of the former fepois('Y~ i(f1)', data = data, i_ref=1). - """ - ) - # WLS currently not supported for GLM regression weights = None weights_type = "aweights" context = {} if context is None else capture_context(context) - _estimation_input_checks( - fml=fml, - data=data, - vcov=vcov, - weights=weights, - ssc=ssc, - fixef_rm=fixef_rm, - collin_tol=collin_tol, - copy_data=copy_data, - store_data=store_data, - lean=lean, - fixef_tol=fixef_tol, - weights_type=weights_type, - use_compression=False, - reps=None, - seed=None, - split=split, - fsplit=fsplit, - separation_check=separation_check, - ) + args = locals() + filtered_args = {k: args[k] for k in EstimationInputs.__dataclass_fields__ if k in args} + EstimationInputs(**filtered_args).validate() fixest = FixestMulti( data=data, @@ -983,6 +901,70 @@ def feglm( return fixest.fetch_model(0, print_fml=False) +def _check_type(value, expected_type, name): + if not isinstance(value, expected_type): + raise TypeError(f"Argument '{name}' must be {expected_type.__name__}, got {type(value).__name__}") + +def _check_value(value, valid_values, name): + if value not in valid_values: + raise ValueError(f"Argument '{name}' must be one of {valid_values}, got {value!r}") + +@dataclass +class EstimationInputs: + fml: str + data: pd.DataFrame + vcov: Optional[Union[str, dict[str, str]]] + weights: Optional[str] + ssc: dict[str, Union[str, bool]] + fixef_rm: str + collin_tol: float + copy_data: bool + store_data: bool + lean: bool + fixef_tol: float + weights_type: str + use_compression: bool + reps: Optional[int] + seed: Optional[int] + split: Optional[str] + fsplit: Optional[str] + separation_check: Optional[list[str]] = field(default=None) + + def validate(self): + _check_type(self.data, pd.DataFrame, "data") + _check_value(self.fixef_rm, ["none", "singleton"], "fixef_rm") + if not (0 < self.collin_tol < 1): + raise ValueError("collin_tol must be in (0, 1)") + if self.weights is not None: + _check_type(self.weights, str, "weights") + if self.weights not in self.data.columns: + raise ValueError(f"weights '{self.weights}' must be a column in data") + for bname in ["copy_data", "store_data", "lean", "use_compression"]: + _check_type(getattr(self, bname), bool, bname) + if not (0 < self.fixef_tol < 1): + raise ValueError("fixef_tol must be in (0, 1)") + _check_value(self.weights_type, ["aweights", "fweights"], "weights_type") + if self.use_compression and self.weights is not None: + raise NotImplementedError("Compressed regression is not supported with weights.") + if self.reps is not None: + _check_type(self.reps, int, "reps") + if self.reps <= 0: + raise ValueError("reps must be strictly positive") + if self.seed is not None: + _check_type(self.seed, int, "seed") + for cname in ["split", "fsplit"]: + cval = getattr(self, cname) + if cval is not None: + _check_type(cval, str, cname) + if cval not in self.data.columns: + raise KeyError(f"Column '{cval}' not found in data.") + if self.split is not None and self.fsplit is not None and self.split != self.fsplit: + raise ValueError(f"split and fsplit specified but not identical: {self.split} vs {self.fsplit}") + if self.separation_check is not None: + _check_type(self.separation_check, list, "separation_check") + if not all(x in ("fe", "ir") for x in self.separation_check): + raise ValueError("separation_check must be a list containing only 'fe' and/or 'ir'.") + def _estimation_input_checks( fml: str, data: DataFrameType, @@ -1110,3 +1092,15 @@ def _estimation_input_checks( raise ValueError( "The function argument `separation_check` must be a list of strings containing 'fe' and/or 'ir'." ) + + +def _enforce_types_and_not_none(args, required_types): + for key, expected_type in required_types.items(): + if args[key] is None: + raise TypeError( + f"Argument '{key}' must be set, but is None after autofill." + ) + if not isinstance(args[key], expected_type): + raise TypeError( + f"Argument '{key}' must be {expected_type}, got {type(args[key])}." + ) diff --git a/pyfixest/options.py b/pyfixest/options.py new file mode 100644 index 000000000..c85a9eaf3 --- /dev/null +++ b/pyfixest/options.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import asdict, dataclass, field +from typing import Any, Optional, Union +from contextlib import contextmanager + +import pandas as pd + +from pyfixest.utils.utils import ssc as ssc_func + +__all__ = ["get_option", "option_context", "options", "set_option"] + + +@dataclass +class _Options: + data: Optional[pd.DataFrame] = None + vcov: Optional[Union[str, Mapping[str, str]]] = None + weights: Optional[str] = None + ssc: Optional[dict[str, Union[str, bool]]] = field(default_factory=ssc_func) + fixef_rm: str = "none" + fixef_tol: float = 1e-8 + collin_tol: float = 1e-10 + drop_intercept: bool = False + i_ref1: Optional[str] = None + copy_data: bool = True + store_data: bool = True + lean: bool = False + weights_type: str = "aweights" + solver: str = "scipy.linalg.solve" + demeaner_backend: str = "numba" + use_compression: bool = False + reps: int = 100 + context: Optional[Union[int, Mapping[str, Any]]] = None + seed: Optional[int] = None + split: Optional[str] = None + fsplit: Optional[str] = None + + separation_check: list[str] = field(default_factory=lambda: ["fe"]) + iwls_tol: Optional[float] = 1e-6 + iwls_maxiter: Optional[int] = 25 + + # helpers ------------ + def update(self, **kwargs): + for k, v in kwargs.items(): + if not hasattr(self, k): + raise KeyError(f"Unknown option '{k}'") + setattr(self, k, v) + + def to_dict(self): + return asdict(self) + + +options = _Options() + + +def set_option(**kwargs): + """Globally set default options (except `fml`).""" + options.update(**kwargs) + + +def get_option(name: str): + return getattr(options, name) + + +@contextmanager +def option_context(**kwargs): + "Temporarily override options inside a `with` block." + old = options.to_dict() + try: + options.update(**kwargs) + yield + finally: + options.__dict__.update(old) From 25f37d419a365941d58880e87476cfc89dbd20b5 Mon Sep 17 00:00:00 2001 From: Alexander Fischer Date: Sat, 7 Jun 2025 15:56:21 +0200 Subject: [PATCH 2/2] global options --- pyfixest/estimation/estimation.py | 228 ++----------------------- pyfixest/options.py | 2 +- pyfixest/utils/api_input_checks.py | 95 +++++++++++ pyfixest/utils/dev_utils.py | 1 - tests/test_global_options.py | 263 +++++++++++++++++++++++++++++ 5 files changed, 376 insertions(+), 213 deletions(-) create mode 100644 pyfixest/utils/api_input_checks.py create mode 100644 tests/test_global_options.py diff --git a/pyfixest/estimation/estimation.py b/pyfixest/estimation/estimation.py index 990e69066..efad89c4b 100644 --- a/pyfixest/estimation/estimation.py +++ b/pyfixest/estimation/estimation.py @@ -1,11 +1,10 @@ import functools from collections.abc import Mapping +from dataclasses import dataclass, field from inspect import signature from typing import Any, Optional, Union -from dataclasses import dataclass, field import pandas as pd -import narwhals as nw from pyfixest.estimation.feols_ import Feols from pyfixest.estimation.fepois_ import Fepois @@ -20,9 +19,13 @@ from pyfixest.options import options from pyfixest.utils.dev_utils import DataFrameType, _narwhals_to_pandas from pyfixest.utils.utils import capture_context +from pyfixest.utils.api_input_checks import _check_type, _check_value, EstimationInputs def autofill_with_options(func): + """ + Decorator to autofill the arguments of the estimation functions with the global options. + """ @functools.wraps(func) def wrapper(*args, **kwargs): sig = signature(func) @@ -445,11 +448,15 @@ def _lspline(series: pd.Series, knots: list[float]) -> np.array: fit_D.ccv(treatment = "D", cluster = "group_id") ``` """ - context = {} if context is None else capture_context(context) + if not isinstance(data, pd.DataFrame): + data = _narwhals_to_pandas(data) + args = locals() - filtered_args = {k: args[k] for k in EstimationInputs.__dataclass_fields__ if k in args} + filtered_args = { + k: args[k] for k in EstimationInputs.__dataclass_fields__ if k in args + } EstimationInputs(**filtered_args).validate() fixest = FixestMulti( @@ -636,7 +643,9 @@ def fepois( weights_type = "aweights" args = locals() - filtered_args = {k: args[k] for k in EstimationInputs.__dataclass_fields__ if k in args} + filtered_args = { + k: args[k] for k in EstimationInputs.__dataclass_fields__ if k in args + } EstimationInputs(**filtered_args).validate() fixest = FixestMulti( @@ -859,7 +868,9 @@ def feglm( context = {} if context is None else capture_context(context) args = locals() - filtered_args = {k: args[k] for k in EstimationInputs.__dataclass_fields__ if k in args} + filtered_args = { + k: args[k] for k in EstimationInputs.__dataclass_fields__ if k in args + } EstimationInputs(**filtered_args).validate() fixest = FixestMulti( @@ -899,208 +910,3 @@ def feglm( return fixest else: return fixest.fetch_model(0, print_fml=False) - - -def _check_type(value, expected_type, name): - if not isinstance(value, expected_type): - raise TypeError(f"Argument '{name}' must be {expected_type.__name__}, got {type(value).__name__}") - -def _check_value(value, valid_values, name): - if value not in valid_values: - raise ValueError(f"Argument '{name}' must be one of {valid_values}, got {value!r}") - -@dataclass -class EstimationInputs: - fml: str - data: pd.DataFrame - vcov: Optional[Union[str, dict[str, str]]] - weights: Optional[str] - ssc: dict[str, Union[str, bool]] - fixef_rm: str - collin_tol: float - copy_data: bool - store_data: bool - lean: bool - fixef_tol: float - weights_type: str - use_compression: bool - reps: Optional[int] - seed: Optional[int] - split: Optional[str] - fsplit: Optional[str] - separation_check: Optional[list[str]] = field(default=None) - - def validate(self): - _check_type(self.data, pd.DataFrame, "data") - _check_value(self.fixef_rm, ["none", "singleton"], "fixef_rm") - if not (0 < self.collin_tol < 1): - raise ValueError("collin_tol must be in (0, 1)") - if self.weights is not None: - _check_type(self.weights, str, "weights") - if self.weights not in self.data.columns: - raise ValueError(f"weights '{self.weights}' must be a column in data") - for bname in ["copy_data", "store_data", "lean", "use_compression"]: - _check_type(getattr(self, bname), bool, bname) - if not (0 < self.fixef_tol < 1): - raise ValueError("fixef_tol must be in (0, 1)") - _check_value(self.weights_type, ["aweights", "fweights"], "weights_type") - if self.use_compression and self.weights is not None: - raise NotImplementedError("Compressed regression is not supported with weights.") - if self.reps is not None: - _check_type(self.reps, int, "reps") - if self.reps <= 0: - raise ValueError("reps must be strictly positive") - if self.seed is not None: - _check_type(self.seed, int, "seed") - for cname in ["split", "fsplit"]: - cval = getattr(self, cname) - if cval is not None: - _check_type(cval, str, cname) - if cval not in self.data.columns: - raise KeyError(f"Column '{cval}' not found in data.") - if self.split is not None and self.fsplit is not None and self.split != self.fsplit: - raise ValueError(f"split and fsplit specified but not identical: {self.split} vs {self.fsplit}") - if self.separation_check is not None: - _check_type(self.separation_check, list, "separation_check") - if not all(x in ("fe", "ir") for x in self.separation_check): - raise ValueError("separation_check must be a list containing only 'fe' and/or 'ir'.") - -def _estimation_input_checks( - fml: str, - data: DataFrameType, - vcov: Optional[Union[str, dict[str, str]]], - weights: Union[None, str], - ssc: dict[str, Union[str, bool]], - fixef_rm: str, - collin_tol: float, - copy_data: bool, - store_data: bool, - lean: bool, - fixef_tol: float, - weights_type: str, - use_compression: bool, - reps: Optional[int], - seed: Optional[int], - split: Optional[str], - fsplit: Optional[str], - separation_check: Optional[list[str]] = None, -): - if not isinstance(fml, str): - raise TypeError("fml must be a string") - if not isinstance(data, pd.DataFrame): - data = _narwhals_to_pandas(data) - if not isinstance(vcov, (str, dict, type(None))): - raise TypeError("vcov must be a string, dictionary, or None") - if not isinstance(fixef_rm, str): - raise TypeError("fixef_rm must be a string") - if not isinstance(collin_tol, float): - raise TypeError("collin_tol must be a float") - - if fixef_rm not in ["none", "singleton"]: - raise ValueError("fixef_rm must be either 'none' or 'singleton'") - if collin_tol <= 0: - raise ValueError("collin_tol must be greater than zero") - if collin_tol >= 1: - raise ValueError("collin_tol must be less than one") - - if not (isinstance(weights, str) or weights is None): - raise ValueError( - f"weights must be a string or None but you provided weights = {weights}." - ) - if weights is not None: - assert weights in data.columns, "weights must be a column in data" - - bool_args = [copy_data, store_data, lean] - for arg in bool_args: - if not isinstance(arg, bool): - raise TypeError(f"The function argument {arg} must be of type bool.") - - if not isinstance(fixef_tol, float): - raise TypeError( - """The function argument `fixef_tol` needs to be of - type float. - """ - ) - if fixef_tol <= 0: - raise ValueError( - """ - The function argument `fixef_tol` needs to be of - strictly larger than 0. - """ - ) - if fixef_tol >= 1: - raise ValueError( - """ - The function argument `fixef_tol` needs to be of - strictly smaller than 1. - """ - ) - - if weights_type not in ["aweights", "fweights"]: - raise ValueError( - f""" - The `weights_type` argument must be of type `aweights` - (for analytical / precision weights) or `fweights` - (for frequency weights) but it is {weights_type}. - """ - ) - - if not isinstance(use_compression, bool): - raise TypeError("The function argument `use_compression` must be of type bool.") - if use_compression and weights is not None: - raise NotImplementedError( - "Compressed regression is not supported with weights." - ) - - if reps is not None: - if not isinstance(reps, int): - raise TypeError("The function argument `reps` must be of type int.") - - if reps <= 0: - raise ValueError("The function argument `reps` must be strictly positive.") - - if seed is not None and not isinstance(seed, int): - raise TypeError("The function argument `seed` must be of type int.") - - if split is not None and not isinstance(split, str): - raise TypeError("The function argument split needs to be of type str.") - - if fsplit is not None and not isinstance(fsplit, str): - raise TypeError("The function argument fsplit needs to be of type str.") - - if split is not None and fsplit is not None and split != fsplit: - raise ValueError( - f""" - Arguments split and fsplit are both specified, but not identical. - split is specified as {split}, while fsplit is specified as {fsplit}. - """ - ) - - if isinstance(split, str) and split not in data.columns: - raise KeyError(f"Column '{split}' not found in data.") - - if isinstance(fsplit, str) and fsplit not in data.columns: - raise KeyError(f"Column '{fsplit}' not found in data.") - - if separation_check is not None: - if not isinstance(separation_check, list): - raise TypeError( - "The function argument `separation_check` must be of type list." - ) - - if not all(x in ["fe", "ir"] for x in separation_check): - raise ValueError( - "The function argument `separation_check` must be a list of strings containing 'fe' and/or 'ir'." - ) - - -def _enforce_types_and_not_none(args, required_types): - for key, expected_type in required_types.items(): - if args[key] is None: - raise TypeError( - f"Argument '{key}' must be set, but is None after autofill." - ) - if not isinstance(args[key], expected_type): - raise TypeError( - f"Argument '{key}' must be {expected_type}, got {type(args[key])}." - ) diff --git a/pyfixest/options.py b/pyfixest/options.py index c85a9eaf3..2be5eceb7 100644 --- a/pyfixest/options.py +++ b/pyfixest/options.py @@ -1,9 +1,9 @@ from __future__ import annotations from collections.abc import Mapping +from contextlib import contextmanager from dataclasses import asdict, dataclass, field from typing import Any, Optional, Union -from contextlib import contextmanager import pandas as pd diff --git a/pyfixest/utils/api_input_checks.py b/pyfixest/utils/api_input_checks.py new file mode 100644 index 000000000..39cadd823 --- /dev/null +++ b/pyfixest/utils/api_input_checks.py @@ -0,0 +1,95 @@ +from dataclasses import dataclass, field +from typing import Optional, Union + +import pandas as pd + +@dataclass +class EstimationInputs: + fml: str + data: pd.DataFrame + vcov: Optional[Union[str, dict[str, str]]] + weights: Optional[str] + ssc: dict[str, Union[str, bool]] + fixef_rm: str + collin_tol: float + copy_data: bool + store_data: bool + lean: bool + fixef_tol: float + weights_type: str + use_compression: bool + reps: Optional[int] + seed: Optional[int] + split: Optional[str] + fsplit: Optional[str] + separation_check: Optional[list[str]] = field(default=None) + + "Dataclass to store and check the arguments of the estimation functions feols, fepois, feglm etc." + + def validate(self): + + "Validate the arguments of the EstimationInputs class." + + # Step 1: Check types + _check_type(self.fml, str, "fml") + _check_type(self.data, pd.DataFrame, "data") + _check_type(self.vcov, (str, dict, type(None)), "vcov") + _check_type(self.weights, (str, type(None)), "weights") + _check_type(self.ssc, dict, "ssc") + _check_type(self.fixef_rm, str, "fixef_rm") + _check_type(self.collin_tol, float, "collin_tol") + _check_type(self.copy_data, bool, "copy_data") + _check_type(self.store_data, bool, "store_data") + _check_type(self.lean, bool, "lean") + _check_type(self.fixef_tol, float, "fixef_tol") + _check_type(self.weights_type, str, "weights_type") + _check_type(self.use_compression, bool, "use_compression") + _check_type(self.reps, (int, type(None)), "reps") + _check_type(self.seed, (int, type(None)), "seed") + _check_type(self.split, (str, type(None)), "split") + _check_type(self.fsplit, (str, type(None)), "fsplit") + _check_type(self.separation_check, (list, type(None)), "separation_check") + + # Step 2: Check values + if isinstance(self.vcov, str): + _check_value(self.vcov, ["iid", "HC1", "HC2", "HC3", "hetero"], "vcov") + elif isinstance(self.vcov, dict): + for key, value in self.vcov.items(): + if not isinstance(key, str): + raise TypeError(f"Key '{key}' in vcov must be a string.") + if not isinstance(value, str): + raise TypeError(f"Value '{value}' in vcov must be a string.") + _check_value(key, ["CRV1", "CRV3"], f"key: {key} in vcov") + _check_value(value, self.data.columns, f"value: {value} in vcov") + if self.weights is not None: + _check_value(self.weights, self.data.columns, "weights") + _check_value(self.fixef_rm, ["none", "singleton", "drop"], "fixef_rm") + if not (0 < self.collin_tol < 1): + raise ValueError("collin_tol must be in (0, 1).") + if not (0 < self.fixef_tol < 1): + raise ValueError("fixef_tol must be in (0, 1).") + _check_value(self.weights_type, ["aweights", "fweights"], "weights_type") + if not (0 < self.reps): + raise ValueError("reps must be strictly positive.") + if self.split is not None: + _check_value(self.split, self.data.columns, "split") + if self.fsplit is not None: + _check_value(self.fsplit, self.data.columns, "fsplit") + if self.split is not None and self.fsplit is not None: + if self.split != self.fsplit: + raise ValueError( + f"split and fsplit specified but not identical: {self.split} vs {self.fsplit}" + ) + _check_value(self.separation_check, ["fe", "ir", None], "separation_check") + +def _check_type(value, expected_type, name): + if not isinstance(value, expected_type): + raise TypeError( + f"Argument '{name}' must be {expected_type.__name__}, got {type(value).__name__}" + ) + +def _check_value(value, valid_values, name): + if value not in valid_values: + raise ValueError( + f"Argument '{name}' must be one of {valid_values}, got {value!r}" + ) diff --git a/pyfixest/utils/dev_utils.py b/pyfixest/utils/dev_utils.py index 77acf08e0..101480aa0 100644 --- a/pyfixest/utils/dev_utils.py +++ b/pyfixest/utils/dev_utils.py @@ -8,7 +8,6 @@ DataFrameType = IntoDataFrame - def _narwhals_to_pandas(data: IntoDataFrame) -> pd.DataFrame: # type: ignore return nw.from_native(data, eager_or_interchange_only=True).to_pandas() diff --git a/tests/test_global_options.py b/tests/test_global_options.py new file mode 100644 index 000000000..58b78dd87 --- /dev/null +++ b/tests/test_global_options.py @@ -0,0 +1,263 @@ +import pytest +import numpy as np +import pandas as pd +import pyfixest as pf +from pyfixest.options import get_option, set_option, option_context + +@pytest.fixture +def data(): + """Create test data.""" + np.random.seed(123) + n = 100 + data = pd.DataFrame({ + 'y': np.random.normal(0, 1, n), + 'x1': np.random.normal(0, 1, n), + 'x2': np.random.normal(0, 1, n), + 'id': np.repeat(range(10), 10), + 'weights': np.random.uniform(0.5, 1.5, n) + }) + return data + +def test_default_options_feols(data): + """Test that default options are correctly applied to feols.""" + # Get default options + default_vcov = get_option('vcov') + default_weights = get_option('weights') + default_ssc = get_option('ssc') + default_fixef_tol = get_option('fixef_tol') + default_collin_tol = get_option('collin_tol') + + # Fit model with defaults + fit = pf.feols('y ~ x1 + x2 | id', data=data) + + # Check that defaults were applied + assert fit._vcov_type == default_vcov + assert fit._weights is default_weights + assert fit._ssc_dict == default_ssc + assert fit._fixef_tol == default_fixef_tol + assert fit._collin_tol == default_collin_tol + +def test_default_options_fepois(data): + """Test that default options are correctly applied to fepois.""" + # Get default options + default_vcov = get_option('vcov') + default_weights = get_option('weights') + default_ssc = get_option('ssc') + default_fixef_tol = get_option('fixef_tol') + default_collin_tol = get_option('collin_tol') + + # Fit model with defaults + fit = pf.fepois('y ~ x1 + x2 | id', data=data) + + # Check that defaults were applied + assert fit._vcov_type == default_vcov + assert fit._weights is None # No weights by default + assert fit._ssc_dict == default_ssc + assert fit._fixef_tol == default_fixef_tol + assert fit._collin_tol == default_collin_tol + +def test_default_options_feglm(data): + """Test that default options are correctly applied to feglm.""" + # Get default options + default_vcov = get_option('vcov') + default_weights = get_option('weights') + default_ssc = get_option('ssc') + default_fixef_tol = get_option('fixef_tol') + default_collin_tol = get_option('collin_tol') + + # Fit model with defaults + fit = pf.feglm('y ~ x1 + x2 | id', data=data) + + # Check that defaults were applied + assert fit._vcov_type == default_vcov + assert fit._weights is None # No weights by default + assert fit._ssc_dict == default_ssc + assert fit._fixef_tol == default_fixef_tol + assert fit._collin_tol == default_collin_tol + +def test_override_options(data): + """Test that options can be overridden at the function call level.""" + # Set custom options + custom_vcov = "hetero" + custom_weights = "weights" + custom_ssc = {"adj": True, "fixef_k": "none"} + + # Fit model with custom options + fit = pf.feols( + 'y ~ x1 + x2 | id', + data=data, + vcov=custom_vcov, + weights=custom_weights, + ssc=custom_ssc + ) + + # Check that custom options were applied + assert fit._vcov_type == custom_vcov + assert fit._weights_name == custom_weights + assert fit._ssc_dict == custom_ssc + +def test_option_context(data): + """Test that options can be temporarily changed using the context manager.""" + # Original options + original_vcov = get_option('vcov') + original_weights = get_option('weights') + + # Custom options + custom_vcov = "hetero" + custom_weights = "weights" + + # Use context manager to temporarily change options + with option_context(vcov=custom_vcov, weights=custom_weights): + fit = pf.feols('y ~ x1 + x2 | id', data=data) + assert fit._vcov_type == custom_vcov + assert fit._weights_name == custom_weights + + # Check that original options are restored + assert get_option('vcov') == original_vcov + assert get_option('weights') == original_weights + +def test_set_option(data): + """Test that options can be permanently changed using set_option.""" + # Original options + original_vcov = get_option('vcov') + original_weights = get_option('weights') + + try: + # Set new options + custom_vcov = "hetero" + custom_weights = "weights" + set_option(vcov=custom_vcov, weights=custom_weights) + + # Fit model with new defaults + fit = pf.feols('y ~ x1 + x2 | id', data=data) + assert fit._vcov_type == custom_vcov + assert fit._weights_name == custom_weights + + finally: + # Restore original options + set_option(vcov=original_vcov, weights=original_weights) + +def test_options_persistence(data): + """Test that options persist across multiple function calls.""" + # Set custom options + custom_vcov = "hetero" + custom_weights = "weights" + set_option(vcov=custom_vcov, weights=custom_weights) + + try: + # Fit multiple models + fit1 = pf.feols('y ~ x1 | id', data=data) + fit2 = pf.fepois('y ~ x1 | id', data=data) + fit3 = pf.feglm('y ~ x1 | id', data=data) + + # Check that all models use the custom options + assert fit1._vcov_type == custom_vcov + assert fit2._vcov_type == custom_vcov + assert fit3._vcov_type == custom_vcov + + assert fit1._weights_name == custom_weights + assert fit2._weights_name == custom_weights + assert fit3._weights_name == custom_weights + + finally: + # Restore original options + set_option(vcov=get_option('vcov'), weights=get_option('weights')) + +def test_global_vs_direct_options(data): + """Test that applying global options leads to the same results as providing arguments directly.""" + # Set custom options + custom_vcov = "hetero" + custom_weights = "weights" + custom_ssc = {"adj": True, "fixef_k": "none"} + + # Fit model with direct arguments + direct_fit = pf.feols( + 'y ~ x1 + x2 | id', + data=data, + vcov=custom_vcov, + weights=custom_weights, + ssc=custom_ssc + ) + + # Set global options + set_option(vcov=custom_vcov, weights=custom_weights, ssc=custom_ssc) + + try: + # Fit model with global options + global_fit = pf.feols('y ~ x1 + x2 | id', data=data) + + # Compare coefficients + np.testing.assert_allclose( + direct_fit.coef(), + global_fit.coef(), + rtol=1e-10, + atol=1e-10, + err_msg="Coefficients do not match between direct and global options" + ) + + # Compare standard errors + np.testing.assert_allclose( + direct_fit.se(), + global_fit.se(), + rtol=1e-10, + atol=1e-10, + err_msg="Standard errors do not match between direct and global options" + ) + + # Compare t-statistics + np.testing.assert_allclose( + direct_fit.tstat(), + global_fit.tstat(), + rtol=1e-10, + atol=1e-10, + err_msg="t-statistics do not match between direct and global options" + ) + + # Compare p-values + np.testing.assert_allclose( + direct_fit.pvalue(), + global_fit.pvalue(), + rtol=1e-10, + atol=1e-10, + err_msg="p-values do not match between direct and global options" + ) + + # Compare confidence intervals + np.testing.assert_allclose( + direct_fit.confint().values, + global_fit.confint().values, + rtol=1e-10, + atol=1e-10, + err_msg="Confidence intervals do not match between direct and global options" + ) + + # Compare variance-covariance matrices + np.testing.assert_allclose( + direct_fit._vcov, + global_fit._vcov, + rtol=1e-10, + atol=1e-10, + err_msg="Variance-covariance matrices do not match between direct and global options" + ) + + # Compare predictions + np.testing.assert_allclose( + direct_fit.predict(), + global_fit.predict(), + rtol=1e-10, + atol=1e-10, + err_msg="Predictions do not match between direct and global options" + ) + + # Compare residuals + np.testing.assert_allclose( + direct_fit.resid(), + global_fit.resid(), + rtol=1e-10, + atol=1e-10, + err_msg="Residuals do not match between direct and global options" + ) + + finally: + # Restore original options + set_option(vcov=get_option('vcov'), weights=get_option('weights'), ssc=get_option('ssc')) \ No newline at end of file