-
Notifications
You must be signed in to change notification settings - Fork 58
Introduce JAX OLS Class #790
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
b2f5b63
2e3e4a5
9291421
ffd3217
80bd6a4
62da934
2c92c3d
3f595cd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
from typing import Optional | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
import pandas as pd | ||
|
||
from pyfixest.estimation.jax.demean_jax_ import demean_jax | ||
|
||
class OLSJAX: | ||
def __init__( | ||
self, | ||
X: jax.Array, | ||
Y: jax.Array, | ||
fe: Optional[jax.Array] = None, | ||
weights: Optional[jax.Array] = None, | ||
vcov: Optional[str] = None, | ||
): | ||
|
||
""" | ||
Class to run OLS regression in JAX. | ||
|
||
Parameters | ||
---------- | ||
X : jax.Array | ||
N x k matrix of independent variables. | ||
Y : jax.Array | ||
Dependent variable. N x 1 matrix. | ||
fe : jax.Array, optional | ||
Fixed effects. N x 1 matrix of integers. The default is None. | ||
weights: jax.Array, optional | ||
Weights. N x 1 matrix. The default is None. | ||
vcov : str, optional | ||
Type of covariance matrix. The default is None. Options are: | ||
- "iid" (default): iid errors | ||
- "HC1": heteroskedasticity robust | ||
- "HC2": heteroskedasticity robust | ||
- "HC3": heteroskedasticity robust | ||
- "CRV1": cluster robust. In this case, please provide a dictionary | ||
with the cluster variable as key and the name of the cluster variable as value. | ||
""" | ||
|
||
self.X_orignal = X | ||
self.Y_orignal = Y | ||
self.fe = fe | ||
self.N = X.shape[0] | ||
self.k = X.shape[1] | ||
self.weights = jnp.ones(self.N) if weights is None else weights | ||
self.vcov_type = "iid" if vcov is None else vcov | ||
|
||
def fit(self): | ||
self.Y, self.X = self.demean( | ||
Y=self.Y_orignal, X=self.X_orignal, fe=self.fe, weights=self.weights | ||
) | ||
self.beta = jnp.linalg.lstsq(self.X, self.Y)[0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. curious whether using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Speed gains look superb, how do you feel about adding lineax as an optional dependency to the JAX env? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i'm game; Kidger's stuff is very high quality and stable, so I doubt that it will break anytime soon. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i'll test out lineax's solver on this PR over the weekend; expect a commit |
||
self.get_fit() | ||
self.scores | ||
self.vcov(vcov_type=self.vcov_type) | ||
self.inference() | ||
|
||
def get_fit(self): | ||
self.beta = jnp.linalg.lstsq(self.X, self.Y)[0] | ||
|
||
@property | ||
def residuals(self): | ||
return self.Y - self.X @ self.beta | ||
|
||
def vcov(self, vcov_type: str): | ||
bread = self.bread | ||
meat = self.meat(type=vcov_type) | ||
if vcov_type == "iid": | ||
self.vcov = bread * meat | ||
else: | ||
self.vcov = bread @ meat @ bread | ||
|
||
return self.vcov | ||
|
||
@property | ||
def bread(self): | ||
return jnp.linalg.inv(self.X.T @ self.X) | ||
|
||
@property | ||
def leverage(self): | ||
return jnp.sum(self.X * (self.X @ jnp.linalg.inv(self.X.T @ self.X)), axis=1) | ||
|
||
@property | ||
def scores(self): | ||
return self.X * self.residuals | ||
|
||
def meat(self, type: str): | ||
if type == "iid": | ||
return self.meat_iid | ||
elif type == "HC1": | ||
return self.meat_hc1 | ||
elif type == "HC2": | ||
return self.meat_hc2 | ||
elif type == "HC3": | ||
return self.meat_hc3 | ||
elif type == "CRV1": | ||
return self.meat_crv1 | ||
else: | ||
raise ValueError("Invalid type") | ||
|
||
@property | ||
def meat_iid(self): | ||
return jnp.sum(self.residuals**2) / (self.N - self.k) | ||
|
||
@property | ||
def meat_hc1(self): | ||
return self.scores.T @ self.scores | ||
|
||
def meat_hc2(self): | ||
self.leverage | ||
transformed_scores = self.scores / jnp.sqrt(1 - self.leverage) | ||
return transformed_scores.T @ transformed_scores | ||
|
||
def meat_hc3(self): | ||
self.leverage | ||
transformed_scores = self.scores / (1 - self.leverage) | ||
return transformed_scores.T @ transformed_scores | ||
|
||
@property | ||
def meat_crv1(self): | ||
raise NotImplementedError("CRV1 is not implemented") | ||
|
||
def predict(self, X): | ||
X = jnp.array(X) | ||
return X @ self.beta | ||
|
||
def demean(self, Y: jax.Array, X: jax.Array, fe: jax.Array, weights: jax.Array): | ||
|
||
if fe is not None: | ||
if not jnp.issubdtype(fe.dtype, jnp.integer): | ||
raise ValueError("Fixed effects must be integers") | ||
|
||
YX = jnp.concatenate((Y, X), axis=1) | ||
YXd, success = demean_jax( | ||
x=YX, flist=fe, weights=weights, output="jax" | ||
) | ||
Yd = YXd[:, 0].reshape(-1, 1) | ||
Xd = YXd[:, 1:] | ||
|
||
return Yd, Xd | ||
|
||
else: | ||
return Y, X |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
from pyfixest.estimation.feols_ import Feols | ||
from pyfixest.estimation.feols_ import Feols, _drop_multicollinear_variables | ||
from pyfixest.estimation.FormulaParser import FixestFormula | ||
import pandas as pd | ||
import numpy as np | ||
from typing import Union, Optional, Mapping, Any, Literal | ||
import jax.numpy as jnp | ||
from pyfixest.estimation.jax.OLSJAX import OLSJAX | ||
|
||
class OLSJAX_API(Feols): | ||
|
||
def __init__( | ||
self, | ||
FixestFormula: FixestFormula, | ||
data: pd.DataFrame, | ||
ssc_dict: dict[str, Union[str, bool]], | ||
drop_singletons: bool, | ||
drop_intercept: bool, | ||
weights: Optional[str], | ||
weights_type: Optional[str], | ||
collin_tol: float, | ||
fixef_tol: float, | ||
lookup_demeaned_data: dict[str, pd.DataFrame], | ||
solver: Literal[ | ||
"np.linalg.lstsq", "np.linalg.solve", "scipy.sparse.linalg.lsqr", "jax" | ||
] = "np.linalg.solve", | ||
demeaner_backend: Literal["numba", "jax"] = "numba", | ||
store_data: bool = True, | ||
copy_data: bool = True, | ||
lean: bool = False, | ||
context: Union[int, Mapping[str, Any]] = 0, | ||
sample_split_var: Optional[str] = None, | ||
sample_split_value: Optional[Union[str, int]] = None, | ||
) -> None: | ||
super().__init__( | ||
FixestFormula=FixestFormula, | ||
data=data, | ||
ssc_dict=ssc_dict, | ||
drop_singletons=drop_singletons, | ||
drop_intercept=drop_intercept, | ||
weights=weights, | ||
weights_type=weights_type, | ||
collin_tol=collin_tol, | ||
fixef_tol=fixef_tol, | ||
lookup_demeaned_data=lookup_demeaned_data, | ||
solver=solver, | ||
store_data=store_data, | ||
copy_data=copy_data, | ||
lean=lean, | ||
sample_split_var=sample_split_var, | ||
sample_split_value=sample_split_value, | ||
context=context, | ||
demeaner_backend=demeaner_backend, | ||
) | ||
|
||
self.prepare_model_matrix() | ||
self.to_jax_array() | ||
|
||
# later to be set in multicoll method | ||
self._N, self._k = self._X_jax.shape | ||
|
||
self.olsjax = OLSJAX( | ||
X=self._X_jax, | ||
Y=self._Y_jax, | ||
fe=self._fe_jax, | ||
weights=self._weights_jax, | ||
vcov="iid", | ||
) | ||
#import pdb; pdb.set_trace() | ||
self.olsjax.Y, self.olsjax.X = self.olsjax.demean(Y = self._Y_jax, X = self._X_jax, fe = self._fe_jax, weights = self._weights_jax.flatten()) | ||
|
||
def to_jax_array(self): | ||
|
||
self._X_jax = jnp.array(self._X) | ||
self._Y_jax = jnp.array(self._Y) | ||
self._fe_jax = jnp.array(self._fe) | ||
self._weights_jax = jnp.array(self._weights) | ||
|
||
|
||
def get_fit(self): | ||
|
||
self.olsjax.get_fit() | ||
self._beta_hat = self.olsjax.beta.flatten() | ||
self._u_hat = self.olsjax.residuals | ||
self._scores = self.olsjax.scores | ||
|
||
def vcov(self, type: str): | ||
|
||
self._vcov_type = type | ||
self.olsjax.vcov(vcov_type=type) | ||
self._vcov = self.olsjax.vcov | ||
|
||
return self | ||
|
||
def convert_attributes_to_numpy(self): | ||
"Convert core attributes from jax to numpy arrays." | ||
attr = ["_beta_hat", "_u_hat", "_scores", "_vcov"] | ||
for a in attr: | ||
# convert to numpy | ||
setattr(self, a, np.array(getattr(self, a))) | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
docstrings?