diff --git a/pyfixest/estimation/ritest.py b/pyfixest/estimation/ritest.py index abf16da75..d92265975 100644 --- a/pyfixest/estimation/ritest.py +++ b/pyfixest/estimation/ritest.py @@ -6,6 +6,7 @@ import numpy as np import pandas as pd import seaborn as sns +from joblib import Parallel, delayed from lets_plot import ( LetsPlot, aes, @@ -24,7 +25,6 @@ LetsPlot.setup_html() - def _get_ritest_stats_slow( data: pd.DataFrame, resampvar: str, @@ -77,28 +77,36 @@ def _get_ritest_stats_slow( fit_ = getattr(fixest_module, model) resampvar_arr = data_resampled[resampvar].to_numpy() - ri_stats = np.zeros(reps) - for i in tqdm(range(reps)): - D_treat = _resample( - resampvar_arr=resampvar_arr, - clustervar_arr=clustervar_arr, - rng=rng, - iterations=1, - ).flatten() - - data_resampled[f"{resampvar}_resampled"] = D_treat + results = Parallel(n_jobs=-1)( + delayed(lambda: ( + # Create resampled treatment values + D_treat := _resample( + resampvar_arr=resampvar_arr, + clustervar_arr=clustervar_arr, + rng=rng, + iterations=1, + ).flatten(), + + # Add values to data + data_resampled.__setitem__(f"{resampvar}_resampled", D_treat), + fixest_fit := fit_(fml_update, data=data_resampled, vcov=vcov), + + # Return appropriate statistic + fixest_fit.coef().xs(f"{resampvar}_resampled") + if type == "randomization-c" + else fixest_fit.tstat().xs(f"{resampvar}_resampled") + )[3])() + for _ in tqdm(range(reps)) # We use _ since we don't actually need the index + ) - fixest_fit = fit_(fml_update, data=data_resampled, vcov=vcov) - if type == "randomization-c": - ri_stats[i] = fixest_fit.coef().xs(f"{resampvar}_resampled") - else: - ri_stats[i] = fixest_fit.tstat().xs(f"{resampvar}_resampled") + # Fill out the results array + for i, result in enumerate(results): + ri_stats[i] = result return ri_stats - def _get_ritest_stats_fast( Y: np.ndarray, X: np.ndarray,