diff --git a/pyfixest/estimation/FixestMulti_.py b/pyfixest/estimation/FixestMulti_.py index 1602d9072..62c9666d9 100644 --- a/pyfixest/estimation/FixestMulti_.py +++ b/pyfixest/estimation/FixestMulti_.py @@ -8,6 +8,7 @@ from pyfixest.estimation.fegaussian_ import Fegaussian from pyfixest.estimation.feiv_ import Feiv from pyfixest.estimation.felogit_ import Felogit +from pyfixest.estimation.jax.olsjax_interface import OLSJAX_API from pyfixest.estimation.feols_ import Feols, _check_vcov_input, _deparse_vcov_input from pyfixest.estimation.feols_compressed_ import FeolsCompressed from pyfixest.estimation.fepois_ import Fepois @@ -16,7 +17,6 @@ from pyfixest.utils.dev_utils import DataFrameType, _narwhals_to_pandas from pyfixest.utils.utils import capture_context - class FixestMulti: """A class to estimate multiple regression models with fixed effects.""" @@ -279,7 +279,11 @@ def _estimate_all_models( FIT: Union[Feols, Feiv, Fepois] if _method == "feols" and not _is_iv: - FIT = Feols( + + backend = "jax" + OLSCLASS = Feols if backend != "jax" else OLSJAX_API + + FIT = OLSCLASS( FixestFormula=FixestFormula, data=_data, ssc_dict=_ssc_dict, @@ -299,11 +303,15 @@ def _estimate_all_models( sample_split_value=sample_split_value, sample_split_var=_splitvar, ) - FIT.prepare_model_matrix() - FIT.demean() - FIT.to_array() - FIT.drop_multicol_vars() - FIT.wls_transform() + + if backend != "jax": + FIT.prepare_model_matrix() + FIT.demean() + FIT.to_array() + FIT.drop_multicol_vars() + FIT.wls_transform() + + elif _method == "feols" and _is_iv: FIT = Feiv( FixestFormula=FixestFormula, @@ -330,6 +338,7 @@ def _estimate_all_models( FIT.to_array() FIT.drop_multicol_vars() FIT.wls_transform() + elif _method == "fepois": FIT = Fepois( FixestFormula=FixestFormula, @@ -475,17 +484,25 @@ def _estimate_all_models( FIT.get_fit() # if X is empty: no inference (empty X only as shorthand for demeaning) - if not FIT._X_is_empty: - # inference - vcov_type = _get_vcov_type(vcov, fval) - FIT.vcov(vcov=vcov_type, data=FIT._data) + if backend != "jax": + if not FIT._X_is_empty: + # inference + vcov_type = _get_vcov_type(vcov, fval) + FIT.vcov(vcov=vcov_type, data=FIT._data) + + FIT.get_inference() + # other regression stats + if _method == "feols" and not FIT._is_iv: + FIT.get_performance() + if isinstance(FIT, Feiv): + FIT.first_stage() + + else: + #import pdb; pdb.set_trace() + FIT.vcov(type = "iid") + FIT.convert_attributes_to_numpy() FIT.get_inference() - # other regression stats - if _method == "feols" and not FIT._is_iv: - FIT.get_performance() - if isinstance(FIT, Feiv): - FIT.first_stage() # delete large attributescl FIT._clear_attributes() diff --git a/pyfixest/estimation/demean_.py b/pyfixest/estimation/demean_.py index 5987dde43..ecff0b5aa 100644 --- a/pyfixest/estimation/demean_.py +++ b/pyfixest/estimation/demean_.py @@ -329,7 +329,7 @@ def _set_demeaner_backend(demeaner_backend: Literal["numba", "jax"]) -> Callable if demeaner_backend == "numba": return demean elif demeaner_backend == "jax": - from pyfixest.estimation.demean_jax_ import demean_jax + from pyfixest.estimation.jax.demean_jax_ import demean_jax return demean_jax else: diff --git a/pyfixest/estimation/feols_.py b/pyfixest/estimation/feols_.py index 113751fb9..dd4e1e4f8 100644 --- a/pyfixest/estimation/feols_.py +++ b/pyfixest/estimation/feols_.py @@ -816,6 +816,7 @@ def get_inference(self, alpha: float = 0.05) -> None: ------- None """ + _vcov = self._vcov _beta_hat = self._beta_hat _vcov_type = self._vcov_type diff --git a/pyfixest/estimation/jax/OLSJAX.py b/pyfixest/estimation/jax/OLSJAX.py new file mode 100644 index 000000000..44606a337 --- /dev/null +++ b/pyfixest/estimation/jax/OLSJAX.py @@ -0,0 +1,145 @@ +from typing import Optional + +import jax +import jax.numpy as jnp +import pandas as pd + +from pyfixest.estimation.jax.demean_jax_ import demean_jax + +class OLSJAX: + def __init__( + self, + X: jax.Array, + Y: jax.Array, + fe: Optional[jax.Array] = None, + weights: Optional[jax.Array] = None, + vcov: Optional[str] = None, + ): + + """ + Class to run OLS regression in JAX. + + Parameters + ---------- + X : jax.Array + N x k matrix of independent variables. + Y : jax.Array + Dependent variable. N x 1 matrix. + fe : jax.Array, optional + Fixed effects. N x 1 matrix of integers. The default is None. + weights: jax.Array, optional + Weights. N x 1 matrix. The default is None. + vcov : str, optional + Type of covariance matrix. The default is None. Options are: + - "iid" (default): iid errors + - "HC1": heteroskedasticity robust + - "HC2": heteroskedasticity robust + - "HC3": heteroskedasticity robust + - "CRV1": cluster robust. In this case, please provide a dictionary + with the cluster variable as key and the name of the cluster variable as value. + """ + + self.X_orignal = X + self.Y_orignal = Y + self.fe = fe + self.N = X.shape[0] + self.k = X.shape[1] + self.weights = jnp.ones(self.N) if weights is None else weights + self.vcov_type = "iid" if vcov is None else vcov + + def fit(self): + self.Y, self.X = self.demean( + Y=self.Y_orignal, X=self.X_orignal, fe=self.fe, weights=self.weights + ) + self.beta = jnp.linalg.lstsq(self.X, self.Y)[0] + self.get_fit() + self.scores + self.vcov(vcov_type=self.vcov_type) + self.inference() + + def get_fit(self): + self.beta = jnp.linalg.lstsq(self.X, self.Y)[0] + + @property + def residuals(self): + return self.Y - self.X @ self.beta + + def vcov(self, vcov_type: str): + bread = self.bread + meat = self.meat(type=vcov_type) + if vcov_type == "iid": + self.vcov = bread * meat + else: + self.vcov = bread @ meat @ bread + + return self.vcov + + @property + def bread(self): + return jnp.linalg.inv(self.X.T @ self.X) + + @property + def leverage(self): + return jnp.sum(self.X * (self.X @ jnp.linalg.inv(self.X.T @ self.X)), axis=1) + + @property + def scores(self): + return self.X * self.residuals + + def meat(self, type: str): + if type == "iid": + return self.meat_iid + elif type == "HC1": + return self.meat_hc1 + elif type == "HC2": + return self.meat_hc2 + elif type == "HC3": + return self.meat_hc3 + elif type == "CRV1": + return self.meat_crv1 + else: + raise ValueError("Invalid type") + + @property + def meat_iid(self): + return jnp.sum(self.residuals**2) / (self.N - self.k) + + @property + def meat_hc1(self): + return self.scores.T @ self.scores + + def meat_hc2(self): + self.leverage + transformed_scores = self.scores / jnp.sqrt(1 - self.leverage) + return transformed_scores.T @ transformed_scores + + def meat_hc3(self): + self.leverage + transformed_scores = self.scores / (1 - self.leverage) + return transformed_scores.T @ transformed_scores + + @property + def meat_crv1(self): + raise NotImplementedError("CRV1 is not implemented") + + def predict(self, X): + X = jnp.array(X) + return X @ self.beta + + def demean(self, Y: jax.Array, X: jax.Array, fe: jax.Array, weights: jax.Array): + + if fe is not None: + if not jnp.issubdtype(fe.dtype, jnp.integer): + raise ValueError("Fixed effects must be integers") + + YX = jnp.concatenate((Y, X), axis=1) + YXd, success = demean_jax( + x=YX, flist=fe, weights=weights, output="jax" + ) + Yd = YXd[:, 0].reshape(-1, 1) + Xd = YXd[:, 1:] + + return Yd, Xd + + else: + return Y, X diff --git a/pyfixest/estimation/demean_jax_.py b/pyfixest/estimation/jax/demean_jax_.py similarity index 92% rename from pyfixest/estimation/demean_jax_.py rename to pyfixest/estimation/jax/demean_jax_.py index 2d327fb2d..0c27ec0f3 100644 --- a/pyfixest/estimation/demean_jax_.py +++ b/pyfixest/estimation/jax/demean_jax_.py @@ -72,6 +72,7 @@ def demean_jax( weights: np.ndarray, tol: float = 1e-08, maxiter: int = 100_000, + output: str = "numpy", ) -> tuple[np.ndarray, bool]: """Fast and reliable JAX implementation with static shapes.""" # Enable float64 precision @@ -89,4 +90,10 @@ def demean_jax( result_jax, converged = _demean_jax_impl( x_jax, flist_jax, weights_jax, n_groups, tol, maxiter ) - return np.array(result_jax), converged + + if output == "numpy": + return np.array(result_jax), converged + elif output == "jax": + return result_jax, converged + else: + raise ValueError("Invalid output type") diff --git a/pyfixest/estimation/detect_singletons_jax.py b/pyfixest/estimation/jax/detect_singletons_jax.py similarity index 100% rename from pyfixest/estimation/detect_singletons_jax.py rename to pyfixest/estimation/jax/detect_singletons_jax.py diff --git a/pyfixest/estimation/jax/olsjax_interface.py b/pyfixest/estimation/jax/olsjax_interface.py new file mode 100644 index 000000000..643b692a2 --- /dev/null +++ b/pyfixest/estimation/jax/olsjax_interface.py @@ -0,0 +1,109 @@ +from pyfixest.estimation.feols_ import Feols +from pyfixest.estimation.feols_ import Feols, _drop_multicollinear_variables +from pyfixest.estimation.FormulaParser import FixestFormula +import pandas as pd +import numpy as np +from typing import Union, Optional, Mapping, Any, Literal +import jax.numpy as jnp +from pyfixest.estimation.jax.OLSJAX import OLSJAX + +class OLSJAX_API(Feols): + + def __init__( + self, + FixestFormula: FixestFormula, + data: pd.DataFrame, + ssc_dict: dict[str, Union[str, bool]], + drop_singletons: bool, + drop_intercept: bool, + weights: Optional[str], + weights_type: Optional[str], + collin_tol: float, + fixef_tol: float, + lookup_demeaned_data: dict[str, pd.DataFrame], + solver: Literal[ + "np.linalg.lstsq", "np.linalg.solve", "scipy.sparse.linalg.lsqr", "jax" + ] = "np.linalg.solve", + demeaner_backend: Literal["numba", "jax"] = "numba", + store_data: bool = True, + copy_data: bool = True, + lean: bool = False, + context: Union[int, Mapping[str, Any]] = 0, + sample_split_var: Optional[str] = None, + sample_split_value: Optional[Union[str, int]] = None, + ) -> None: + super().__init__( + FixestFormula=FixestFormula, + data=data, + ssc_dict=ssc_dict, + drop_singletons=drop_singletons, + drop_intercept=drop_intercept, + weights=weights, + weights_type=weights_type, + collin_tol=collin_tol, + fixef_tol=fixef_tol, + lookup_demeaned_data=lookup_demeaned_data, + solver=solver, + store_data=store_data, + copy_data=copy_data, + lean=lean, + sample_split_var=sample_split_var, + sample_split_value=sample_split_value, + context=context, + demeaner_backend=demeaner_backend, + ) + + self.prepare_model_matrix() + self.to_jax_array() + + # later to be set in multicoll method + self._N, self._k = self._X_jax.shape + + self.olsjax = OLSJAX( + X=self._X_jax, + Y=self._Y_jax, + fe=self._fe_jax, + weights=self._weights_jax, + vcov="iid", + ) + #import pdb; pdb.set_trace() + self.olsjax.Y, self.olsjax.X = self.olsjax.demean(Y = self._Y_jax, X = self._X_jax, fe = self._fe_jax, weights = self._weights_jax.flatten()) + + def to_jax_array(self): + + self._X_jax = jnp.array(self._X) + self._Y_jax = jnp.array(self._Y) + self._fe_jax = jnp.array(self._fe) + self._weights_jax = jnp.array(self._weights) + + + def get_fit(self): + + self.olsjax.get_fit() + self._beta_hat = self.olsjax.beta.flatten() + self._u_hat = self.olsjax.residuals + self._scores = self.olsjax.scores + + def vcov(self, type: str): + + self._vcov_type = type + self.olsjax.vcov(vcov_type=type) + self._vcov = self.olsjax.vcov + + return self + + def convert_attributes_to_numpy(self): + "Convert core attributes from jax to numpy arrays." + attr = ["_beta_hat", "_u_hat", "_scores", "_vcov"] + for a in attr: + # convert to numpy + setattr(self, a, np.array(getattr(self, a))) + + + + + + + + + diff --git a/tests/test_demean.py b/tests/test_demean.py index 973e5958e..a4c08aeed 100644 --- a/tests/test_demean.py +++ b/tests/test_demean.py @@ -4,7 +4,7 @@ import pytest from pyfixest.estimation.demean_ import _set_demeaner_backend, demean, demean_model -from pyfixest.estimation.demean_jax_ import demean_jax +from pyfixest.estimation.jax.demean_jax_ import demean_jax @pytest.mark.parametrize( diff --git a/tests/test_detect_singletons.py b/tests/test_detect_singletons.py index d6c2af329..8f5480c37 100644 --- a/tests/test_detect_singletons.py +++ b/tests/test_detect_singletons.py @@ -2,7 +2,7 @@ import pytest from pyfixest.estimation.detect_singletons_ import detect_singletons -from pyfixest.estimation.detect_singletons_jax import detect_singletons_jax +from pyfixest.jax.detect_singletons_jax import detect_singletons_jax input1 = np.array([[0, 2, 1], [0, 2, 1], [0, 1, 3], [0, 1, 2], [0, 1, 2]]) solution1 = np.array([False, False, True, False, False])