Skip to content

Introduce JAX OLS Class #790

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 33 additions & 16 deletions pyfixest/estimation/FixestMulti_.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pyfixest.estimation.fegaussian_ import Fegaussian
from pyfixest.estimation.feiv_ import Feiv
from pyfixest.estimation.felogit_ import Felogit
from pyfixest.estimation.jax.olsjax_interface import OLSJAX_API
from pyfixest.estimation.feols_ import Feols, _check_vcov_input, _deparse_vcov_input
from pyfixest.estimation.feols_compressed_ import FeolsCompressed
from pyfixest.estimation.fepois_ import Fepois
Expand All @@ -16,7 +17,6 @@
from pyfixest.utils.dev_utils import DataFrameType, _narwhals_to_pandas
from pyfixest.utils.utils import capture_context


class FixestMulti:
"""A class to estimate multiple regression models with fixed effects."""

Expand Down Expand Up @@ -279,7 +279,11 @@ def _estimate_all_models(
FIT: Union[Feols, Feiv, Fepois]

if _method == "feols" and not _is_iv:
FIT = Feols(

backend = "jax"
OLSCLASS = Feols if backend != "jax" else OLSJAX_API

FIT = OLSCLASS(
FixestFormula=FixestFormula,
data=_data,
ssc_dict=_ssc_dict,
Expand All @@ -299,11 +303,15 @@ def _estimate_all_models(
sample_split_value=sample_split_value,
sample_split_var=_splitvar,
)
FIT.prepare_model_matrix()
FIT.demean()
FIT.to_array()
FIT.drop_multicol_vars()
FIT.wls_transform()

if backend != "jax":
FIT.prepare_model_matrix()
FIT.demean()
FIT.to_array()
FIT.drop_multicol_vars()
FIT.wls_transform()


elif _method == "feols" and _is_iv:
FIT = Feiv(
FixestFormula=FixestFormula,
Expand All @@ -330,6 +338,7 @@ def _estimate_all_models(
FIT.to_array()
FIT.drop_multicol_vars()
FIT.wls_transform()

elif _method == "fepois":
FIT = Fepois(
FixestFormula=FixestFormula,
Expand Down Expand Up @@ -475,17 +484,25 @@ def _estimate_all_models(

FIT.get_fit()
# if X is empty: no inference (empty X only as shorthand for demeaning)
if not FIT._X_is_empty:
# inference
vcov_type = _get_vcov_type(vcov, fval)
FIT.vcov(vcov=vcov_type, data=FIT._data)
if backend != "jax":
if not FIT._X_is_empty:
# inference
vcov_type = _get_vcov_type(vcov, fval)
FIT.vcov(vcov=vcov_type, data=FIT._data)

FIT.get_inference()
# other regression stats
if _method == "feols" and not FIT._is_iv:
FIT.get_performance()
if isinstance(FIT, Feiv):
FIT.first_stage()

else:

#import pdb; pdb.set_trace()
FIT.vcov(type = "iid")
FIT.convert_attributes_to_numpy()
FIT.get_inference()
# other regression stats
if _method == "feols" and not FIT._is_iv:
FIT.get_performance()
if isinstance(FIT, Feiv):
FIT.first_stage()

# delete large attributescl
FIT._clear_attributes()
Expand Down
2 changes: 1 addition & 1 deletion pyfixest/estimation/demean_.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def _set_demeaner_backend(demeaner_backend: Literal["numba", "jax"]) -> Callable
if demeaner_backend == "numba":
return demean
elif demeaner_backend == "jax":
from pyfixest.estimation.demean_jax_ import demean_jax
from pyfixest.estimation.jax.demean_jax_ import demean_jax

return demean_jax
else:
Expand Down
1 change: 1 addition & 0 deletions pyfixest/estimation/feols_.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,7 @@ def get_inference(self, alpha: float = 0.05) -> None:
-------
None
"""

_vcov = self._vcov
_beta_hat = self._beta_hat
_vcov_type = self._vcov_type
Expand Down
145 changes: 145 additions & 0 deletions pyfixest/estimation/jax/OLSJAX.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from typing import Optional

import jax
import jax.numpy as jnp
import pandas as pd

from pyfixest.estimation.jax.demean_jax_ import demean_jax

class OLSJAX:
def __init__(
Comment on lines +9 to +10
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstrings?

self,
X: jax.Array,
Y: jax.Array,
fe: Optional[jax.Array] = None,
weights: Optional[jax.Array] = None,
vcov: Optional[str] = None,
):

"""
Class to run OLS regression in JAX.

Parameters
----------
X : jax.Array
N x k matrix of independent variables.
Y : jax.Array
Dependent variable. N x 1 matrix.
fe : jax.Array, optional
Fixed effects. N x 1 matrix of integers. The default is None.
weights: jax.Array, optional
Weights. N x 1 matrix. The default is None.
vcov : str, optional
Type of covariance matrix. The default is None. Options are:
- "iid" (default): iid errors
- "HC1": heteroskedasticity robust
- "HC2": heteroskedasticity robust
- "HC3": heteroskedasticity robust
- "CRV1": cluster robust. In this case, please provide a dictionary
with the cluster variable as key and the name of the cluster variable as value.
"""

self.X_orignal = X
self.Y_orignal = Y
self.fe = fe
self.N = X.shape[0]
self.k = X.shape[1]
self.weights = jnp.ones(self.N) if weights is None else weights
self.vcov_type = "iid" if vcov is None else vcov

def fit(self):
self.Y, self.X = self.demean(
Y=self.Y_orignal, X=self.X_orignal, fe=self.fe, weights=self.weights
)
self.beta = jnp.linalg.lstsq(self.X, self.Y)[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious whether using lx.linear_solve gets speed-gains here; example

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Speed gains look superb, how do you feel about adding lineax as an optional dependency to the JAX env?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm game; Kidger's stuff is very high quality and stable, so I doubt that it will break anytime soon.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'll test out lineax's solver on this PR over the weekend; expect a commit

self.get_fit()
self.scores
self.vcov(vcov_type=self.vcov_type)
self.inference()

def get_fit(self):
self.beta = jnp.linalg.lstsq(self.X, self.Y)[0]

@property
def residuals(self):
return self.Y - self.X @ self.beta

def vcov(self, vcov_type: str):
bread = self.bread
meat = self.meat(type=vcov_type)
if vcov_type == "iid":
self.vcov = bread * meat
else:
self.vcov = bread @ meat @ bread

return self.vcov

@property
def bread(self):
return jnp.linalg.inv(self.X.T @ self.X)

@property
def leverage(self):
return jnp.sum(self.X * (self.X @ jnp.linalg.inv(self.X.T @ self.X)), axis=1)

@property
def scores(self):
return self.X * self.residuals

def meat(self, type: str):
if type == "iid":
return self.meat_iid
elif type == "HC1":
return self.meat_hc1
elif type == "HC2":
return self.meat_hc2
elif type == "HC3":
return self.meat_hc3
elif type == "CRV1":
return self.meat_crv1
else:
raise ValueError("Invalid type")

@property
def meat_iid(self):
return jnp.sum(self.residuals**2) / (self.N - self.k)

@property
def meat_hc1(self):
return self.scores.T @ self.scores

def meat_hc2(self):
self.leverage
transformed_scores = self.scores / jnp.sqrt(1 - self.leverage)
return transformed_scores.T @ transformed_scores

def meat_hc3(self):
self.leverage
transformed_scores = self.scores / (1 - self.leverage)
return transformed_scores.T @ transformed_scores

@property
def meat_crv1(self):
raise NotImplementedError("CRV1 is not implemented")

def predict(self, X):
X = jnp.array(X)
return X @ self.beta

def demean(self, Y: jax.Array, X: jax.Array, fe: jax.Array, weights: jax.Array):

if fe is not None:
if not jnp.issubdtype(fe.dtype, jnp.integer):
raise ValueError("Fixed effects must be integers")

YX = jnp.concatenate((Y, X), axis=1)
YXd, success = demean_jax(
x=YX, flist=fe, weights=weights, output="jax"
)
Yd = YXd[:, 0].reshape(-1, 1)
Xd = YXd[:, 1:]

return Yd, Xd

else:
return Y, X
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def demean_jax(
weights: np.ndarray,
tol: float = 1e-08,
maxiter: int = 100_000,
output: str = "numpy",
) -> tuple[np.ndarray, bool]:
"""Fast and reliable JAX implementation with static shapes."""
# Enable float64 precision
Expand All @@ -89,4 +90,10 @@ def demean_jax(
result_jax, converged = _demean_jax_impl(
x_jax, flist_jax, weights_jax, n_groups, tol, maxiter
)
return np.array(result_jax), converged

if output == "numpy":
return np.array(result_jax), converged
elif output == "jax":
return result_jax, converged
else:
raise ValueError("Invalid output type")
109 changes: 109 additions & 0 deletions pyfixest/estimation/jax/olsjax_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from pyfixest.estimation.feols_ import Feols
from pyfixest.estimation.feols_ import Feols, _drop_multicollinear_variables
from pyfixest.estimation.FormulaParser import FixestFormula
import pandas as pd
import numpy as np
from typing import Union, Optional, Mapping, Any, Literal
import jax.numpy as jnp
from pyfixest.estimation.jax.OLSJAX import OLSJAX

class OLSJAX_API(Feols):

def __init__(
self,
FixestFormula: FixestFormula,
data: pd.DataFrame,
ssc_dict: dict[str, Union[str, bool]],
drop_singletons: bool,
drop_intercept: bool,
weights: Optional[str],
weights_type: Optional[str],
collin_tol: float,
fixef_tol: float,
lookup_demeaned_data: dict[str, pd.DataFrame],
solver: Literal[
"np.linalg.lstsq", "np.linalg.solve", "scipy.sparse.linalg.lsqr", "jax"
] = "np.linalg.solve",
demeaner_backend: Literal["numba", "jax"] = "numba",
store_data: bool = True,
copy_data: bool = True,
lean: bool = False,
context: Union[int, Mapping[str, Any]] = 0,
sample_split_var: Optional[str] = None,
sample_split_value: Optional[Union[str, int]] = None,
) -> None:
super().__init__(
FixestFormula=FixestFormula,
data=data,
ssc_dict=ssc_dict,
drop_singletons=drop_singletons,
drop_intercept=drop_intercept,
weights=weights,
weights_type=weights_type,
collin_tol=collin_tol,
fixef_tol=fixef_tol,
lookup_demeaned_data=lookup_demeaned_data,
solver=solver,
store_data=store_data,
copy_data=copy_data,
lean=lean,
sample_split_var=sample_split_var,
sample_split_value=sample_split_value,
context=context,
demeaner_backend=demeaner_backend,
)

self.prepare_model_matrix()
self.to_jax_array()

# later to be set in multicoll method
self._N, self._k = self._X_jax.shape

self.olsjax = OLSJAX(
X=self._X_jax,
Y=self._Y_jax,
fe=self._fe_jax,
weights=self._weights_jax,
vcov="iid",
)
#import pdb; pdb.set_trace()
self.olsjax.Y, self.olsjax.X = self.olsjax.demean(Y = self._Y_jax, X = self._X_jax, fe = self._fe_jax, weights = self._weights_jax.flatten())

def to_jax_array(self):

self._X_jax = jnp.array(self._X)
self._Y_jax = jnp.array(self._Y)
self._fe_jax = jnp.array(self._fe)
self._weights_jax = jnp.array(self._weights)


def get_fit(self):

self.olsjax.get_fit()
self._beta_hat = self.olsjax.beta.flatten()
self._u_hat = self.olsjax.residuals
self._scores = self.olsjax.scores

def vcov(self, type: str):

self._vcov_type = type
self.olsjax.vcov(vcov_type=type)
self._vcov = self.olsjax.vcov

return self

def convert_attributes_to_numpy(self):
"Convert core attributes from jax to numpy arrays."
attr = ["_beta_hat", "_u_hat", "_scores", "_vcov"]
for a in attr:
# convert to numpy
setattr(self, a, np.array(getattr(self, a)))









2 changes: 1 addition & 1 deletion tests/test_demean.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

from pyfixest.estimation.demean_ import _set_demeaner_backend, demean, demean_model
from pyfixest.estimation.demean_jax_ import demean_jax
from pyfixest.estimation.jax.demean_jax_ import demean_jax


@pytest.mark.parametrize(
Expand Down
Loading
Loading