From b2f5b6300d1fcff74f995ba91971984469ab7ded Mon Sep 17 00:00:00 2001 From: Alexander Fischer Date: Thu, 9 Jan 2025 23:16:21 +0100 Subject: [PATCH 1/6] add jax benchmark notebook --- benchmarks/gpu_pyfixest_errors.ipynb | 1478 ++++++++++++++++++++++++++ 1 file changed, 1478 insertions(+) create mode 100644 benchmarks/gpu_pyfixest_errors.ipynb diff --git a/benchmarks/gpu_pyfixest_errors.ipynb b/benchmarks/gpu_pyfixest_errors.ipynb new file mode 100644 index 000000000..3f6ee1481 --- /dev/null +++ b/benchmarks/gpu_pyfixest_errors.ipynb @@ -0,0 +1,1478 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "execution": { + "iopub.execute_input": "2025-01-09T00:23:28.917621Z", + "iopub.status.busy": "2025-01-09T00:23:28.917332Z", + "iopub.status.idle": "2025-01-09T00:23:29.477701Z", + "shell.execute_reply": "2025-01-09T00:23:29.477193Z", + "shell.execute_reply.started": "2025-01-09T00:23:28.917602Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[CpuDevice(id=0)]" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "\n", + "jax.devices()" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "execution": { + "iopub.execute_input": "2025-01-09T00:23:30.540594Z", + "iopub.status.busy": "2025-01-09T00:23:30.540228Z", + "iopub.status.idle": "2025-01-09T00:23:30.739685Z", + "shell.execute_reply": "2025-01-09T00:23:30.739213Z", + "shell.execute_reply.started": "2025-01-09T00:23:30.540574Z" + }, + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{CpuDevice(id=0)}" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "jnp.ones(10).devices()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "execution": { + "iopub.execute_input": "2025-01-09T00:26:29.239253Z", + "iopub.status.busy": "2025-01-09T00:26:29.238947Z", + "iopub.status.idle": "2025-01-09T00:26:29.754752Z", + "shell.execute_reply": "2025-01-09T00:26:29.754158Z", + "shell.execute_reply.started": "2025-01-09T00:26:29.239235Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "'nvidia-smi' is not recognized as an internal or external command,\n", + "operable program or batch file.\n" + ] + } + ], + "source": [ + "!nvidia-smi" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "execution": { + "iopub.execute_input": "2025-01-09T00:23:44.232335Z", + "iopub.status.busy": "2025-01-09T00:23:44.231984Z", + "iopub.status.idle": "2025-01-09T00:23:45.388035Z", + "shell.execute_reply": "2025-01-09T00:23:45.387587Z", + "shell.execute_reply.started": "2025-01-09T00:23:44.232312Z" + }, + "id": "fHzEldNvR2_K" + }, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " " + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " " + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import time\n", + "from itertools import product\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import pandas as pd\n", + "import seaborn as sns\n", + "from scipy.stats import nbinom\n", + "from tqdm import tqdm\n", + "\n", + "import pyfixest as pf\n", + "from pyfixest.estimation.demean_ import demean\n", + "from pyfixest.estimation.demean_jax_ import demean_jax" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "execution": { + "iopub.execute_input": "2025-01-09T00:23:46.290548Z", + "iopub.status.busy": "2025-01-09T00:23:46.289898Z", + "iopub.status.idle": "2025-01-09T00:23:46.417097Z", + "shell.execute_reply": "2025-01-09T00:23:46.416504Z", + "shell.execute_reply.started": "2025-01-09T00:23:46.290525Z" + }, + "id": "XQjP2889YJxs", + "outputId": "3e686d7b-0774-4bb5-c1b9-28e5b9f286a9" + }, + "outputs": [], + "source": [ + "# %load_ext watermark\n", + "# %watermark --iversions" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "colab": { + "background_save": true + }, + "execution": { + "iopub.execute_input": "2025-01-09T00:23:49.545271Z", + "iopub.status.busy": "2025-01-09T00:23:49.545016Z", + "iopub.status.idle": "2025-01-09T00:23:49.552123Z", + "shell.execute_reply": "2025-01-09T00:23:49.551676Z", + "shell.execute_reply.started": "2025-01-09T00:23:49.545253Z" + }, + "id": "bxMmeyCxR3fb" + }, + "outputs": [], + "source": [ + "def generate_test_data(size: int, k: int = 2):\n", + " \"\"\"\n", + " Generate benchmark data for pyfixest on GPU (similar to the R fixest benchmark data).\n", + "\n", + " Args:\n", + " size (int): The number of observations in the data frame.\n", + " k (int): The number of covariates in the data frame.\n", + "\n", + " Returns\n", + " -------\n", + " pd.DataFrame: The generated data frame for the given size.\n", + " \"\"\"\n", + " # Constants\n", + " all_n = [1000 * 10**i for i in range(5)]\n", + " a = 1\n", + " b = 0.05\n", + "\n", + " n = all_n[size - 1]\n", + "\n", + " dum_all = []\n", + " nb_dum = [n // 20, int(np.sqrt(n)), int(n**0.33)]\n", + "\n", + " dum_all = np.zeros((n, 3))\n", + " dum_all[:, 0] = np.random.choice(nb_dum[0], n, replace=True)\n", + " dum_all[:, 1] = np.random.choice(nb_dum[1], n, replace=True)\n", + " dum_all[:, 2] = np.random.choice(nb_dum[2], n, replace=True)\n", + " dum_all = dum_all.astype(int)\n", + "\n", + " X1 = np.random.normal(size=n)\n", + " X2 = X1**2\n", + "\n", + " mu = a * X1 + b * X2\n", + "\n", + " for m in range(3):\n", + " coef_dum = np.random.normal(size=nb_dum[m])\n", + " mu += coef_dum[dum_all[:, m]]\n", + "\n", + " mu = np.exp(mu)\n", + " y = nbinom.rvs(0.5, 1 - (mu / (mu + 0.5)), size=n)\n", + "\n", + " X_full = np.column_stack((X1, X2))\n", + " base = pd.DataFrame(\n", + " {\n", + " \"y\": y,\n", + " \"ln_y\": np.log(y + 1),\n", + " \"X1\": X1,\n", + " \"X2\": X2,\n", + " }\n", + " )\n", + "\n", + " if k > 2:\n", + " X = np.random.normal(size=(n, k - 2))\n", + " X_df = pd.DataFrame(X, columns=[f\"X{i}\" for i in range(3, k + 1, 1)])\n", + " base = pd.concat([base, X_df], axis=1)\n", + " X_full = np.column_stack((X_full, X))\n", + "\n", + " for m in range(3):\n", + " base[f\"dum_{m + 1}\"] = dum_all[:, m]\n", + "\n", + " weights = np.random.uniform(0, 1, n)\n", + " return base, y, X_full, dum_all, weights" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "execution": { + "iopub.execute_input": "2025-01-09T00:23:50.285297Z", + "iopub.status.busy": "2025-01-09T00:23:50.284967Z", + "iopub.status.idle": "2025-01-09T00:23:50.460957Z", + "shell.execute_reply": "2025-01-09T00:23:50.460501Z", + "shell.execute_reply.started": "2025-01-09T00:23:50.285276Z" + }, + "id": "nzynhbqwR81H" + }, + "outputs": [], + "source": [ + "df, Y, X, f, weights = generate_test_data(1)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "execution": { + "iopub.execute_input": "2025-01-09T00:25:02.873750Z", + "iopub.status.busy": "2025-01-09T00:25:02.873239Z", + "iopub.status.idle": "2025-01-09T00:25:03.153458Z", + "shell.execute_reply": "2025-01-09T00:25:03.153005Z", + "shell.execute_reply.started": "2025-01-09T00:25:02.873732Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "###\n", + "\n", + "Estimation: OLS\n", + "Dep. var.: ln_y, Fixed effects: dum_1\n", + "Inference: CRV1\n", + "Observations: 1000\n", + "\n", + "| Coefficient | Estimate | Std. Error | t value | Pr(>|t|) | 2.5% | 97.5% |\n", + "|:--------------|-----------:|-------------:|----------:|-----------:|-------:|--------:|\n", + "| X1 | 0.436 | 0.046 | 9.440 | 0.000 | 0.343 | 0.529 |\n", + "---\n", + "RMSE: 1.067 R2: 0.242 R2 Within: 0.131 \n" + ] + } + ], + "source": [ + "m0 = pf.feols(\"ln_y ~ X1 | dum_1\", df, demeaner_backend=\"numba\")\n", + "m0.summary()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "execution": { + "iopub.execute_input": "2025-01-09T00:25:03.330646Z", + "iopub.status.busy": "2025-01-09T00:25:03.330367Z", + "iopub.status.idle": "2025-01-09T00:25:03.571916Z", + "shell.execute_reply": "2025-01-09T00:25:03.571482Z", + "shell.execute_reply.started": "2025-01-09T00:25:03.330629Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "###\n", + "\n", + "Estimation: OLS\n", + "Dep. var.: ln_y, Fixed effects: dum_1\n", + "Inference: CRV1\n", + "Observations: 1000\n", + "\n", + "| Coefficient | Estimate | Std. Error | t value | Pr(>|t|) | 2.5% | 97.5% |\n", + "|:--------------|-----------:|-------------:|----------:|-----------:|-------:|--------:|\n", + "| X1 | 0.436 | 0.046 | 9.440 | 0.000 | 0.343 | 0.529 |\n", + "---\n", + "RMSE: 1.067 R2: 0.242 R2 Within: 0.131 \n" + ] + } + ], + "source": [ + "m1 = pf.feols(\"ln_y ~ X1 | dum_1\", df, demeaner_backend=\"jax\")\n", + "m1.summary()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## function" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "execution": { + "iopub.execute_input": "2025-01-09T00:24:06.619552Z", + "iopub.status.busy": "2025-01-09T00:24:06.619273Z", + "iopub.status.idle": "2025-01-09T00:24:06.626298Z", + "shell.execute_reply": "2025-01-09T00:24:06.625727Z", + "shell.execute_reply.started": "2025-01-09T00:24:06.619534Z" + }, + "id": "29rZkULUR_A0" + }, + "outputs": [], + "source": [ + "def run_standard_benchmark(\n", + " fixed_effect,\n", + " demeaner_backend,\n", + " size=1,\n", + " k=1,\n", + " solver=\"np.linalg.lstsq\",\n", + " skip_demean_benchmark=True,\n", + "):\n", + " \"\"\"\n", + " Run the fixest standard benchmark fixed effect models. This is the function the benchmarks\n", + " will loop over.\n", + "\n", + " Args:\n", + " fixed_effect (str): The fixed effect to use. Must be a list of variables as \"dum_1\", \"dum_1+dum_2\", or \"dum_1+dum_2+dum_3\", etc.\n", + " demeaner_backend (str): The backend to use for demeaning. Must be \"numba\" or \"jax\".\n", + " size (int): The size of the data to generate. Must be between 1 and 5. For 1, N = 1000, for 2, N = 10000, etc.\n", + " k_vals (int): The number of covariates to generate.\n", + " solver (str): The solver to use for the estimation. Must be \"np.linalg.lstsq\". \"jax\" currently throws an error.\n", + " skip_demean_benchmark (bool): Whether to skip the \"pure\" demean benchmark. Default is True. Only the full call\n", + " to feols is benchmarked.\n", + "\n", + " \"\"\"\n", + " assert fixed_effect in [\"dum_1\", \"dum_1+dum_2\", \"dum_1+dum_2+dum_3\"]\n", + "\n", + " # one fixed effect\n", + " res = []\n", + "\n", + " fml_base = \"ln_y ~ X1\"\n", + " fml = f\"{fml_base} | {fixed_effect}\"\n", + "\n", + " # warmup\n", + " df, y, X, f, weights = generate_test_data(1)\n", + " pf.feols(\n", + " fml,\n", + " data=df,\n", + " demeaner_backend=demeaner_backend,\n", + " store_data=False,\n", + " copy_data=False,\n", + " solver=solver,\n", + " )\n", + "\n", + " if k > 1:\n", + " xfml = \"+\".join([f\"X{i}\" for i in range(2, k + 1, 1)])\n", + " fml = f\"{fml_base} + {xfml} | {fixed_effect}\"\n", + " else:\n", + " fml = f\"{fml_base} + X1 | {fixed_effect}\"\n", + "\n", + " for rep in range(1, 11):\n", + " df, Y, X, f, weights = generate_test_data(size=size, k=k)\n", + "\n", + " tic1 = time.time()\n", + " pf.feols(\n", + " fml,\n", + " data=df,\n", + " demeaner_backend=demeaner_backend,\n", + " store_data=False,\n", + " copy_data=False,\n", + " solver=solver,\n", + " )\n", + " tic2 = time.time()\n", + "\n", + " full_feols_timing = tic2 - tic1\n", + "\n", + " demean_timing = np.nan\n", + " if not skip_demean_benchmark:\n", + " YX = np.column_stack((Y.reshape(-1, 1), X))\n", + " tic3 = time.time()\n", + " if demeaner_backend == \"jax\":\n", + " _, _ = demean_jax(YX, f, weights, tol=1e-10)\n", + " else:\n", + " _, _ = demean(YX, f, weights, tol=1e-10)\n", + " tic4 = time.time()\n", + " demean_timing = tic4 - tic3\n", + "\n", + " res.append(\n", + " pd.Series(\n", + " {\n", + " \"method\": \"feols\",\n", + " \"solver\": solver,\n", + " \"demeaner_backend\": demeaner_backend,\n", + " \"n_obs\": df.shape[0],\n", + " \"k\": k,\n", + " \"G\": len(fixed_effect.split(\"+\")),\n", + " \"rep\": rep,\n", + " \"full_feols_timing\": full_feols_timing,\n", + " \"demean_timing\": demean_timing,\n", + " }\n", + " )\n", + " )\n", + "\n", + " return pd.concat(res, axis=1).T" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "execution": { + "iopub.execute_input": "2025-01-09T00:28:43.818536Z", + "iopub.status.busy": "2025-01-09T00:28:43.818246Z", + "iopub.status.idle": "2025-01-09T00:28:51.489202Z", + "shell.execute_reply": "2025-01-09T00:28:51.488591Z", + "shell.execute_reply.started": "2025-01-09T00:28:43.818520Z" + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
methodsolverdemeaner_backendn_obskGrepfull_feols_timingdemean_timing
0feolsnp.linalg.lstsqnumba10001110.150473NaN
1feolsnp.linalg.lstsqnumba10001120.147583NaN
2feolsnp.linalg.lstsqnumba10001130.186491NaN
3feolsnp.linalg.lstsqnumba10001140.190972NaN
4feolsnp.linalg.lstsqnumba10001150.162773NaN
5feolsnp.linalg.lstsqnumba10001160.171777NaN
6feolsnp.linalg.lstsqnumba10001170.166872NaN
7feolsnp.linalg.lstsqnumba10001180.158694NaN
8feolsnp.linalg.lstsqnumba10001190.185547NaN
9feolsnp.linalg.lstsqnumba100011100.158114NaN
\n", + "
" + ], + "text/plain": [ + " method solver demeaner_backend n_obs k G rep full_feols_timing \\\n", + "0 feols np.linalg.lstsq numba 1000 1 1 1 0.150473 \n", + "1 feols np.linalg.lstsq numba 1000 1 1 2 0.147583 \n", + "2 feols np.linalg.lstsq numba 1000 1 1 3 0.186491 \n", + "3 feols np.linalg.lstsq numba 1000 1 1 4 0.190972 \n", + "4 feols np.linalg.lstsq numba 1000 1 1 5 0.162773 \n", + "5 feols np.linalg.lstsq numba 1000 1 1 6 0.171777 \n", + "6 feols np.linalg.lstsq numba 1000 1 1 7 0.166872 \n", + "7 feols np.linalg.lstsq numba 1000 1 1 8 0.158694 \n", + "8 feols np.linalg.lstsq numba 1000 1 1 9 0.185547 \n", + "9 feols np.linalg.lstsq numba 1000 1 1 10 0.158114 \n", + "\n", + " demean_timing \n", + "0 NaN \n", + "1 NaN \n", + "2 NaN \n", + "3 NaN \n", + "4 NaN \n", + "5 NaN \n", + "6 NaN \n", + "7 NaN \n", + "8 NaN \n", + "9 NaN " + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# test run numba\n", + "run_standard_benchmark(fixed_effect=\"dum_1\", demeaner_backend=\"numba\", size=1, k=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "execution": { + "iopub.execute_input": "2025-01-09T00:28:43.818536Z", + "iopub.status.busy": "2025-01-09T00:28:43.818246Z", + "iopub.status.idle": "2025-01-09T00:28:51.489202Z", + "shell.execute_reply": "2025-01-09T00:28:51.488591Z", + "shell.execute_reply.started": "2025-01-09T00:28:43.818520Z" + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
methodsolverdemeaner_backendn_obskGrepfull_feols_timingdemean_timing
0feolsnp.linalg.lstsqjax10001110.122831NaN
1feolsnp.linalg.lstsqjax10001120.122887NaN
2feolsnp.linalg.lstsqjax10001130.136041NaN
3feolsnp.linalg.lstsqjax10001140.139644NaN
4feolsnp.linalg.lstsqjax10001150.136235NaN
5feolsnp.linalg.lstsqjax10001160.122477NaN
6feolsnp.linalg.lstsqjax10001170.123122NaN
7feolsnp.linalg.lstsqjax10001180.119589NaN
8feolsnp.linalg.lstsqjax10001190.122247NaN
9feolsnp.linalg.lstsqjax100011100.118353NaN
\n", + "
" + ], + "text/plain": [ + " method solver demeaner_backend n_obs k G rep full_feols_timing \\\n", + "0 feols np.linalg.lstsq jax 1000 1 1 1 0.122831 \n", + "1 feols np.linalg.lstsq jax 1000 1 1 2 0.122887 \n", + "2 feols np.linalg.lstsq jax 1000 1 1 3 0.136041 \n", + "3 feols np.linalg.lstsq jax 1000 1 1 4 0.139644 \n", + "4 feols np.linalg.lstsq jax 1000 1 1 5 0.136235 \n", + "5 feols np.linalg.lstsq jax 1000 1 1 6 0.122477 \n", + "6 feols np.linalg.lstsq jax 1000 1 1 7 0.123122 \n", + "7 feols np.linalg.lstsq jax 1000 1 1 8 0.119589 \n", + "8 feols np.linalg.lstsq jax 1000 1 1 9 0.122247 \n", + "9 feols np.linalg.lstsq jax 1000 1 1 10 0.118353 \n", + "\n", + " demean_timing \n", + "0 NaN \n", + "1 NaN \n", + "2 NaN \n", + "3 NaN \n", + "4 NaN \n", + "5 NaN \n", + "6 NaN \n", + "7 NaN \n", + "8 NaN \n", + "9 NaN " + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# test run jax\n", + "run_standard_benchmark(fixed_effect=\"dum_1\", demeaner_backend=\"jax\", size=1, k=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "execution": { + "iopub.execute_input": "2025-01-09T00:25:39.248695Z", + "iopub.status.busy": "2025-01-09T00:25:39.248317Z", + "iopub.status.idle": "2025-01-09T00:25:39.253652Z", + "shell.execute_reply": "2025-01-09T00:25:39.253000Z", + "shell.execute_reply.started": "2025-01-09T00:25:39.248671Z" + }, + "id": "HxsGRMSlR_jI" + }, + "outputs": [], + "source": [ + "def run_all_benchmarks(size_list, k_list):\n", + " \"\"\"\n", + " Run all the benchmarks.\n", + "\n", + " Args:\n", + " size_list (list): The list of sizes to run the benchmarks on. 1-> 1000, 2-> 10000, ..., 5-> 10_000_000\n", + " k_list (list): The list of k values to run the benchmarks on.\n", + " \"\"\"\n", + " res = pd.DataFrame()\n", + "\n", + " all_combinations = list(\n", + " product(\n", + " [\"numba\", \"jax\"], # demeaner_backend\n", + " [\"dum_1\", \"dum_1+dum_2\", \"dum_1+dum_2+dum_3\"], # fixef\n", + " size_list, # size\n", + " k_list, # k\n", + " [\"np.linalg.lstsq\"], # solver\n", + " )\n", + " )\n", + "\n", + " with tqdm(total=len(all_combinations), desc=\"Running Benchmarks\") as pbar:\n", + " for demeaner_backend, fixef, size, k, solver in all_combinations:\n", + " res = pd.concat(\n", + " [\n", + " res,\n", + " run_standard_benchmark(\n", + " solver=solver,\n", + " fixed_effect=fixef,\n", + " demeaner_backend=demeaner_backend,\n", + " size=size,\n", + " k=k,\n", + " ),\n", + " ],\n", + " axis=0,\n", + " )\n", + " pbar.update(1) # Update the progress bar after each iteration\n", + "\n", + " return res" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Run Benchmarks" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "colab": { + "background_save": true, + "base_uri": "https://localhost:8080/" + }, + "execution": { + "iopub.execute_input": "2025-01-09T00:25:39.962554Z", + "iopub.status.busy": "2025-01-09T00:25:39.962039Z", + "iopub.status.idle": "2025-01-09T00:26:25.319310Z", + "shell.execute_reply": "2025-01-09T00:26:25.318687Z", + "shell.execute_reply.started": "2025-01-09T00:25:39.962536Z" + }, + "id": "gki1mlqvSEIi", + "outputId": "3cb40095-df81-4e78-99a6-2410da237884" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Running Benchmarks: 100%|██████████| 24/24 [00:49<00:00, 2.06s/it]\n" + ] + } + ], + "source": [ + "res_all = run_all_benchmarks(\n", + " size_list=[1, 2, 3, 4, 5], # for N = 1000, 10_000, 100_000, 1_000_000, 10_000_000\n", + " k_list=[1, 10, 50, 100], # for k = 1, 10, 50, 100\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "colab": { + "background_save": true, + "base_uri": "https://localhost:8080/" + }, + "execution": { + "iopub.execute_input": "2025-01-09T00:26:25.320894Z", + "iopub.status.busy": "2025-01-09T00:26:25.320596Z", + "iopub.status.idle": "2025-01-09T00:26:26.391854Z", + "shell.execute_reply": "2025-01-09T00:26:26.391236Z", + "shell.execute_reply.started": "2025-01-09T00:26:25.320871Z" + }, + "id": "7zEIHj5nXXvq", + "outputId": "238dedce-a9f5-4a1b-b9a2-b84d4a89c091" + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
methoddemeaner_backendkGn_obsfull_feols_timingdemean_timing
0feolsjax1110000.127274NaN
1feolsjax11100000.142241NaN
2feolsjax1210000.134405NaN
3feolsjax12100000.168555NaN
4feolsjax1310000.139142NaN
5feolsjax13100000.189492NaN
6feolsjax5110000.144933NaN
7feolsjax51100000.147049NaN
8feolsjax5210000.147216NaN
9feolsjax52100000.175394NaN
10feolsjax5310000.15034NaN
11feolsjax53100000.194496NaN
12feolsnumba1110000.153849NaN
13feolsnumba11100000.183697NaN
14feolsnumba1210000.17201NaN
15feolsnumba12100000.16131NaN
16feolsnumba1310000.169463NaN
17feolsnumba13100000.169789NaN
18feolsnumba5110000.331457NaN
19feolsnumba51100000.172194NaN
20feolsnumba5210000.162562NaN
21feolsnumba52100000.184233NaN
22feolsnumba5310000.174623NaN
23feolsnumba53100000.166674NaN
\n", + "
" + ], + "text/plain": [ + " method demeaner_backend k G n_obs full_feols_timing demean_timing\n", + "0 feols jax 1 1 1000 0.127274 NaN\n", + "1 feols jax 1 1 10000 0.142241 NaN\n", + "2 feols jax 1 2 1000 0.134405 NaN\n", + "3 feols jax 1 2 10000 0.168555 NaN\n", + "4 feols jax 1 3 1000 0.139142 NaN\n", + "5 feols jax 1 3 10000 0.189492 NaN\n", + "6 feols jax 5 1 1000 0.144933 NaN\n", + "7 feols jax 5 1 10000 0.147049 NaN\n", + "8 feols jax 5 2 1000 0.147216 NaN\n", + "9 feols jax 5 2 10000 0.175394 NaN\n", + "10 feols jax 5 3 1000 0.15034 NaN\n", + "11 feols jax 5 3 10000 0.194496 NaN\n", + "12 feols numba 1 1 1000 0.153849 NaN\n", + "13 feols numba 1 1 10000 0.183697 NaN\n", + "14 feols numba 1 2 1000 0.17201 NaN\n", + "15 feols numba 1 2 10000 0.16131 NaN\n", + "16 feols numba 1 3 1000 0.169463 NaN\n", + "17 feols numba 1 3 10000 0.169789 NaN\n", + "18 feols numba 5 1 1000 0.331457 NaN\n", + "19 feols numba 5 1 10000 0.172194 NaN\n", + "20 feols numba 5 2 1000 0.162562 NaN\n", + "21 feols numba 5 2 10000 0.184233 NaN\n", + "22 feols numba 5 3 1000 0.174623 NaN\n", + "23 feols numba 5 3 10000 0.166674 NaN" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df = (\n", + " res_all.drop([\"rep\", \"solver\"], axis=1)\n", + " .groupby([\"method\", \"demeaner_backend\", \"k\", \"G\", \"n_obs\"])\n", + " .mean()\n", + " .reset_index()\n", + ")\n", + "df" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab": { + "background_save": true + }, + "id": "VCn6O5MMXlBw" + }, + "source": [ + "## Visualize" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:matplotlib.category:Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.\n", + "INFO:matplotlib.category:Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.\n", + "INFO:matplotlib.category:Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.\n", + "INFO:matplotlib.category:Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.\n", + "INFO:matplotlib.category:Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.\n", + "INFO:matplotlib.category:Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.\n", + "INFO:matplotlib.category:Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.\n", + "INFO:matplotlib.category:Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.\n", + "INFO:matplotlib.category:Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.\n", + "INFO:matplotlib.category:Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.\n", + "INFO:matplotlib.category:Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.\n", + "INFO:matplotlib.category:Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "df[\"G\"] = df[\"G\"].map({1: \"n_fixef = 1\", 2: \"n_fixef = 2\", 3: \"n_fixef = 3\"})\n", + "df[\"n_obs\"] = df[\"n_obs\"].astype(str)\n", + "\n", + "# Dynamically determine unique values for order and hue_order\n", + "n_obs_order = sorted(df[\"n_obs\"].unique(), key=lambda x: int(x)) # Sort as integers\n", + "demeaner_backend_order = df[\"demeaner_backend\"].unique()\n", + "\n", + "custom_palette = sns.color_palette(\"coolwarm\", n_colors=2)\n", + "\n", + "# Create the FacetGrid with reordered columns and rows\n", + "g = sns.FacetGrid(\n", + " df,\n", + " col=\"G\", # G (n_fixef) increases left to right\n", + " row=\"k\", # k increases top to bottom\n", + " margin_titles=True,\n", + " height=4,\n", + " aspect=1.2,\n", + " col_order=[\"n_fixef = 1\", \"n_fixef = 2\", \"n_fixef = 3\"], # Ensure correct order\n", + ")\n", + "\n", + "# Plot the bar chart for each facet with the custom palette\n", + "g.map(\n", + " sns.barplot,\n", + " \"n_obs\",\n", + " \"full_feols_timing\",\n", + " \"demeaner_backend\",\n", + " order=n_obs_order, # Dynamic order for n_obs\n", + " hue_order=demeaner_backend_order, # Dynamic hue order for demeaner_backend\n", + " errorbar=None, # Suppress error bars\n", + " palette=custom_palette,\n", + ")\n", + "\n", + "# Add legend and adjust layout\n", + "g.add_legend(title=\"Demeaner Backend\")\n", + "g.set_axis_labels(\"Number of Observations\", \"Runtime (seconds)\")\n", + "g.set_titles(row_template=\"k = {row_name}\", col_template=\"{col_name}\")\n", + "plt.subplots_adjust(top=0.9)\n", + "g.fig.suptitle(\"Runtime vs Number of Observations by n_fixef and k\")\n", + "\n", + "# Show plot\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "jax", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.8" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} From ffd3217020c6fe6ee4fb4eb0831fdfbbf9dfd597 Mon Sep 17 00:00:00 2001 From: Alexander Fischer Date: Tue, 14 Jan 2025 21:43:54 +0100 Subject: [PATCH 2/6] add jax module, JAXOLS --- benchmarks/gpu_pyfixest_errors.ipynb | 2 +- pyfixest/estimation/demean_.py | 2 +- pyfixest/estimation/jax/OLSJAX.py | 145 ++++++++++++++++++ pyfixest/estimation/{ => jax}/demean_jax_.py | 9 +- .../{ => jax}/detect_singletons_jax.py | 0 tests/test_demean.py | 2 +- tests/test_detect_singletons.py | 2 +- 7 files changed, 157 insertions(+), 5 deletions(-) create mode 100644 pyfixest/estimation/jax/OLSJAX.py rename pyfixest/estimation/{ => jax}/demean_jax_.py (92%) rename pyfixest/estimation/{ => jax}/detect_singletons_jax.py (100%) diff --git a/benchmarks/gpu_pyfixest_errors.ipynb b/benchmarks/gpu_pyfixest_errors.ipynb index 3f6ee1481..945569cd9 100644 --- a/benchmarks/gpu_pyfixest_errors.ipynb +++ b/benchmarks/gpu_pyfixest_errors.ipynb @@ -194,7 +194,7 @@ "\n", "import pyfixest as pf\n", "from pyfixest.estimation.demean_ import demean\n", - "from pyfixest.estimation.demean_jax_ import demean_jax" + "from pyfixest.estimation.jax.demean_jax_ import demean_jax" ] }, { diff --git a/pyfixest/estimation/demean_.py b/pyfixest/estimation/demean_.py index 5987dde43..ecff0b5aa 100644 --- a/pyfixest/estimation/demean_.py +++ b/pyfixest/estimation/demean_.py @@ -329,7 +329,7 @@ def _set_demeaner_backend(demeaner_backend: Literal["numba", "jax"]) -> Callable if demeaner_backend == "numba": return demean elif demeaner_backend == "jax": - from pyfixest.estimation.demean_jax_ import demean_jax + from pyfixest.estimation.jax.demean_jax_ import demean_jax return demean_jax else: diff --git a/pyfixest/estimation/jax/OLSJAX.py b/pyfixest/estimation/jax/OLSJAX.py new file mode 100644 index 000000000..bde7e29df --- /dev/null +++ b/pyfixest/estimation/jax/OLSJAX.py @@ -0,0 +1,145 @@ +import jax +import jax.numpy as jnp +from pyfixest.estimation.jax.demean_jax_ import demean_jax +from typing import Optional +import pandas as pd + +class OLSJAX: + + def __init__(self, X: jax.Array, Y: jax.Array, fe: Optional[jax.Array] = None, vcov: str = 'iid'): + + self.X_orignal = X + self.Y_orignal = Y + self.fe = fe + self.N = X.shape[0] + self.k = X.shape[1] + self.weights = jnp.ones(self.N) + self.vcov_type = vcov + + def fit(self): + + self.Y, self.X = self.demean(Y = self.Y_orignal, X = self.X_orignal, fe = self.fe, weights = self.weights) + self.beta = jnp.linalg.lstsq(self.X, self.Y)[0] + self.residuals + self.scores + self.vcov(vcov_type = self.vcov_type) + self.inference() + + def tidy(self): + + lb = 0.025 + ub = 0.975 + + return pd.DataFrame( + { + #"Coefficient": _coefnames, + "Estimate": self.beta.flatten(), + "Std. Error": self.se.flatten(), + "t value": self.tstat.flatten(), + "Pr(>|t|)": self.pvalue.flatten(), + f"{lb * 100:.1f}%": self.confint[:,0].flatten(), + f"{ub * 100:.1f}%": self.confint[:,1].flatten(), + } + ) + + @property + def residuals(self): + self.uhat = self.Y - self.X @ self.beta + return self.uhat + + def vcov(self, vcov_type: str): + + bread = self.bread + meat = self.meat(type = vcov_type) + if vcov_type == 'iid': + self.vcov = bread * meat + else: + self.vcov = bread @ meat @ bread + + return self.vcov + + @property + def bread(self): + return jnp.linalg.inv(self.X.T @ self.X) + + @property + def leverage(self): + return jnp.sum(self.X * (self.X @ jnp.linalg.inv(self.X.T @ self.X)), axis=1) + + @property + def scores(self): + return self.X * self.residuals + + def meat(self, type: str): + + if type == 'iid': + return self.meat_iid + elif type == 'HC1': + return self.meat_hc1 + elif type == 'HC2': + return self.meat_hc2 + elif type == 'HC3': + return self.meat_hc3 + elif type == 'CRV1': + return self.meat_crv1 + else: + raise ValueError("Invalid type") + + @property + def meat_iid(self): + return jnp.sum(self.uhat ** 2) / (self.N - self.k) + + @property + def meat_hc1(self): + + return self.scores.T @ self.scores + + def meat_hc2(self): + self.leverage + transformed_scores = self.scores / jnp.sqrt(1 - self.leverage) + return transformed_scores.T @ transformed_scores + + def meat_hc3(self): + self.leverage + transformed_scores = self.scores / (1 - self.leverage) + return transformed_scores.T @ transformed_scores + + @property + def meat_crv1(self): + raise NotImplementedError("CRV1 is not implemented") + + def predict(self, X): + X = jnp.array(X) + return X @ self.beta + + def demean(self, Y: jax.Array, X: jax.Array, fe: jax.Array, weights: jax.Array): + + if fe is not None: + if not jnp.issubdtype(fe.dtype, jnp.integer): + raise ValueError("Fixed effects must be integers") + + YX = jnp.concatenate((Y, X), axis = 1) + YXd, success = demean_jax(x = YX, flist = fe, weights = self.weights, output = "jax") + Yd = YXd[:, 0].reshape(-1, 1) + Xd = YXd[:, 1:] + + return Yd, Xd + + else: + + return Y, X + + def inference(self): + + self.se = jnp.sqrt(jnp.diag(self.vcov)).reshape(-1, 1) + self.tstat = self.beta / self.se + self.pvalue = 2 * (1 - jax.scipy.stats.norm.cdf(jnp.abs(self.tstat))) + self.confint = jnp.column_stack( + [ + self.beta - jax.scipy.stats.norm.ppf(1 - 0.05 / 2) * self.se, + self.beta + jax.scipy.stats.norm.ppf(1 - 0.05 / 2) * self.se, + ] + ) + + + return self.se, self.tstat, self.pvalue, self.confint diff --git a/pyfixest/estimation/demean_jax_.py b/pyfixest/estimation/jax/demean_jax_.py similarity index 92% rename from pyfixest/estimation/demean_jax_.py rename to pyfixest/estimation/jax/demean_jax_.py index 2d327fb2d..0c27ec0f3 100644 --- a/pyfixest/estimation/demean_jax_.py +++ b/pyfixest/estimation/jax/demean_jax_.py @@ -72,6 +72,7 @@ def demean_jax( weights: np.ndarray, tol: float = 1e-08, maxiter: int = 100_000, + output: str = "numpy", ) -> tuple[np.ndarray, bool]: """Fast and reliable JAX implementation with static shapes.""" # Enable float64 precision @@ -89,4 +90,10 @@ def demean_jax( result_jax, converged = _demean_jax_impl( x_jax, flist_jax, weights_jax, n_groups, tol, maxiter ) - return np.array(result_jax), converged + + if output == "numpy": + return np.array(result_jax), converged + elif output == "jax": + return result_jax, converged + else: + raise ValueError("Invalid output type") diff --git a/pyfixest/estimation/detect_singletons_jax.py b/pyfixest/estimation/jax/detect_singletons_jax.py similarity index 100% rename from pyfixest/estimation/detect_singletons_jax.py rename to pyfixest/estimation/jax/detect_singletons_jax.py diff --git a/tests/test_demean.py b/tests/test_demean.py index 973e5958e..a4c08aeed 100644 --- a/tests/test_demean.py +++ b/tests/test_demean.py @@ -4,7 +4,7 @@ import pytest from pyfixest.estimation.demean_ import _set_demeaner_backend, demean, demean_model -from pyfixest.estimation.demean_jax_ import demean_jax +from pyfixest.estimation.jax.demean_jax_ import demean_jax @pytest.mark.parametrize( diff --git a/tests/test_detect_singletons.py b/tests/test_detect_singletons.py index d6c2af329..8f5480c37 100644 --- a/tests/test_detect_singletons.py +++ b/tests/test_detect_singletons.py @@ -2,7 +2,7 @@ import pytest from pyfixest.estimation.detect_singletons_ import detect_singletons -from pyfixest.estimation.detect_singletons_jax import detect_singletons_jax +from pyfixest.jax.detect_singletons_jax import detect_singletons_jax input1 = np.array([[0, 2, 1], [0, 2, 1], [0, 1, 3], [0, 1, 2], [0, 1, 2]]) solution1 = np.array([False, False, True, False, False]) From 80bd6a4ce2cb67bfd89811645eda55e5ac698586 Mon Sep 17 00:00:00 2001 From: Alexander Fischer Date: Tue, 14 Jan 2025 21:44:55 +0100 Subject: [PATCH 3/6] add jaxols --- pyfixest/estimation/jax/OLSJAX.py | 58 ++++++++++++++++--------------- 1 file changed, 30 insertions(+), 28 deletions(-) diff --git a/pyfixest/estimation/jax/OLSJAX.py b/pyfixest/estimation/jax/OLSJAX.py index bde7e29df..2458bbfa3 100644 --- a/pyfixest/estimation/jax/OLSJAX.py +++ b/pyfixest/estimation/jax/OLSJAX.py @@ -1,13 +1,20 @@ +from typing import Optional + import jax import jax.numpy as jnp -from pyfixest.estimation.jax.demean_jax_ import demean_jax -from typing import Optional import pandas as pd -class OLSJAX: +from pyfixest.estimation.jax.demean_jax_ import demean_jax - def __init__(self, X: jax.Array, Y: jax.Array, fe: Optional[jax.Array] = None, vcov: str = 'iid'): +class OLSJAX: + def __init__( + self, + X: jax.Array, + Y: jax.Array, + fe: Optional[jax.Array] = None, + vcov: str = "iid", + ): self.X_orignal = X self.Y_orignal = Y self.fe = fe @@ -17,28 +24,28 @@ def __init__(self, X: jax.Array, Y: jax.Array, fe: Optional[jax.Array] = None, v self.vcov_type = vcov def fit(self): - - self.Y, self.X = self.demean(Y = self.Y_orignal, X = self.X_orignal, fe = self.fe, weights = self.weights) + self.Y, self.X = self.demean( + Y=self.Y_orignal, X=self.X_orignal, fe=self.fe, weights=self.weights + ) self.beta = jnp.linalg.lstsq(self.X, self.Y)[0] self.residuals self.scores - self.vcov(vcov_type = self.vcov_type) + self.vcov(vcov_type=self.vcov_type) self.inference() def tidy(self): - lb = 0.025 ub = 0.975 return pd.DataFrame( { - #"Coefficient": _coefnames, + # "Coefficient": _coefnames, "Estimate": self.beta.flatten(), "Std. Error": self.se.flatten(), "t value": self.tstat.flatten(), "Pr(>|t|)": self.pvalue.flatten(), - f"{lb * 100:.1f}%": self.confint[:,0].flatten(), - f"{ub * 100:.1f}%": self.confint[:,1].flatten(), + f"{lb * 100:.1f}%": self.confint[:, 0].flatten(), + f"{ub * 100:.1f}%": self.confint[:, 1].flatten(), } ) @@ -48,10 +55,9 @@ def residuals(self): return self.uhat def vcov(self, vcov_type: str): - bread = self.bread - meat = self.meat(type = vcov_type) - if vcov_type == 'iid': + meat = self.meat(type=vcov_type) + if vcov_type == "iid": self.vcov = bread * meat else: self.vcov = bread @ meat @ bread @@ -71,27 +77,25 @@ def scores(self): return self.X * self.residuals def meat(self, type: str): - - if type == 'iid': + if type == "iid": return self.meat_iid - elif type == 'HC1': + elif type == "HC1": return self.meat_hc1 - elif type == 'HC2': + elif type == "HC2": return self.meat_hc2 - elif type == 'HC3': + elif type == "HC3": return self.meat_hc3 - elif type == 'CRV1': + elif type == "CRV1": return self.meat_crv1 else: raise ValueError("Invalid type") @property def meat_iid(self): - return jnp.sum(self.uhat ** 2) / (self.N - self.k) + return jnp.sum(self.uhat**2) / (self.N - self.k) @property def meat_hc1(self): - return self.scores.T @ self.scores def meat_hc2(self): @@ -113,24 +117,23 @@ def predict(self, X): return X @ self.beta def demean(self, Y: jax.Array, X: jax.Array, fe: jax.Array, weights: jax.Array): - if fe is not None: if not jnp.issubdtype(fe.dtype, jnp.integer): raise ValueError("Fixed effects must be integers") - YX = jnp.concatenate((Y, X), axis = 1) - YXd, success = demean_jax(x = YX, flist = fe, weights = self.weights, output = "jax") + YX = jnp.concatenate((Y, X), axis=1) + YXd, success = demean_jax( + x=YX, flist=fe, weights=self.weights, output="jax" + ) Yd = YXd[:, 0].reshape(-1, 1) Xd = YXd[:, 1:] return Yd, Xd else: - return Y, X def inference(self): - self.se = jnp.sqrt(jnp.diag(self.vcov)).reshape(-1, 1) self.tstat = self.beta / self.se self.pvalue = 2 * (1 - jax.scipy.stats.norm.cdf(jnp.abs(self.tstat))) @@ -141,5 +144,4 @@ def inference(self): ] ) - return self.se, self.tstat, self.pvalue, self.confint From 62da934c4bdb26ce763cbf6132983fc71f03b3d9 Mon Sep 17 00:00:00 2001 From: Alexander Fischer Date: Wed, 15 Jan 2025 20:41:55 +0100 Subject: [PATCH 4/6] interface --- pyfixest/estimation/FixestMulti_.py | 11 +++- pyfixest/estimation/jax/OLSJAX.py | 48 ++++++++------- pyfixest/estimation/jax/olsjax_interface.py | 65 +++++++++++++++++++++ 3 files changed, 102 insertions(+), 22 deletions(-) create mode 100644 pyfixest/estimation/jax/olsjax_interface.py diff --git a/pyfixest/estimation/FixestMulti_.py b/pyfixest/estimation/FixestMulti_.py index 1602d9072..275e451de 100644 --- a/pyfixest/estimation/FixestMulti_.py +++ b/pyfixest/estimation/FixestMulti_.py @@ -8,6 +8,7 @@ from pyfixest.estimation.fegaussian_ import Fegaussian from pyfixest.estimation.feiv_ import Feiv from pyfixest.estimation.felogit_ import Felogit +from pyfixest.estimation.jax.olsjax_interface import OLSJAX_API from pyfixest.estimation.feols_ import Feols, _check_vcov_input, _deparse_vcov_input from pyfixest.estimation.feols_compressed_ import FeolsCompressed from pyfixest.estimation.fepois_ import Fepois @@ -16,7 +17,6 @@ from pyfixest.utils.dev_utils import DataFrameType, _narwhals_to_pandas from pyfixest.utils.utils import capture_context - class FixestMulti: """A class to estimate multiple regression models with fixed effects.""" @@ -279,7 +279,11 @@ def _estimate_all_models( FIT: Union[Feols, Feiv, Fepois] if _method == "feols" and not _is_iv: - FIT = Feols( + + backend = "jax" + OLSCLASS = Feols if backend != "jax" else OLSJAX_API + + FIT = OLSCLASS( FixestFormula=FixestFormula, data=_data, ssc_dict=_ssc_dict, @@ -299,11 +303,14 @@ def _estimate_all_models( sample_split_value=sample_split_value, sample_split_var=_splitvar, ) + FIT.prepare_model_matrix() FIT.demean() FIT.to_array() FIT.drop_multicol_vars() FIT.wls_transform() + + elif _method == "feols" and _is_iv: FIT = Feiv( FixestFormula=FixestFormula, diff --git a/pyfixest/estimation/jax/OLSJAX.py b/pyfixest/estimation/jax/OLSJAX.py index 2458bbfa3..ed293da7b 100644 --- a/pyfixest/estimation/jax/OLSJAX.py +++ b/pyfixest/estimation/jax/OLSJAX.py @@ -6,22 +6,46 @@ from pyfixest.estimation.jax.demean_jax_ import demean_jax - class OLSJAX: def __init__( self, X: jax.Array, Y: jax.Array, fe: Optional[jax.Array] = None, - vcov: str = "iid", + weights: Optional[jax.Array] = None, + vcov: Optional[str, dict[str,str]] = None, ): + + """ + Class to run OLS regression in JAX. + + Parameters + ---------- + X : jax.Array + N x k matrix of independent variables. + Y : jax.Array + Dependent variable. N x 1 matrix. + fe : jax.Array, optional + Fixed effects. N x 1 matrix of integers. The default is None. + weights: jax.Array, optional + Weights. N x 1 matrix. The default is None. + vcov : str, optional + Type of covariance matrix. The default is None. Options are: + - "iid" (default): iid errors + - "HC1": heteroskedasticity robust + - "HC2": heteroskedasticity robust + - "HC3": heteroskedasticity robust + - "CRV1": cluster robust. In this case, please provide a dictionary + with the cluster variable as key and the name of the cluster variable as value. + """ + self.X_orignal = X self.Y_orignal = Y self.fe = fe self.N = X.shape[0] self.k = X.shape[1] - self.weights = jnp.ones(self.N) - self.vcov_type = vcov + self.weights = jnp.ones(self.N) if weights is None else weights + self.vcov_type = "iid" if vcov is None else vcov def fit(self): self.Y, self.X = self.demean( @@ -33,22 +57,6 @@ def fit(self): self.vcov(vcov_type=self.vcov_type) self.inference() - def tidy(self): - lb = 0.025 - ub = 0.975 - - return pd.DataFrame( - { - # "Coefficient": _coefnames, - "Estimate": self.beta.flatten(), - "Std. Error": self.se.flatten(), - "t value": self.tstat.flatten(), - "Pr(>|t|)": self.pvalue.flatten(), - f"{lb * 100:.1f}%": self.confint[:, 0].flatten(), - f"{ub * 100:.1f}%": self.confint[:, 1].flatten(), - } - ) - @property def residuals(self): self.uhat = self.Y - self.X @ self.beta diff --git a/pyfixest/estimation/jax/olsjax_interface.py b/pyfixest/estimation/jax/olsjax_interface.py new file mode 100644 index 000000000..a78b65d22 --- /dev/null +++ b/pyfixest/estimation/jax/olsjax_interface.py @@ -0,0 +1,65 @@ +from pyfixest.estimation.feols_ import Feols +from pyfixest.estimation.feols_ import Feols, _drop_multicollinear_variables +from pyfixest.estimation.FormulaParser import FixestFormula +import pandas as pd +from typing import Union, Optional, Mapping, Any, Literal +import jax.numpy as jnp + +class OLSJAX_API(Feols): + + def __init__( + self, + FixestFormula: FixestFormula, + data: pd.DataFrame, + ssc_dict: dict[str, Union[str, bool]], + drop_singletons: bool, + drop_intercept: bool, + weights: Optional[str], + weights_type: Optional[str], + collin_tol: float, + fixef_tol: float, + lookup_demeaned_data: dict[str, pd.DataFrame], + solver: Literal[ + "np.linalg.lstsq", "np.linalg.solve", "scipy.sparse.linalg.lsqr", "jax" + ] = "np.linalg.solve", + demeaner_backend: Literal["numba", "jax"] = "numba", + store_data: bool = True, + copy_data: bool = True, + lean: bool = False, + context: Union[int, Mapping[str, Any]] = 0, + sample_split_var: Optional[str] = None, + sample_split_value: Optional[Union[str, int]] = None, + ) -> None: + super().__init__( + FixestFormula=FixestFormula, + data=data, + ssc_dict=ssc_dict, + drop_singletons=drop_singletons, + drop_intercept=drop_intercept, + weights=weights, + weights_type=weights_type, + collin_tol=collin_tol, + fixef_tol=fixef_tol, + lookup_demeaned_data=lookup_demeaned_data, + solver=solver, + store_data=store_data, + copy_data=copy_data, + lean=lean, + sample_split_var=sample_split_var, + sample_split_value=sample_split_value, + context=context, + demeaner_backend=demeaner_backend, + ) + + + + def to_array(self): + + """ + Convert all relevant data to JAX arrays. + """ + + self._Y = jnp.asarray(self._Y) + self._X = jnp.asarray(self._X) + self._fe = jnp.asarray(self._fe) + self._weights = jnp.asarray(self._weights) From 2c92c3d4d740cee0033f6554bef3fa8bd39530b3 Mon Sep 17 00:00:00 2001 From: Alexander Fischer Date: Sun, 19 Jan 2025 15:19:49 +0100 Subject: [PATCH 5/6] POC hacky pure JAX OLS class via interface --- pyfixest/estimation/FixestMulti_.py | 38 ++++++++----- pyfixest/estimation/feols_.py | 1 + pyfixest/estimation/jax/OLSJAX.py | 28 ++++------ pyfixest/estimation/jax/olsjax_interface.py | 60 ++++++++++++++++++--- 4 files changed, 86 insertions(+), 41 deletions(-) diff --git a/pyfixest/estimation/FixestMulti_.py b/pyfixest/estimation/FixestMulti_.py index 275e451de..62c9666d9 100644 --- a/pyfixest/estimation/FixestMulti_.py +++ b/pyfixest/estimation/FixestMulti_.py @@ -304,11 +304,12 @@ def _estimate_all_models( sample_split_var=_splitvar, ) - FIT.prepare_model_matrix() - FIT.demean() - FIT.to_array() - FIT.drop_multicol_vars() - FIT.wls_transform() + if backend != "jax": + FIT.prepare_model_matrix() + FIT.demean() + FIT.to_array() + FIT.drop_multicol_vars() + FIT.wls_transform() elif _method == "feols" and _is_iv: @@ -337,6 +338,7 @@ def _estimate_all_models( FIT.to_array() FIT.drop_multicol_vars() FIT.wls_transform() + elif _method == "fepois": FIT = Fepois( FixestFormula=FixestFormula, @@ -482,17 +484,25 @@ def _estimate_all_models( FIT.get_fit() # if X is empty: no inference (empty X only as shorthand for demeaning) - if not FIT._X_is_empty: - # inference - vcov_type = _get_vcov_type(vcov, fval) - FIT.vcov(vcov=vcov_type, data=FIT._data) + if backend != "jax": + if not FIT._X_is_empty: + # inference + vcov_type = _get_vcov_type(vcov, fval) + FIT.vcov(vcov=vcov_type, data=FIT._data) + + FIT.get_inference() + # other regression stats + if _method == "feols" and not FIT._is_iv: + FIT.get_performance() + if isinstance(FIT, Feiv): + FIT.first_stage() + + else: + #import pdb; pdb.set_trace() + FIT.vcov(type = "iid") + FIT.convert_attributes_to_numpy() FIT.get_inference() - # other regression stats - if _method == "feols" and not FIT._is_iv: - FIT.get_performance() - if isinstance(FIT, Feiv): - FIT.first_stage() # delete large attributescl FIT._clear_attributes() diff --git a/pyfixest/estimation/feols_.py b/pyfixest/estimation/feols_.py index 113751fb9..dd4e1e4f8 100644 --- a/pyfixest/estimation/feols_.py +++ b/pyfixest/estimation/feols_.py @@ -816,6 +816,7 @@ def get_inference(self, alpha: float = 0.05) -> None: ------- None """ + _vcov = self._vcov _beta_hat = self._beta_hat _vcov_type = self._vcov_type diff --git a/pyfixest/estimation/jax/OLSJAX.py b/pyfixest/estimation/jax/OLSJAX.py index ed293da7b..44606a337 100644 --- a/pyfixest/estimation/jax/OLSJAX.py +++ b/pyfixest/estimation/jax/OLSJAX.py @@ -13,7 +13,7 @@ def __init__( Y: jax.Array, fe: Optional[jax.Array] = None, weights: Optional[jax.Array] = None, - vcov: Optional[str, dict[str,str]] = None, + vcov: Optional[str] = None, ): """ @@ -52,15 +52,17 @@ def fit(self): Y=self.Y_orignal, X=self.X_orignal, fe=self.fe, weights=self.weights ) self.beta = jnp.linalg.lstsq(self.X, self.Y)[0] - self.residuals + self.get_fit() self.scores self.vcov(vcov_type=self.vcov_type) self.inference() + def get_fit(self): + self.beta = jnp.linalg.lstsq(self.X, self.Y)[0] + @property def residuals(self): - self.uhat = self.Y - self.X @ self.beta - return self.uhat + return self.Y - self.X @ self.beta def vcov(self, vcov_type: str): bread = self.bread @@ -100,7 +102,7 @@ def meat(self, type: str): @property def meat_iid(self): - return jnp.sum(self.uhat**2) / (self.N - self.k) + return jnp.sum(self.residuals**2) / (self.N - self.k) @property def meat_hc1(self): @@ -125,13 +127,14 @@ def predict(self, X): return X @ self.beta def demean(self, Y: jax.Array, X: jax.Array, fe: jax.Array, weights: jax.Array): + if fe is not None: if not jnp.issubdtype(fe.dtype, jnp.integer): raise ValueError("Fixed effects must be integers") YX = jnp.concatenate((Y, X), axis=1) YXd, success = demean_jax( - x=YX, flist=fe, weights=self.weights, output="jax" + x=YX, flist=fe, weights=weights, output="jax" ) Yd = YXd[:, 0].reshape(-1, 1) Xd = YXd[:, 1:] @@ -140,16 +143,3 @@ def demean(self, Y: jax.Array, X: jax.Array, fe: jax.Array, weights: jax.Array): else: return Y, X - - def inference(self): - self.se = jnp.sqrt(jnp.diag(self.vcov)).reshape(-1, 1) - self.tstat = self.beta / self.se - self.pvalue = 2 * (1 - jax.scipy.stats.norm.cdf(jnp.abs(self.tstat))) - self.confint = jnp.column_stack( - [ - self.beta - jax.scipy.stats.norm.ppf(1 - 0.05 / 2) * self.se, - self.beta + jax.scipy.stats.norm.ppf(1 - 0.05 / 2) * self.se, - ] - ) - - return self.se, self.tstat, self.pvalue, self.confint diff --git a/pyfixest/estimation/jax/olsjax_interface.py b/pyfixest/estimation/jax/olsjax_interface.py index a78b65d22..643b692a2 100644 --- a/pyfixest/estimation/jax/olsjax_interface.py +++ b/pyfixest/estimation/jax/olsjax_interface.py @@ -2,8 +2,10 @@ from pyfixest.estimation.feols_ import Feols, _drop_multicollinear_variables from pyfixest.estimation.FormulaParser import FixestFormula import pandas as pd +import numpy as np from typing import Union, Optional, Mapping, Any, Literal import jax.numpy as jnp +from pyfixest.estimation.jax.OLSJAX import OLSJAX class OLSJAX_API(Feols): @@ -51,15 +53,57 @@ def __init__( demeaner_backend=demeaner_backend, ) + self.prepare_model_matrix() + self.to_jax_array() + + # later to be set in multicoll method + self._N, self._k = self._X_jax.shape + + self.olsjax = OLSJAX( + X=self._X_jax, + Y=self._Y_jax, + fe=self._fe_jax, + weights=self._weights_jax, + vcov="iid", + ) + #import pdb; pdb.set_trace() + self.olsjax.Y, self.olsjax.X = self.olsjax.demean(Y = self._Y_jax, X = self._X_jax, fe = self._fe_jax, weights = self._weights_jax.flatten()) + + def to_jax_array(self): + + self._X_jax = jnp.array(self._X) + self._Y_jax = jnp.array(self._Y) + self._fe_jax = jnp.array(self._fe) + self._weights_jax = jnp.array(self._weights) + + + def get_fit(self): + + self.olsjax.get_fit() + self._beta_hat = self.olsjax.beta.flatten() + self._u_hat = self.olsjax.residuals + self._scores = self.olsjax.scores + + def vcov(self, type: str): + + self._vcov_type = type + self.olsjax.vcov(vcov_type=type) + self._vcov = self.olsjax.vcov + + return self + + def convert_attributes_to_numpy(self): + "Convert core attributes from jax to numpy arrays." + attr = ["_beta_hat", "_u_hat", "_scores", "_vcov"] + for a in attr: + # convert to numpy + setattr(self, a, np.array(getattr(self, a))) + + + + + - def to_array(self): - """ - Convert all relevant data to JAX arrays. - """ - self._Y = jnp.asarray(self._Y) - self._X = jnp.asarray(self._X) - self._fe = jnp.asarray(self._fe) - self._weights = jnp.asarray(self._weights) From 3f595cd32feca8d4d1d64496382b41532ebfb970 Mon Sep 17 00:00:00 2001 From: Alexander Fischer Date: Sun, 19 Jan 2025 15:31:42 +0100 Subject: [PATCH 6/6] delete GPU benchmarks --- benchmarks/gpu_pyfixest_errors.ipynb | 1478 -------------------------- 1 file changed, 1478 deletions(-) delete mode 100644 benchmarks/gpu_pyfixest_errors.ipynb diff --git a/benchmarks/gpu_pyfixest_errors.ipynb b/benchmarks/gpu_pyfixest_errors.ipynb deleted file mode 100644 index 945569cd9..000000000 --- a/benchmarks/gpu_pyfixest_errors.ipynb +++ /dev/null @@ -1,1478 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "execution": { - "iopub.execute_input": "2025-01-09T00:23:28.917621Z", - "iopub.status.busy": "2025-01-09T00:23:28.917332Z", - "iopub.status.idle": "2025-01-09T00:23:29.477701Z", - "shell.execute_reply": "2025-01-09T00:23:29.477193Z", - "shell.execute_reply.started": "2025-01-09T00:23:28.917602Z" - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "[CpuDevice(id=0)]" - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import jax\n", - "import jax.numpy as jnp\n", - "\n", - "jax.devices()" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "execution": { - "iopub.execute_input": "2025-01-09T00:23:30.540594Z", - "iopub.status.busy": "2025-01-09T00:23:30.540228Z", - "iopub.status.idle": "2025-01-09T00:23:30.739685Z", - "shell.execute_reply": "2025-01-09T00:23:30.739213Z", - "shell.execute_reply.started": "2025-01-09T00:23:30.540574Z" - }, - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "{CpuDevice(id=0)}" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "jnp.ones(10).devices()" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "execution": { - "iopub.execute_input": "2025-01-09T00:26:29.239253Z", - "iopub.status.busy": "2025-01-09T00:26:29.238947Z", - "iopub.status.idle": "2025-01-09T00:26:29.754752Z", - "shell.execute_reply": "2025-01-09T00:26:29.754158Z", - "shell.execute_reply.started": "2025-01-09T00:26:29.239235Z" - } - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "'nvidia-smi' is not recognized as an internal or external command,\n", - "operable program or batch file.\n" - ] - } - ], - "source": [ - "!nvidia-smi" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "execution": { - "iopub.execute_input": "2025-01-09T00:23:44.232335Z", - "iopub.status.busy": "2025-01-09T00:23:44.231984Z", - "iopub.status.idle": "2025-01-09T00:23:45.388035Z", - "shell.execute_reply": "2025-01-09T00:23:45.387587Z", - "shell.execute_reply.started": "2025-01-09T00:23:44.232312Z" - }, - "id": "fHzEldNvR2_K" - }, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "
\n", - " \n", - " " - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n", - "
\n", - " \n", - " " - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "import time\n", - "from itertools import product\n", - "\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import pandas as pd\n", - "import seaborn as sns\n", - "from scipy.stats import nbinom\n", - "from tqdm import tqdm\n", - "\n", - "import pyfixest as pf\n", - "from pyfixest.estimation.demean_ import demean\n", - "from pyfixest.estimation.jax.demean_jax_ import demean_jax" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "execution": { - "iopub.execute_input": "2025-01-09T00:23:46.290548Z", - "iopub.status.busy": "2025-01-09T00:23:46.289898Z", - "iopub.status.idle": "2025-01-09T00:23:46.417097Z", - "shell.execute_reply": "2025-01-09T00:23:46.416504Z", - "shell.execute_reply.started": "2025-01-09T00:23:46.290525Z" - }, - "id": "XQjP2889YJxs", - "outputId": "3e686d7b-0774-4bb5-c1b9-28e5b9f286a9" - }, - "outputs": [], - "source": [ - "# %load_ext watermark\n", - "# %watermark --iversions" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "colab": { - "background_save": true - }, - "execution": { - "iopub.execute_input": "2025-01-09T00:23:49.545271Z", - "iopub.status.busy": "2025-01-09T00:23:49.545016Z", - "iopub.status.idle": "2025-01-09T00:23:49.552123Z", - "shell.execute_reply": "2025-01-09T00:23:49.551676Z", - "shell.execute_reply.started": "2025-01-09T00:23:49.545253Z" - }, - "id": "bxMmeyCxR3fb" - }, - "outputs": [], - "source": [ - "def generate_test_data(size: int, k: int = 2):\n", - " \"\"\"\n", - " Generate benchmark data for pyfixest on GPU (similar to the R fixest benchmark data).\n", - "\n", - " Args:\n", - " size (int): The number of observations in the data frame.\n", - " k (int): The number of covariates in the data frame.\n", - "\n", - " Returns\n", - " -------\n", - " pd.DataFrame: The generated data frame for the given size.\n", - " \"\"\"\n", - " # Constants\n", - " all_n = [1000 * 10**i for i in range(5)]\n", - " a = 1\n", - " b = 0.05\n", - "\n", - " n = all_n[size - 1]\n", - "\n", - " dum_all = []\n", - " nb_dum = [n // 20, int(np.sqrt(n)), int(n**0.33)]\n", - "\n", - " dum_all = np.zeros((n, 3))\n", - " dum_all[:, 0] = np.random.choice(nb_dum[0], n, replace=True)\n", - " dum_all[:, 1] = np.random.choice(nb_dum[1], n, replace=True)\n", - " dum_all[:, 2] = np.random.choice(nb_dum[2], n, replace=True)\n", - " dum_all = dum_all.astype(int)\n", - "\n", - " X1 = np.random.normal(size=n)\n", - " X2 = X1**2\n", - "\n", - " mu = a * X1 + b * X2\n", - "\n", - " for m in range(3):\n", - " coef_dum = np.random.normal(size=nb_dum[m])\n", - " mu += coef_dum[dum_all[:, m]]\n", - "\n", - " mu = np.exp(mu)\n", - " y = nbinom.rvs(0.5, 1 - (mu / (mu + 0.5)), size=n)\n", - "\n", - " X_full = np.column_stack((X1, X2))\n", - " base = pd.DataFrame(\n", - " {\n", - " \"y\": y,\n", - " \"ln_y\": np.log(y + 1),\n", - " \"X1\": X1,\n", - " \"X2\": X2,\n", - " }\n", - " )\n", - "\n", - " if k > 2:\n", - " X = np.random.normal(size=(n, k - 2))\n", - " X_df = pd.DataFrame(X, columns=[f\"X{i}\" for i in range(3, k + 1, 1)])\n", - " base = pd.concat([base, X_df], axis=1)\n", - " X_full = np.column_stack((X_full, X))\n", - "\n", - " for m in range(3):\n", - " base[f\"dum_{m + 1}\"] = dum_all[:, m]\n", - "\n", - " weights = np.random.uniform(0, 1, n)\n", - " return base, y, X_full, dum_all, weights" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "execution": { - "iopub.execute_input": "2025-01-09T00:23:50.285297Z", - "iopub.status.busy": "2025-01-09T00:23:50.284967Z", - "iopub.status.idle": "2025-01-09T00:23:50.460957Z", - "shell.execute_reply": "2025-01-09T00:23:50.460501Z", - "shell.execute_reply.started": "2025-01-09T00:23:50.285276Z" - }, - "id": "nzynhbqwR81H" - }, - "outputs": [], - "source": [ - "df, Y, X, f, weights = generate_test_data(1)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "execution": { - "iopub.execute_input": "2025-01-09T00:25:02.873750Z", - "iopub.status.busy": "2025-01-09T00:25:02.873239Z", - "iopub.status.idle": "2025-01-09T00:25:03.153458Z", - "shell.execute_reply": "2025-01-09T00:25:03.153005Z", - "shell.execute_reply.started": "2025-01-09T00:25:02.873732Z" - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "###\n", - "\n", - "Estimation: OLS\n", - "Dep. var.: ln_y, Fixed effects: dum_1\n", - "Inference: CRV1\n", - "Observations: 1000\n", - "\n", - "| Coefficient | Estimate | Std. Error | t value | Pr(>|t|) | 2.5% | 97.5% |\n", - "|:--------------|-----------:|-------------:|----------:|-----------:|-------:|--------:|\n", - "| X1 | 0.436 | 0.046 | 9.440 | 0.000 | 0.343 | 0.529 |\n", - "---\n", - "RMSE: 1.067 R2: 0.242 R2 Within: 0.131 \n" - ] - } - ], - "source": [ - "m0 = pf.feols(\"ln_y ~ X1 | dum_1\", df, demeaner_backend=\"numba\")\n", - "m0.summary()" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "execution": { - "iopub.execute_input": "2025-01-09T00:25:03.330646Z", - "iopub.status.busy": "2025-01-09T00:25:03.330367Z", - "iopub.status.idle": "2025-01-09T00:25:03.571916Z", - "shell.execute_reply": "2025-01-09T00:25:03.571482Z", - "shell.execute_reply.started": "2025-01-09T00:25:03.330629Z" - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "###\n", - "\n", - "Estimation: OLS\n", - "Dep. var.: ln_y, Fixed effects: dum_1\n", - "Inference: CRV1\n", - "Observations: 1000\n", - "\n", - "| Coefficient | Estimate | Std. Error | t value | Pr(>|t|) | 2.5% | 97.5% |\n", - "|:--------------|-----------:|-------------:|----------:|-----------:|-------:|--------:|\n", - "| X1 | 0.436 | 0.046 | 9.440 | 0.000 | 0.343 | 0.529 |\n", - "---\n", - "RMSE: 1.067 R2: 0.242 R2 Within: 0.131 \n" - ] - } - ], - "source": [ - "m1 = pf.feols(\"ln_y ~ X1 | dum_1\", df, demeaner_backend=\"jax\")\n", - "m1.summary()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## function" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "execution": { - "iopub.execute_input": "2025-01-09T00:24:06.619552Z", - "iopub.status.busy": "2025-01-09T00:24:06.619273Z", - "iopub.status.idle": "2025-01-09T00:24:06.626298Z", - "shell.execute_reply": "2025-01-09T00:24:06.625727Z", - "shell.execute_reply.started": "2025-01-09T00:24:06.619534Z" - }, - "id": "29rZkULUR_A0" - }, - "outputs": [], - "source": [ - "def run_standard_benchmark(\n", - " fixed_effect,\n", - " demeaner_backend,\n", - " size=1,\n", - " k=1,\n", - " solver=\"np.linalg.lstsq\",\n", - " skip_demean_benchmark=True,\n", - "):\n", - " \"\"\"\n", - " Run the fixest standard benchmark fixed effect models. This is the function the benchmarks\n", - " will loop over.\n", - "\n", - " Args:\n", - " fixed_effect (str): The fixed effect to use. Must be a list of variables as \"dum_1\", \"dum_1+dum_2\", or \"dum_1+dum_2+dum_3\", etc.\n", - " demeaner_backend (str): The backend to use for demeaning. Must be \"numba\" or \"jax\".\n", - " size (int): The size of the data to generate. Must be between 1 and 5. For 1, N = 1000, for 2, N = 10000, etc.\n", - " k_vals (int): The number of covariates to generate.\n", - " solver (str): The solver to use for the estimation. Must be \"np.linalg.lstsq\". \"jax\" currently throws an error.\n", - " skip_demean_benchmark (bool): Whether to skip the \"pure\" demean benchmark. Default is True. Only the full call\n", - " to feols is benchmarked.\n", - "\n", - " \"\"\"\n", - " assert fixed_effect in [\"dum_1\", \"dum_1+dum_2\", \"dum_1+dum_2+dum_3\"]\n", - "\n", - " # one fixed effect\n", - " res = []\n", - "\n", - " fml_base = \"ln_y ~ X1\"\n", - " fml = f\"{fml_base} | {fixed_effect}\"\n", - "\n", - " # warmup\n", - " df, y, X, f, weights = generate_test_data(1)\n", - " pf.feols(\n", - " fml,\n", - " data=df,\n", - " demeaner_backend=demeaner_backend,\n", - " store_data=False,\n", - " copy_data=False,\n", - " solver=solver,\n", - " )\n", - "\n", - " if k > 1:\n", - " xfml = \"+\".join([f\"X{i}\" for i in range(2, k + 1, 1)])\n", - " fml = f\"{fml_base} + {xfml} | {fixed_effect}\"\n", - " else:\n", - " fml = f\"{fml_base} + X1 | {fixed_effect}\"\n", - "\n", - " for rep in range(1, 11):\n", - " df, Y, X, f, weights = generate_test_data(size=size, k=k)\n", - "\n", - " tic1 = time.time()\n", - " pf.feols(\n", - " fml,\n", - " data=df,\n", - " demeaner_backend=demeaner_backend,\n", - " store_data=False,\n", - " copy_data=False,\n", - " solver=solver,\n", - " )\n", - " tic2 = time.time()\n", - "\n", - " full_feols_timing = tic2 - tic1\n", - "\n", - " demean_timing = np.nan\n", - " if not skip_demean_benchmark:\n", - " YX = np.column_stack((Y.reshape(-1, 1), X))\n", - " tic3 = time.time()\n", - " if demeaner_backend == \"jax\":\n", - " _, _ = demean_jax(YX, f, weights, tol=1e-10)\n", - " else:\n", - " _, _ = demean(YX, f, weights, tol=1e-10)\n", - " tic4 = time.time()\n", - " demean_timing = tic4 - tic3\n", - "\n", - " res.append(\n", - " pd.Series(\n", - " {\n", - " \"method\": \"feols\",\n", - " \"solver\": solver,\n", - " \"demeaner_backend\": demeaner_backend,\n", - " \"n_obs\": df.shape[0],\n", - " \"k\": k,\n", - " \"G\": len(fixed_effect.split(\"+\")),\n", - " \"rep\": rep,\n", - " \"full_feols_timing\": full_feols_timing,\n", - " \"demean_timing\": demean_timing,\n", - " }\n", - " )\n", - " )\n", - "\n", - " return pd.concat(res, axis=1).T" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "execution": { - "iopub.execute_input": "2025-01-09T00:28:43.818536Z", - "iopub.status.busy": "2025-01-09T00:28:43.818246Z", - "iopub.status.idle": "2025-01-09T00:28:51.489202Z", - "shell.execute_reply": "2025-01-09T00:28:51.488591Z", - "shell.execute_reply.started": "2025-01-09T00:28:43.818520Z" - } - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
methodsolverdemeaner_backendn_obskGrepfull_feols_timingdemean_timing
0feolsnp.linalg.lstsqnumba10001110.150473NaN
1feolsnp.linalg.lstsqnumba10001120.147583NaN
2feolsnp.linalg.lstsqnumba10001130.186491NaN
3feolsnp.linalg.lstsqnumba10001140.190972NaN
4feolsnp.linalg.lstsqnumba10001150.162773NaN
5feolsnp.linalg.lstsqnumba10001160.171777NaN
6feolsnp.linalg.lstsqnumba10001170.166872NaN
7feolsnp.linalg.lstsqnumba10001180.158694NaN
8feolsnp.linalg.lstsqnumba10001190.185547NaN
9feolsnp.linalg.lstsqnumba100011100.158114NaN
\n", - "
" - ], - "text/plain": [ - " method solver demeaner_backend n_obs k G rep full_feols_timing \\\n", - "0 feols np.linalg.lstsq numba 1000 1 1 1 0.150473 \n", - "1 feols np.linalg.lstsq numba 1000 1 1 2 0.147583 \n", - "2 feols np.linalg.lstsq numba 1000 1 1 3 0.186491 \n", - "3 feols np.linalg.lstsq numba 1000 1 1 4 0.190972 \n", - "4 feols np.linalg.lstsq numba 1000 1 1 5 0.162773 \n", - "5 feols np.linalg.lstsq numba 1000 1 1 6 0.171777 \n", - "6 feols np.linalg.lstsq numba 1000 1 1 7 0.166872 \n", - "7 feols np.linalg.lstsq numba 1000 1 1 8 0.158694 \n", - "8 feols np.linalg.lstsq numba 1000 1 1 9 0.185547 \n", - "9 feols np.linalg.lstsq numba 1000 1 1 10 0.158114 \n", - "\n", - " demean_timing \n", - "0 NaN \n", - "1 NaN \n", - "2 NaN \n", - "3 NaN \n", - "4 NaN \n", - "5 NaN \n", - "6 NaN \n", - "7 NaN \n", - "8 NaN \n", - "9 NaN " - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# test run numba\n", - "run_standard_benchmark(fixed_effect=\"dum_1\", demeaner_backend=\"numba\", size=1, k=1)" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "execution": { - "iopub.execute_input": "2025-01-09T00:28:43.818536Z", - "iopub.status.busy": "2025-01-09T00:28:43.818246Z", - "iopub.status.idle": "2025-01-09T00:28:51.489202Z", - "shell.execute_reply": "2025-01-09T00:28:51.488591Z", - "shell.execute_reply.started": "2025-01-09T00:28:43.818520Z" - } - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
methodsolverdemeaner_backendn_obskGrepfull_feols_timingdemean_timing
0feolsnp.linalg.lstsqjax10001110.122831NaN
1feolsnp.linalg.lstsqjax10001120.122887NaN
2feolsnp.linalg.lstsqjax10001130.136041NaN
3feolsnp.linalg.lstsqjax10001140.139644NaN
4feolsnp.linalg.lstsqjax10001150.136235NaN
5feolsnp.linalg.lstsqjax10001160.122477NaN
6feolsnp.linalg.lstsqjax10001170.123122NaN
7feolsnp.linalg.lstsqjax10001180.119589NaN
8feolsnp.linalg.lstsqjax10001190.122247NaN
9feolsnp.linalg.lstsqjax100011100.118353NaN
\n", - "
" - ], - "text/plain": [ - " method solver demeaner_backend n_obs k G rep full_feols_timing \\\n", - "0 feols np.linalg.lstsq jax 1000 1 1 1 0.122831 \n", - "1 feols np.linalg.lstsq jax 1000 1 1 2 0.122887 \n", - "2 feols np.linalg.lstsq jax 1000 1 1 3 0.136041 \n", - "3 feols np.linalg.lstsq jax 1000 1 1 4 0.139644 \n", - "4 feols np.linalg.lstsq jax 1000 1 1 5 0.136235 \n", - "5 feols np.linalg.lstsq jax 1000 1 1 6 0.122477 \n", - "6 feols np.linalg.lstsq jax 1000 1 1 7 0.123122 \n", - "7 feols np.linalg.lstsq jax 1000 1 1 8 0.119589 \n", - "8 feols np.linalg.lstsq jax 1000 1 1 9 0.122247 \n", - "9 feols np.linalg.lstsq jax 1000 1 1 10 0.118353 \n", - "\n", - " demean_timing \n", - "0 NaN \n", - "1 NaN \n", - "2 NaN \n", - "3 NaN \n", - "4 NaN \n", - "5 NaN \n", - "6 NaN \n", - "7 NaN \n", - "8 NaN \n", - "9 NaN " - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# test run jax\n", - "run_standard_benchmark(fixed_effect=\"dum_1\", demeaner_backend=\"jax\", size=1, k=1)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "execution": { - "iopub.execute_input": "2025-01-09T00:25:39.248695Z", - "iopub.status.busy": "2025-01-09T00:25:39.248317Z", - "iopub.status.idle": "2025-01-09T00:25:39.253652Z", - "shell.execute_reply": "2025-01-09T00:25:39.253000Z", - "shell.execute_reply.started": "2025-01-09T00:25:39.248671Z" - }, - "id": "HxsGRMSlR_jI" - }, - "outputs": [], - "source": [ - "def run_all_benchmarks(size_list, k_list):\n", - " \"\"\"\n", - " Run all the benchmarks.\n", - "\n", - " Args:\n", - " size_list (list): The list of sizes to run the benchmarks on. 1-> 1000, 2-> 10000, ..., 5-> 10_000_000\n", - " k_list (list): The list of k values to run the benchmarks on.\n", - " \"\"\"\n", - " res = pd.DataFrame()\n", - "\n", - " all_combinations = list(\n", - " product(\n", - " [\"numba\", \"jax\"], # demeaner_backend\n", - " [\"dum_1\", \"dum_1+dum_2\", \"dum_1+dum_2+dum_3\"], # fixef\n", - " size_list, # size\n", - " k_list, # k\n", - " [\"np.linalg.lstsq\"], # solver\n", - " )\n", - " )\n", - "\n", - " with tqdm(total=len(all_combinations), desc=\"Running Benchmarks\") as pbar:\n", - " for demeaner_backend, fixef, size, k, solver in all_combinations:\n", - " res = pd.concat(\n", - " [\n", - " res,\n", - " run_standard_benchmark(\n", - " solver=solver,\n", - " fixed_effect=fixef,\n", - " demeaner_backend=demeaner_backend,\n", - " size=size,\n", - " k=k,\n", - " ),\n", - " ],\n", - " axis=0,\n", - " )\n", - " pbar.update(1) # Update the progress bar after each iteration\n", - "\n", - " return res" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Run Benchmarks" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "colab": { - "background_save": true, - "base_uri": "https://localhost:8080/" - }, - "execution": { - "iopub.execute_input": "2025-01-09T00:25:39.962554Z", - "iopub.status.busy": "2025-01-09T00:25:39.962039Z", - "iopub.status.idle": "2025-01-09T00:26:25.319310Z", - "shell.execute_reply": "2025-01-09T00:26:25.318687Z", - "shell.execute_reply.started": "2025-01-09T00:25:39.962536Z" - }, - "id": "gki1mlqvSEIi", - "outputId": "3cb40095-df81-4e78-99a6-2410da237884" - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Running Benchmarks: 100%|██████████| 24/24 [00:49<00:00, 2.06s/it]\n" - ] - } - ], - "source": [ - "res_all = run_all_benchmarks(\n", - " size_list=[1, 2, 3, 4, 5], # for N = 1000, 10_000, 100_000, 1_000_000, 10_000_000\n", - " k_list=[1, 10, 50, 100], # for k = 1, 10, 50, 100\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": { - "colab": { - "background_save": true, - "base_uri": "https://localhost:8080/" - }, - "execution": { - "iopub.execute_input": "2025-01-09T00:26:25.320894Z", - "iopub.status.busy": "2025-01-09T00:26:25.320596Z", - "iopub.status.idle": "2025-01-09T00:26:26.391854Z", - "shell.execute_reply": "2025-01-09T00:26:26.391236Z", - "shell.execute_reply.started": "2025-01-09T00:26:25.320871Z" - }, - "id": "7zEIHj5nXXvq", - "outputId": "238dedce-a9f5-4a1b-b9a2-b84d4a89c091" - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
methoddemeaner_backendkGn_obsfull_feols_timingdemean_timing
0feolsjax1110000.127274NaN
1feolsjax11100000.142241NaN
2feolsjax1210000.134405NaN
3feolsjax12100000.168555NaN
4feolsjax1310000.139142NaN
5feolsjax13100000.189492NaN
6feolsjax5110000.144933NaN
7feolsjax51100000.147049NaN
8feolsjax5210000.147216NaN
9feolsjax52100000.175394NaN
10feolsjax5310000.15034NaN
11feolsjax53100000.194496NaN
12feolsnumba1110000.153849NaN
13feolsnumba11100000.183697NaN
14feolsnumba1210000.17201NaN
15feolsnumba12100000.16131NaN
16feolsnumba1310000.169463NaN
17feolsnumba13100000.169789NaN
18feolsnumba5110000.331457NaN
19feolsnumba51100000.172194NaN
20feolsnumba5210000.162562NaN
21feolsnumba52100000.184233NaN
22feolsnumba5310000.174623NaN
23feolsnumba53100000.166674NaN
\n", - "
" - ], - "text/plain": [ - " method demeaner_backend k G n_obs full_feols_timing demean_timing\n", - "0 feols jax 1 1 1000 0.127274 NaN\n", - "1 feols jax 1 1 10000 0.142241 NaN\n", - "2 feols jax 1 2 1000 0.134405 NaN\n", - "3 feols jax 1 2 10000 0.168555 NaN\n", - "4 feols jax 1 3 1000 0.139142 NaN\n", - "5 feols jax 1 3 10000 0.189492 NaN\n", - "6 feols jax 5 1 1000 0.144933 NaN\n", - "7 feols jax 5 1 10000 0.147049 NaN\n", - "8 feols jax 5 2 1000 0.147216 NaN\n", - "9 feols jax 5 2 10000 0.175394 NaN\n", - "10 feols jax 5 3 1000 0.15034 NaN\n", - "11 feols jax 5 3 10000 0.194496 NaN\n", - "12 feols numba 1 1 1000 0.153849 NaN\n", - "13 feols numba 1 1 10000 0.183697 NaN\n", - "14 feols numba 1 2 1000 0.17201 NaN\n", - "15 feols numba 1 2 10000 0.16131 NaN\n", - "16 feols numba 1 3 1000 0.169463 NaN\n", - "17 feols numba 1 3 10000 0.169789 NaN\n", - "18 feols numba 5 1 1000 0.331457 NaN\n", - "19 feols numba 5 1 10000 0.172194 NaN\n", - "20 feols numba 5 2 1000 0.162562 NaN\n", - "21 feols numba 5 2 10000 0.184233 NaN\n", - "22 feols numba 5 3 1000 0.174623 NaN\n", - "23 feols numba 5 3 10000 0.166674 NaN" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df = (\n", - " res_all.drop([\"rep\", \"solver\"], axis=1)\n", - " .groupby([\"method\", \"demeaner_backend\", \"k\", \"G\", \"n_obs\"])\n", - " .mean()\n", - " .reset_index()\n", - ")\n", - "df" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab": { - "background_save": true - }, - "id": "VCn6O5MMXlBw" - }, - "source": [ - "## Visualize" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:matplotlib.category:Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.\n", - "INFO:matplotlib.category:Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.\n", - "INFO:matplotlib.category:Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.\n", - "INFO:matplotlib.category:Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.\n", - "INFO:matplotlib.category:Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.\n", - "INFO:matplotlib.category:Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.\n", - "INFO:matplotlib.category:Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.\n", - "INFO:matplotlib.category:Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.\n", - "INFO:matplotlib.category:Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.\n", - "INFO:matplotlib.category:Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.\n", - "INFO:matplotlib.category:Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.\n", - "INFO:matplotlib.category:Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.\n" - ] - }, - { - "data": { - "image/png": "", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "df[\"G\"] = df[\"G\"].map({1: \"n_fixef = 1\", 2: \"n_fixef = 2\", 3: \"n_fixef = 3\"})\n", - "df[\"n_obs\"] = df[\"n_obs\"].astype(str)\n", - "\n", - "# Dynamically determine unique values for order and hue_order\n", - "n_obs_order = sorted(df[\"n_obs\"].unique(), key=lambda x: int(x)) # Sort as integers\n", - "demeaner_backend_order = df[\"demeaner_backend\"].unique()\n", - "\n", - "custom_palette = sns.color_palette(\"coolwarm\", n_colors=2)\n", - "\n", - "# Create the FacetGrid with reordered columns and rows\n", - "g = sns.FacetGrid(\n", - " df,\n", - " col=\"G\", # G (n_fixef) increases left to right\n", - " row=\"k\", # k increases top to bottom\n", - " margin_titles=True,\n", - " height=4,\n", - " aspect=1.2,\n", - " col_order=[\"n_fixef = 1\", \"n_fixef = 2\", \"n_fixef = 3\"], # Ensure correct order\n", - ")\n", - "\n", - "# Plot the bar chart for each facet with the custom palette\n", - "g.map(\n", - " sns.barplot,\n", - " \"n_obs\",\n", - " \"full_feols_timing\",\n", - " \"demeaner_backend\",\n", - " order=n_obs_order, # Dynamic order for n_obs\n", - " hue_order=demeaner_backend_order, # Dynamic hue order for demeaner_backend\n", - " errorbar=None, # Suppress error bars\n", - " palette=custom_palette,\n", - ")\n", - "\n", - "# Add legend and adjust layout\n", - "g.add_legend(title=\"Demeaner Backend\")\n", - "g.set_axis_labels(\"Number of Observations\", \"Runtime (seconds)\")\n", - "g.set_titles(row_template=\"k = {row_name}\", col_template=\"{col_name}\")\n", - "plt.subplots_adjust(top=0.9)\n", - "g.fig.suptitle(\"Runtime vs Number of Observations by n_fixef and k\")\n", - "\n", - "# Show plot\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "colab": { - "provenance": [] - }, - "kernelspec": { - "display_name": "jax", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.8" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -}