From 25665bc89b549c17cef9831a744353446019132f Mon Sep 17 00:00:00 2001 From: Showmick Das Date: Thu, 20 Feb 2025 01:42:39 -0500 Subject: [PATCH 01/17] Created using Colab --- chronos.ipynb | 258 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 258 insertions(+) create mode 100644 chronos.ipynb diff --git a/chronos.ipynb b/chronos.ipynb new file mode 100644 index 0000000..4301ea5 --- /dev/null +++ b/chronos.ipynb @@ -0,0 +1,258 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "authorship_tag": "ABX9TyO3JkXtZ8GmkdSpgF01FRYS", + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# **Install Dependencies**" + ], + "metadata": { + "id": "tpaQlDVlxetV" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "259--YZOr1Lo" + }, + "outputs": [], + "source": [ + "!pip install samay-0.1.0-cp311-cp311-linux_x86_64.whl" + ] + }, + { + "cell_type": "code", + "source": [ + "!pip install --upgrade -U numpy --force" + ], + "metadata": { + "id": "VGf7Oy_JxpYO" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Importing Requirements**" + ], + "metadata": { + "id": "0cbiLmxU2Lw-" + } + }, + { + "cell_type": "code", + "source": [ + "from tsfmproject.model import ChronosModel\n", + "from tsfmproject.dataset import ChronosDataset\n", + "from tsfmproject.visualization import ForecastVisualization" + ], + "metadata": { + "id": "APfhHKAU1qM3" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Training Dataset**" + ], + "metadata": { + "id": "Gt2KUswn2UtL" + } + }, + { + "cell_type": "code", + "source": [ + "train_dataset = ChronosDataset(name=\"ett\", mode=\"train\", path='/nethome/sli999/TSFMProject/src/tsfmproject/models/moment/data/ETTh1.csv', datetime_col='date', freq='h', context_len=128, horizon_len=64)\n", + "test_dataset = ChronosDataset(name=\"ett\", mode=\"test\", path='/nethome/sli999/TSFMProject/src/tsfmproject/models/moment/data/ETTh1.csv', datetime_col='date', freq='h', context_len=128, horizon_len=64)\n", + "\n", + "print(len(test_dataset.dataset))\n", + "# print(test_dataset.dataset.shape)" + ], + "metadata": { + "id": "MyUXzOME1sgo" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Loading the Chronos Model**" + ], + "metadata": { + "id": "OgcxscvD3Ee8" + } + }, + { + "cell_type": "code", + "source": [ + "repo = \"amazon/chronos-t5-small\"\n", + "ch = ChronosModel(config=None, repo=repo)\n", + "ch.load_model()" + ], + "metadata": { + "id": "K9Zmx3li1u5k" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Evaluated Results**" + ], + "metadata": { + "id": "2I0VE7Uu3Jqt" + } + }, + { + "cell_type": "code", + "source": [ + "eval_results, trues, preds, histories = ch.evaluate(test_dataset, batch_size=8, metrics=[\"MSE\", \"MASE\"])\n", + "print(eval_results)\n", + "# visualization = ForecastVisualization(trues, preds[:,:,1,:], histories)\n", + "# visualization.plot()" + ], + "metadata": { + "id": "Z90jyPsp1ykH" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Visualizing Results**" + ], + "metadata": { + "id": "iuzgw9uA3idf" + } + }, + { + "cell_type": "code", + "source": [ + "visualization = ForecastVisualization(trues, preds, histories)\n", + "visualization.plot(channel_idx=0, time_idx=0)" + ], + "metadata": { + "id": "LeFA_RVg12IS" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Finetuning**" + ], + "metadata": { + "id": "DE5-HDC53n4_" + } + }, + { + "cell_type": "code", + "source": [ + "ch.finetune(train_dataset)" + ], + "metadata": { + "id": "XF2bgAXi125y" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Loading the Sequence to Sequence Model**" + ], + "metadata": { + "id": "Szu7i0UU31IA" + } + }, + { + "cell_type": "code", + "source": [ + "latest_run_dir = ch.get_latest_run_dir()\n", + "model_dir = os.path.join(latest_run_dir, \"checkpoint-final\")\n", + "model_type = \"seq2seq\"\n", + "model = ch.load_model(model_dir, model_type)" + ], + "metadata": { + "id": "5IB_u3T0164d" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Evaluating Results on seq2seq Model**" + ], + "metadata": { + "id": "Wgrg7Yxj4l7g" + } + }, + { + "cell_type": "code", + "source": [ + "eval_results, trues, preds, histories = ch.evaluate(test_dataset, batch_size=8, metrics=[\"MSE\", \"MASE\"])\n", + "print(eval_results)" + ], + "metadata": { + "id": "4zOcRXrf18hZ" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Visualizing Results on seq2seq Model**" + ], + "metadata": { + "id": "frcLPZ-W4AOr" + } + }, + { + "cell_type": "code", + "source": [ + "visualization = ForecastVisualization(trues, preds, histories)\n", + "visualization.plot(channel_idx=0, time_idx=0)" + ], + "metadata": { + "id": "5Whw0DjV1-Tj" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file From 08314bc91732313c5935d022b4f6f4a55f2962ce Mon Sep 17 00:00:00 2001 From: Showmick Das Date: Thu, 20 Feb 2025 01:46:07 -0500 Subject: [PATCH 02/17] Moving chronos.ipynb to example/colab directory --- chronos.ipynb => example/colab/chronos.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename chronos.ipynb => example/colab/chronos.ipynb (99%) diff --git a/chronos.ipynb b/example/colab/chronos.ipynb similarity index 99% rename from chronos.ipynb rename to example/colab/chronos.ipynb index 4301ea5..1377100 100644 --- a/chronos.ipynb +++ b/example/colab/chronos.ipynb @@ -255,4 +255,4 @@ "outputs": [] } ] -} \ No newline at end of file +} From 4ab40e3a2b8d69dde2d9761cf9d81f73b4ea03e5 Mon Sep 17 00:00:00 2001 From: Showmick Das Date: Thu, 20 Feb 2025 02:00:42 -0500 Subject: [PATCH 03/17] Delete example/colab/chronos.ipynb --- example/colab/chronos.ipynb | 258 ------------------------------------ 1 file changed, 258 deletions(-) delete mode 100644 example/colab/chronos.ipynb diff --git a/example/colab/chronos.ipynb b/example/colab/chronos.ipynb deleted file mode 100644 index 1377100..0000000 --- a/example/colab/chronos.ipynb +++ /dev/null @@ -1,258 +0,0 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [], - "authorship_tag": "ABX9TyO3JkXtZ8GmkdSpgF01FRYS", - "include_colab_link": true - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - } - }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "view-in-github", - "colab_type": "text" - }, - "source": [ - "\"Open" - ] - }, - { - "cell_type": "markdown", - "source": [ - "# **Install Dependencies**" - ], - "metadata": { - "id": "tpaQlDVlxetV" - } - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "259--YZOr1Lo" - }, - "outputs": [], - "source": [ - "!pip install samay-0.1.0-cp311-cp311-linux_x86_64.whl" - ] - }, - { - "cell_type": "code", - "source": [ - "!pip install --upgrade -U numpy --force" - ], - "metadata": { - "id": "VGf7Oy_JxpYO" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "# **Importing Requirements**" - ], - "metadata": { - "id": "0cbiLmxU2Lw-" - } - }, - { - "cell_type": "code", - "source": [ - "from tsfmproject.model import ChronosModel\n", - "from tsfmproject.dataset import ChronosDataset\n", - "from tsfmproject.visualization import ForecastVisualization" - ], - "metadata": { - "id": "APfhHKAU1qM3" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "# **Training Dataset**" - ], - "metadata": { - "id": "Gt2KUswn2UtL" - } - }, - { - "cell_type": "code", - "source": [ - "train_dataset = ChronosDataset(name=\"ett\", mode=\"train\", path='/nethome/sli999/TSFMProject/src/tsfmproject/models/moment/data/ETTh1.csv', datetime_col='date', freq='h', context_len=128, horizon_len=64)\n", - "test_dataset = ChronosDataset(name=\"ett\", mode=\"test\", path='/nethome/sli999/TSFMProject/src/tsfmproject/models/moment/data/ETTh1.csv', datetime_col='date', freq='h', context_len=128, horizon_len=64)\n", - "\n", - "print(len(test_dataset.dataset))\n", - "# print(test_dataset.dataset.shape)" - ], - "metadata": { - "id": "MyUXzOME1sgo" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "# **Loading the Chronos Model**" - ], - "metadata": { - "id": "OgcxscvD3Ee8" - } - }, - { - "cell_type": "code", - "source": [ - "repo = \"amazon/chronos-t5-small\"\n", - "ch = ChronosModel(config=None, repo=repo)\n", - "ch.load_model()" - ], - "metadata": { - "id": "K9Zmx3li1u5k" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "# **Evaluated Results**" - ], - "metadata": { - "id": "2I0VE7Uu3Jqt" - } - }, - { - "cell_type": "code", - "source": [ - "eval_results, trues, preds, histories = ch.evaluate(test_dataset, batch_size=8, metrics=[\"MSE\", \"MASE\"])\n", - "print(eval_results)\n", - "# visualization = ForecastVisualization(trues, preds[:,:,1,:], histories)\n", - "# visualization.plot()" - ], - "metadata": { - "id": "Z90jyPsp1ykH" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "# **Visualizing Results**" - ], - "metadata": { - "id": "iuzgw9uA3idf" - } - }, - { - "cell_type": "code", - "source": [ - "visualization = ForecastVisualization(trues, preds, histories)\n", - "visualization.plot(channel_idx=0, time_idx=0)" - ], - "metadata": { - "id": "LeFA_RVg12IS" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "# **Finetuning**" - ], - "metadata": { - "id": "DE5-HDC53n4_" - } - }, - { - "cell_type": "code", - "source": [ - "ch.finetune(train_dataset)" - ], - "metadata": { - "id": "XF2bgAXi125y" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "# **Loading the Sequence to Sequence Model**" - ], - "metadata": { - "id": "Szu7i0UU31IA" - } - }, - { - "cell_type": "code", - "source": [ - "latest_run_dir = ch.get_latest_run_dir()\n", - "model_dir = os.path.join(latest_run_dir, \"checkpoint-final\")\n", - "model_type = \"seq2seq\"\n", - "model = ch.load_model(model_dir, model_type)" - ], - "metadata": { - "id": "5IB_u3T0164d" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "# **Evaluating Results on seq2seq Model**" - ], - "metadata": { - "id": "Wgrg7Yxj4l7g" - } - }, - { - "cell_type": "code", - "source": [ - "eval_results, trues, preds, histories = ch.evaluate(test_dataset, batch_size=8, metrics=[\"MSE\", \"MASE\"])\n", - "print(eval_results)" - ], - "metadata": { - "id": "4zOcRXrf18hZ" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "# **Visualizing Results on seq2seq Model**" - ], - "metadata": { - "id": "frcLPZ-W4AOr" - } - }, - { - "cell_type": "code", - "source": [ - "visualization = ForecastVisualization(trues, preds, histories)\n", - "visualization.plot(channel_idx=0, time_idx=0)" - ], - "metadata": { - "id": "5Whw0DjV1-Tj" - }, - "execution_count": null, - "outputs": [] - } - ] -} From 378d206cd9b55a414c1ef4ad041e25d7e25e30f7 Mon Sep 17 00:00:00 2001 From: Showmick Das Date: Thu, 20 Feb 2025 02:11:17 -0500 Subject: [PATCH 04/17] Created using Colab --- chronos.ipynb | 213 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 213 insertions(+) create mode 100644 chronos.ipynb diff --git a/chronos.ipynb b/chronos.ipynb new file mode 100644 index 0000000..f005451 --- /dev/null +++ b/chronos.ipynb @@ -0,0 +1,213 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "authorship_tag": "ABX9TyOHc9Zna3DRB0e4ryzjGrBk", + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# **Install Dependencies**" + ], + "metadata": { + "id": "tpaQlDVlxetV" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "259--YZOr1Lo" + }, + "outputs": [], + "source": [ + "!pip install samay-0.1.0-cp311-cp311-linux_x86_64.whl" + ] + }, + { + "cell_type": "code", + "source": [ + "!pip install --upgrade -U numpy --force" + ], + "metadata": { + "id": "VGf7Oy_JxpYO" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Importing Requirements**" + ], + "metadata": { + "id": "0cbiLmxU2Lw-" + } + }, + { + "cell_type": "code", + "source": [ + "import os\n", + "import sys\n", + "import torch\n", + "import numpy as np\n", + "\n", + "src_path = os.path.abspath(os.path.join(\"..\", \"src\"))\n", + "if src_path not in sys.path:\n", + " sys.path.insert(0, src_path)\n", + "\n", + "from samay.model import ChronosModel\n", + "from samay.dataset import ChronosDataset\n", + "# from tsfmproject.utils import load_args\n", + "\n", + "# arg_path = \"../config/timesfm.json\"\n", + "# args = load_args(arg_path)\n", + "repo = \"amazon/chronos-t5-small\"\n", + "chronos_model = ChronosModel(repo=repo)" + ], + "metadata": { + "id": "APfhHKAU1qM3" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Loading Dataset**" + ], + "metadata": { + "id": "Gt2KUswn2UtL" + } + }, + { + "cell_type": "code", + "source": [ + "!wget https://raw.githubusercontent.com/kage08/Samay/refs/heads/main/data/data/ETTh1.csv" + ], + "metadata": { + "id": "mnIsa2TgJgCN" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "train_dataset = ChronosDataset(name=\"ett\", datetime_col='date', path='../src/tsfmproject/models/moment/data/ETTh1.csv',\n", + " mode='train', batch_size=8)\n", + "val_dataset = ChronosDataset(name=\"ett\", datetime_col='date', path='../src/tsfmproject/models/moment/data/ETTh1.csv',\n", + " mode='test', batch_size=8)" + ], + "metadata": { + "id": "MyUXzOME1sgo" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Visualize the zero-shot forecasting**" + ], + "metadata": { + "id": "OgcxscvD3Ee8" + } + }, + { + "cell_type": "code", + "source": [ + "chronos_model.plot(val_dataset, horizon_len=64, quantile_levels=[0.1, 0.5, 0.9])" + ], + "metadata": { + "id": "K9Zmx3li1u5k" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Evaluate the zero-shot Chronos Model**" + ], + "metadata": { + "id": "2I0VE7Uu3Jqt" + } + }, + { + "cell_type": "code", + "source": [ + "metrics = chronos_model.evaluate(val_dataset, horizon_len=64, quantile_levels=[0.1, 0.5, 0.9])\n", + "print(metrics)" + ], + "metadata": { + "id": "Z90jyPsp1ykH" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Finetune Chronos Model on the ETT dataset**" + ], + "metadata": { + "id": "iuzgw9uA3idf" + } + }, + { + "cell_type": "code", + "source": [ + "chronos_model.finetune(train_dataset)" + ], + "metadata": { + "id": "LeFA_RVg12IS" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Evaluate the Finetuned Chronos Model**" + ], + "metadata": { + "id": "DE5-HDC53n4_" + } + }, + { + "cell_type": "code", + "source": [ + "metrics = chronos_model.evaluate(val_dataset, horizon_len=64, quantile_levels=[0.1, 0.5, 0.9])\n", + "print(metrics)" + ], + "metadata": { + "id": "XF2bgAXi125y" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file From 68624b5ba16b0216996d2c35e59b4c46604227b2 Mon Sep 17 00:00:00 2001 From: Showmick Das Date: Thu, 20 Feb 2025 02:12:04 -0500 Subject: [PATCH 05/17] Adding chronos.ipynb to example/colab directory --- chronos.ipynb => example/colab/chronos.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename chronos.ipynb => example/colab/chronos.ipynb (99%) diff --git a/chronos.ipynb b/example/colab/chronos.ipynb similarity index 99% rename from chronos.ipynb rename to example/colab/chronos.ipynb index f005451..b47ff4c 100644 --- a/chronos.ipynb +++ b/example/colab/chronos.ipynb @@ -210,4 +210,4 @@ "outputs": [] } ] -} \ No newline at end of file +} From b4c72cb8df35ac42394cbd429864b65e594fbabe Mon Sep 17 00:00:00 2001 From: Showmick Das Date: Thu, 20 Feb 2025 02:19:06 -0500 Subject: [PATCH 06/17] Created using Colab --- chronos_trial.ipynb | 645 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 645 insertions(+) create mode 100644 chronos_trial.ipynb diff --git a/chronos_trial.ipynb b/chronos_trial.ipynb new file mode 100644 index 0000000..fb9ae68 --- /dev/null +++ b/chronos_trial.ipynb @@ -0,0 +1,645 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "authorship_tag": "ABX9TyPxvN1ZBJ3PXV+dtaz3/lQX", + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# **Install Dependencies**" + ], + "metadata": { + "id": "tpaQlDVlxetV" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qF8U27Si8bbj" + }, + "outputs": [], + "source": [ + "!pip install samay-0.1.0-cp311-cp311-linux_x86_64.whl" + ] + }, + { + "cell_type": "code", + "source": [ + "!pip install --upgrade -U numpy --force" + ], + "metadata": { + "id": "VGf7Oy_JxpYO" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Importing Requirements**" + ], + "metadata": { + "id": "0cbiLmxU2Lw-" + } + }, + { + "cell_type": "code", + "source": [ + "import os\n", + "import sys\n", + "\n", + "import numpy as np\n", + "import torch\n", + "\n", + "src_path = os.path.abspath(os.path.join(\"src\"))\n", + "if src_path not in sys.path:\n", + " sys.path.insert(0, src_path)\n", + "\n", + "print(sys.path)" + ], + "metadata": { + "id": "9KSGbm9r9wVy" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "from samay.dataset import ChronosDataset\n", + "from samay.model import ChronosModel\n", + "from samay.visualization import ForecastVisualization" + ], + "metadata": { + "id": "0-Z4X0kdNPR8" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Loading Dataset**" + ], + "metadata": { + "id": "Gt2KUswn2UtL" + } + }, + { + "cell_type": "code", + "source": [ + "!wget https://raw.githubusercontent.com/kage08/Samay/refs/heads/main/data/data/ETTh1.csv" + ], + "metadata": { + "id": "mnIsa2TgJgCN" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "train_dataset = ChronosDataset(\n", + " name=\"ett\",\n", + " mode=\"train\",\n", + " path=\"/nethome/sli999/TSFMProject/src/tsfmproject/models/moment/data/ETTh1.csv\",\n", + " datetime_col=\"date\",\n", + " freq=\"h\",\n", + " context_len=128,\n", + " horizon_len=64,\n", + ")\n", + "test_dataset = ChronosDataset(\n", + " name=\"ett\",\n", + " mode=\"test\",\n", + " path=\"/nethome/sli999/TSFMProject/src/tsfmproject/models/moment/data/ETTh1.csv\",\n", + " datetime_col=\"date\",\n", + " freq=\"h\",\n", + " context_len=128,\n", + " horizon_len=64,\n", + ")\n", + "# train_dataset = ChronosDataset(name=\"ett\", mode=\"train\", path='/nethome/abhalerao9/TIMESERIESMODELING/TSFMProject/data/dataset/timesfm_covid_pivot.csv', datetime_col='ds', freq='D', context_len=64, horizon_len=16)\n", + "# test_dataset = ChronosDataset(name=\"ett\", mode=\"test\", path='/nethome/abhalerao9/TIMESERIESMODELING/TSFMProject/data/dataset/timesfm_covid_pivot.csv', datetime_col='ds', freq='D', context_len=64, horizon_len=16)\n", + "print(len(test_dataset.dataset))\n", + "# print(test_dataset.dataset.shape)" + ], + "metadata": { + "id": "UoMPUycM-uX4" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Loading the Chronos Model**" + ], + "metadata": { + "id": "OgcxscvD3Ee8" + } + }, + { + "cell_type": "code", + "source": [ + "repo = \"amazon/chronos-t5-small\"\n", + "ch = ChronosModel(config=None, repo=repo)\n", + "ch.load_model()" + ], + "metadata": { + "id": "yfcJVg7d-xE0" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "print(ch.model.model.device)" + ], + "metadata": { + "id": "lt5Scq3e-zTu" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "eval_results, trues, preds, histories = ch.evaluate(\n", + " test_dataset, batch_size=8, metrics=[\"MSE\", \"MASE\"]\n", + ")\n", + "print(eval_results)\n", + "# visualization = ForecastVisualization(trues, preds[:,:,1,:], histories)\n", + "# visualization.plot()" + ], + "metadata": { + "id": "Jsmfu_hS-1S-" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "visualization = ForecastVisualization(trues, preds, histories)\n", + "visualization.plot(channel_idx=0, time_idx=0)" + ], + "metadata": { + "id": "Dk_Rkgvd-373" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "print(trues.shape)\n", + "print(preds.shape)\n", + "print(histories.shape)" + ], + "metadata": { + "id": "k-Y5T3qk-6RV" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "ch.finetune(train_dataset)" + ], + "metadata": { + "id": "TWElWIOD-75d" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "latest_run_dir = ch.get_latest_run_dir()\n", + "model_dir = os.path.join(latest_run_dir, \"checkpoint-final\")\n", + "model_type = \"seq2seq\"\n", + "model = ch.load_model(model_dir, model_type)" + ], + "metadata": { + "id": "X4FyWD2q-9_-" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "eval_results, trues, preds, histories = ch.evaluate(\n", + " test_dataset, batch_size=8, metrics=[\"MSE\", \"MASE\"]\n", + ")\n", + "print(eval_results)" + ], + "metadata": { + "id": "UjsDeYg9-_0T" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "visualization = ForecastVisualization(trues, preds, histories)\n", + "visualization.plot(channel_idx=0, time_index=0)" + ], + "metadata": { + "id": "SumTaMAd_B2d" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "data = test_dataset.dataset\n", + "data = np.array(data).transpose()\n", + "\n", + "print(data.shape)" + ], + "metadata": { + "id": "Xvrrb_eL_DWh" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "input = [torch.tensor(ts[:1000]) for i, ts in enumerate(data)]\n", + "print(input[0].shape)\n", + "predictions = ch.model.predict(context=input, prediction_length=64, num_samples=10)" + ], + "metadata": { + "id": "ex5hbQEF_FBD" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "from gluonts.dataset.pandas import PandasDataset\n", + "from gluonts.dataset.split import split" + ], + "metadata": { + "id": "cfJ6LK1ONpMf" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "dataset = test_dataset.dataset\n", + "dataset = PandasDataset(dict(dataset))\n", + "train, test_template = split(dataset, offset=-128 + 20 * 64)\n", + "print(test_template)\n", + "print(len(dataset))\n", + "test_data = test_template.generate_instances(\n", + " prediction_length=64, windows=20, distance=64\n", + ")" + ], + "metadata": { + "id": "nYGxZbHnNrA-" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "print(test_data)" + ], + "metadata": { + "id": "5TAndQWYNswE" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "input_it = iter(test_data.input)\n", + "label_it = iter(test_data.label)\n", + "# inp = next(input_it)\n", + "# label = next(label_it)\n", + "print(inp)\n", + "print(label[\"target\"].shape)" + ], + "metadata": { + "id": "b_qowpKnNup2" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "for inp, label in zip(input_it, label_it):\n", + " print(inp[\"item_id\"], label[\"item_id\"], label[\"target\"].shape)" + ], + "metadata": { + "id": "uruJM9vqNxs-" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "print(test_dataset.dataset)" + ], + "metadata": { + "id": "PSDdqD-kNz7w" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "import torch\n", + "from torch.utils.data import DataLoader, Dataset\n", + "\n", + "\n", + "class ChronosWindowDataset(Dataset):\n", + " \"\"\"\n", + " A PyTorch Dataset for sliding window extraction from time series data.\n", + " \"\"\"\n", + "\n", + " def __init__(self, data, context_len, horizon_len, stride=-1):\n", + " \"\"\"\n", + " Initialize the dataset with sliding window logic.\n", + "\n", + " Args:\n", + " data (pd.DataFrame): The input time series data.\n", + " context_len (int): Length of the context window.\n", + " horizon_len (int): Length of the forecast horizon.\n", + " stride (int): Step size for sliding the window.\n", + " \"\"\"\n", + " self.data = data\n", + " self.context_len = context_len\n", + " self.horizon_len = horizon_len\n", + " self.total_len = context_len + horizon_len\n", + " self.stride = stride\n", + "\n", + " if self.stride == -1:\n", + " self.stride = self.horizon_len\n", + "\n", + " # Generate start indices for sliding windows\n", + " self.indices = [\n", + " start for start in range(0, len(data) - self.total_len + 1, self.stride)\n", + " ]\n", + "\n", + " def __len__(self):\n", + " return len(self.indices)\n", + "\n", + " def __getitem__(self, idx):\n", + " start = self.indices[idx]\n", + " window = self.data.iloc[start : start + self.total_len]\n", + "\n", + " # Extract context and actuals, and convert to Torch tensors\n", + " context = torch.tensor(\n", + " window.iloc[: self.context_len].to_numpy().transpose(), dtype=torch.float32\n", + " )\n", + " actual = torch.tensor(\n", + " window.iloc[self.context_len :].to_numpy().transpose(), dtype=torch.float32\n", + " )\n", + "\n", + " # # Return the input as a list of tensors (one for each column)\n", + " # input_list = [context[i] for i in range(context.shape[0])]\n", + "\n", + " return context, actual" + ], + "metadata": { + "id": "0_HeGsp7N0p2" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "test_data = test_dataset.dataset\n", + "print(test_data)" + ], + "metadata": { + "id": "d5gwi5DWN5OM" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "import torch\n", + "from torch.utils.data import Dataset\n", + "\n", + "# window_dataset = ChronosWindowDataset(data=test_data, context_len=128, horizon_len=64)\n", + "data_loader = DataLoader(test_data, batch_size=8, shuffle=False)" + ], + "metadata": { + "id": "gysL-YTAN61s" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "with torch.no_grad():\n", + " for i, (context, actual) in enumerate(data_loader):\n", + " print(context.shape)\n", + " print(actual.shape)" + ], + "metadata": { + "id": "vOnq2B7cN8ou" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "input, actual = next(iter(data_loader))\n", + "input = input.squeeze()\n", + "actual = actual.squeeze()\n", + "print(input.shape)\n", + "print(actual.shape)" + ], + "metadata": { + "id": "TRQ4SDaVN-Xy" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "input_stack = input.reshape(-1, 128)\n", + "print(input_stack.shape)" + ], + "metadata": { + "id": "-1KkT_1GOAL8" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "actual = actual.reshape(-1, 64)\n", + "print(actual.shape)" + ], + "metadata": { + "id": "Df5hiY9MOAlQ" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "predictions = ch.model.predict(\n", + " context=input_stack, prediction_length=64, num_samples=20\n", + ")" + ], + "metadata": { + "id": "9QeD4CI0ODBM" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "print(predictions.shape)" + ], + "metadata": { + "id": "Hb2d_xRYOGvA" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "print(predictions.shape)\n", + "pred_median = np.median(predictions, axis=1)\n", + "print(pred_median.shape)\n", + "pred_quantiles = np.quantile(\n", + " predictions, [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], axis=1\n", + ")\n", + "print(pred_quantiles.shape)" + ], + "metadata": { + "id": "xaGEglU5OJB1" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "actual = actual.reshape(-1, 64)\n", + "print(actual.shape)" + ], + "metadata": { + "id": "jwKg5EBtOKdF" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "mse1 = np.mean((actual.numpy() - pred_quantiles[4]) ** 2)\n", + "print(mse1)\n", + "mse1 = np.mean((actual.numpy() - pred_median) ** 2)\n", + "print(mse1)" + ], + "metadata": { + "id": "HsFP4XrkOMAB" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "print(ch.model.model.device)" + ], + "metadata": { + "id": "o2O76zOzON49" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "actual = actual.reshape(8, 7, 64)\n", + "pred_median = pred_median.reshape(8, 7, 64)\n", + "print(actual.shape)" + ], + "metadata": { + "id": "6vgEzR4aOPZD" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "mse1 = np.mean((actual - pred_median) ** 2)\n", + "print(mse1)" + ], + "metadata": { + "id": "-2_f2c-aOP8w" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "data = test_dataset.dataset\n", + "print(len(data.iloc[:10]))" + ], + "metadata": { + "id": "p2BmmGwPORRa" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file From f1775d4c6678fbec29f878aa12ea362a59303868 Mon Sep 17 00:00:00 2001 From: Showmick Das Date: Thu, 20 Feb 2025 02:19:51 -0500 Subject: [PATCH 07/17] Moving chronos_trial.ipynb to the example/colab directory --- chronos_trial.ipynb => example/colab/chronos_trial.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename chronos_trial.ipynb => example/colab/chronos_trial.ipynb (99%) diff --git a/chronos_trial.ipynb b/example/colab/chronos_trial.ipynb similarity index 99% rename from chronos_trial.ipynb rename to example/colab/chronos_trial.ipynb index fb9ae68..106a2c6 100644 --- a/chronos_trial.ipynb +++ b/example/colab/chronos_trial.ipynb @@ -642,4 +642,4 @@ "outputs": [] } ] -} \ No newline at end of file +} From 6432346bf4b689c9f1c2bf827d9b4cac30c39eb4 Mon Sep 17 00:00:00 2001 From: Showmick Das Date: Thu, 20 Feb 2025 02:26:01 -0500 Subject: [PATCH 08/17] Created using Colab --- moirai.ipynb | 187 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 187 insertions(+) create mode 100644 moirai.ipynb diff --git a/moirai.ipynb b/moirai.ipynb new file mode 100644 index 0000000..ba6941e --- /dev/null +++ b/moirai.ipynb @@ -0,0 +1,187 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "authorship_tag": "ABX9TyNMrNyISIi+DxlbMTxeYGbS", + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# **Install Dependencies**" + ], + "metadata": { + "id": "tpaQlDVlxetV" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qF8U27Si8bbj" + }, + "outputs": [], + "source": [ + "!pip install samay-0.1.0-cp311-cp311-linux_x86_64.whl" + ] + }, + { + "cell_type": "code", + "source": [ + "!pip install --upgrade -U numpy --force" + ], + "metadata": { + "id": "VGf7Oy_JxpYO" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Importing Requirements**" + ], + "metadata": { + "id": "0cbiLmxU2Lw-" + } + }, + { + "cell_type": "code", + "source": [ + "from tsfmproject.dataset import MoiraiDataset\n", + "from tsfmproject.model import MoiraiTSModel" + ], + "metadata": { + "id": "MAojvVF6PEGP" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Loading Dataset**" + ], + "metadata": { + "id": "Gt2KUswn2UtL" + } + }, + { + "cell_type": "code", + "source": [ + "!wget https://raw.githubusercontent.com/kage08/Samay/refs/heads/main/data/data/ETTh1.csv" + ], + "metadata": { + "id": "mnIsa2TgJgCN" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "train_dataset = MoiraiDataset(name=\"ett\", mode=\"train\", path='/nethome/sli999/TSFMProject/src/tsfmproject/models/moment/data/ETTh1.csv', datetime_col='date', freq='h', context_len=128, horizon_len=64, normalize=False)\n", + "test_dataset = MoiraiDataset(name=\"ett\", mode=\"test\", path='/nethome/sli999/TSFMProject/src/tsfmproject/models/moment/data/ETTh1.csv', datetime_col='date', freq='h', context_len=128, horizon_len=64, normalize=False)" + ], + "metadata": { + "id": "gX48JpxwPJwN" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "config = {\n", + " \"context_len\": 128,\n", + " \"horizon_len\": 64,\n", + "}\n", + "model_type = \"moirai-moe\"\n", + "model_size = \"small\"\n", + "moirai_model = MoiraiTSModel(model_type=model_type, model_size=model_size, config=config)" + ], + "metadata": { + "id": "BofOhe74PO-4" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "print(moirai_model.model.device)" + ], + "metadata": { + "id": "2ff_VDTnPS0o" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "eval_results, trues, preds, histories = moirai_model.evaluate(test_dataset, metrics=[\"MSE\", \"MASE\"])" + ], + "metadata": { + "id": "rPNns3q5PTNe" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "print(eval_results)" + ], + "metadata": { + "id": "9ZBsQ8gzPV6T" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "for i in range(len(histories)):\n", + " print(i, histories[i].shape)" + ], + "metadata": { + "id": "9DpENZS7PXen" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "print(len(test_dataset.data))" + ], + "metadata": { + "id": "fLoUWiqOPZAR" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file From 4677856989d9de4512a77d74293325ff3f571514 Mon Sep 17 00:00:00 2001 From: Showmick Das Date: Thu, 20 Feb 2025 02:26:32 -0500 Subject: [PATCH 09/17] Move moirai.ipynb to example/colab/ directory --- moirai.ipynb => example/colab/moirai.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename moirai.ipynb => example/colab/moirai.ipynb (99%) diff --git a/moirai.ipynb b/example/colab/moirai.ipynb similarity index 99% rename from moirai.ipynb rename to example/colab/moirai.ipynb index ba6941e..57dde82 100644 --- a/moirai.ipynb +++ b/example/colab/moirai.ipynb @@ -184,4 +184,4 @@ "outputs": [] } ] -} \ No newline at end of file +} From fd6e7f8487bc949d188a2afaecf09e4a5b93ac31 Mon Sep 17 00:00:00 2001 From: Showmick Das Date: Thu, 20 Feb 2025 02:32:13 -0500 Subject: [PATCH 10/17] Created using Colab --- moment_anomaly_detection.ipynb | 233 +++++++++++++++++++++++++++++++++ 1 file changed, 233 insertions(+) create mode 100644 moment_anomaly_detection.ipynb diff --git a/moment_anomaly_detection.ipynb b/moment_anomaly_detection.ipynb new file mode 100644 index 0000000..e73f93c --- /dev/null +++ b/moment_anomaly_detection.ipynb @@ -0,0 +1,233 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "authorship_tag": "ABX9TyPLZ1a9i7oPxMNhCvZBpzN+", + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# **Install Dependencies**" + ], + "metadata": { + "id": "tpaQlDVlxetV" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qF8U27Si8bbj" + }, + "outputs": [], + "source": [ + "!pip install samay-0.1.0-cp311-cp311-linux_x86_64.whl" + ] + }, + { + "cell_type": "code", + "source": [ + "!pip install --upgrade -U numpy --force" + ], + "metadata": { + "id": "VGf7Oy_JxpYO" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Loading Moment Model**" + ], + "metadata": { + "id": "0cbiLmxU2Lw-" + } + }, + { + "cell_type": "code", + "source": [ + "from samay.model import MomentModel\n", + "from samay.dataset import MomentDataset\n", + "from samay.utils import load_args\n", + "\n", + "repo = \"AutonLab/MOMENT-1-large\"\n", + "config = {\n", + " \"task_name\": \"reconstruction\",\n", + "}\n", + "mmt = MomentModel(config=config, repo=repo)" + ], + "metadata": { + "id": "aQSjnUjKQqlE" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Loading Dataset**" + ], + "metadata": { + "id": "Gt2KUswn2UtL" + } + }, + { + "cell_type": "code", + "source": [ + "!wget https://raw.githubusercontent.com/kage08/Samay/refs/heads/main/data/data/ETTh1.csv" + ], + "metadata": { + "id": "mnIsa2TgJgCN" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Zero-Shot Anomaly Detection Using Moment Model**" + ], + "metadata": { + "id": "CHgBRG2AQ4vD" + } + }, + { + "cell_type": "code", + "source": [ + "train_dataset = MomentDataset(name=\"ett\", path='../src/tsfmproject/models/moment/data/198_UCR_Anomaly_tiltAPB2_50000_124159_124985.out',\n", + " mode=\"train\", boundaries=[50000, 50000, 0], task_name='detection', stride=512)\n", + "test_dataset = MomentDataset(name=\"ett\", path='../src/tsfmproject/models/moment/data/198_UCR_Anomaly_tiltAPB2_50000_124159_124985.out',\n", + " mode=\"test\", boundaries=[50000, 50000, 0], task_name='detection', stride=512)\n", + "# print(len(train_dataset))\n", + "# print(len(test_dataset))\n", + "# trues, preds, labels = mmt.evaluate(test_dataset, task_name='detection')\n", + "mmt.plot(test_dataset, task_name='detection',)" + ], + "metadata": { + "id": "-GnUpadKQ2pz" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Visualization of Zero-Shot Anomaly Detection**" + ], + "metadata": { + "id": "rmBuetJARAR_" + } + }, + { + "cell_type": "code", + "source": [ + "# import numpy as np\n", + "# import matplotlib.pyplot as plt\n", + "# from tsfmproject.models.moment.momentfm.utils.anomaly_detection_metrics import adjbestf1\n", + "\n", + "# # We will use the Mean Squared Error (MSE) between the observed values and MOMENT's predictions as the anomaly score\n", + "# anomaly_scores = (trues - preds)**2\n", + "\n", + "# print(f\"Zero-shot Adjusted Best F1 Score: {adjbestf1(y_true=labels, y_scores=anomaly_scores)}\")\n", + "\n", + "# anomaly_start = 74158\n", + "# anomaly_end = 74984\n", + "# start = anomaly_start-512\n", + "# end = anomaly_end+512\n", + "\n", + "# plt.plot(trues[start:end], label=\"Observed\", c='darkblue')\n", + "# plt.plot(preds[start:end], label=\"Predicted\", c='red')\n", + "# plt.plot(anomaly_scores[start:end], label=\"Anomaly Score\", c='black')\n", + "# plt.legend(fontsize=16)\n", + "# plt.show()" + ], + "metadata": { + "id": "SDspECriRDvN" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **FInetune the Moment model for anomaly detection**" + ], + "metadata": { + "id": "4woLdKIERIUf" + } + }, + { + "cell_type": "code", + "source": [ + "finetuned_model = mmt.finetune(\n", + " train_dataset, task_name=\"detection\", mask_ratio=0.1, epoch=5\n", + ")" + ], + "metadata": { + "id": "Dw7OexJ_RGGH" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Evaluate the finetuned model and Visualization**" + ], + "metadata": { + "id": "HTva4_cARNWY" + } + }, + { + "cell_type": "code", + "source": [ + "trues, preds, labels = mmt.evaluate(test_dataset, task_name=\"detection\")\n", + "\n", + "anomaly_scores = (trues - preds) ** 2\n", + "\n", + "print(\n", + " f\"Zero-shot Adjusted Best F1 Score: {adjbestf1(y_true=labels, y_scores=anomaly_scores)}\"\n", + ")\n", + "\n", + "anomaly_start = 74158\n", + "anomaly_end = 74984\n", + "start = anomaly_start - 512\n", + "end = anomaly_end + 512\n", + "\n", + "plt.plot(trues[start:end], label=\"Observed\", c=\"darkblue\")\n", + "plt.plot(preds[start:end], label=\"Predicted\", c=\"red\")\n", + "plt.plot(anomaly_scores[start:end], label=\"Anomaly Score\", c=\"black\")\n", + "plt.legend(fontsize=16)\n", + "plt.show()" + ], + "metadata": { + "id": "Mkk6GyTyRQX8" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file From 79a0e9f41f2a676f02a0d88a75bc5ab99c403ac1 Mon Sep 17 00:00:00 2001 From: Showmick Das Date: Thu, 20 Feb 2025 02:33:13 -0500 Subject: [PATCH 11/17] Moving moment_anomaly_detection.ipynb to example/colab/ directory --- .../colab/moment_anomaly_detection.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename moment_anomaly_detection.ipynb => example/colab/moment_anomaly_detection.ipynb (99%) diff --git a/moment_anomaly_detection.ipynb b/example/colab/moment_anomaly_detection.ipynb similarity index 99% rename from moment_anomaly_detection.ipynb rename to example/colab/moment_anomaly_detection.ipynb index e73f93c..2ed717b 100644 --- a/moment_anomaly_detection.ipynb +++ b/example/colab/moment_anomaly_detection.ipynb @@ -230,4 +230,4 @@ "outputs": [] } ] -} \ No newline at end of file +} From cbfb988c4c1db95d0debae1a24be89704e1b0972 Mon Sep 17 00:00:00 2001 From: Showmick Das Date: Thu, 20 Feb 2025 02:38:45 -0500 Subject: [PATCH 12/17] Created using Colab --- moment_classification.ipynb | 236 ++++++++++++++++++++++++++++++++++++ 1 file changed, 236 insertions(+) create mode 100644 moment_classification.ipynb diff --git a/moment_classification.ipynb b/moment_classification.ipynb new file mode 100644 index 0000000..8c330a0 --- /dev/null +++ b/moment_classification.ipynb @@ -0,0 +1,236 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "authorship_tag": "ABX9TyNoSDnZw0e+6sXWq4gOcrTv", + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# **Install Dependencies**" + ], + "metadata": { + "id": "tpaQlDVlxetV" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qF8U27Si8bbj" + }, + "outputs": [], + "source": [ + "!pip install samay-0.1.0-cp311-cp311-linux_x86_64.whl" + ] + }, + { + "cell_type": "code", + "source": [ + "!pip install --upgrade -U numpy --force" + ], + "metadata": { + "id": "VGf7Oy_JxpYO" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Loading Moment Model**" + ], + "metadata": { + "id": "0cbiLmxU2Lw-" + } + }, + { + "cell_type": "code", + "source": [ + "from samay.model import MomentModel\n", + "from samay.dataset import MomentDataset\n", + "from samay.models.moment.momentfm.models.statistical_classifiers import fit_svm\n", + "from samay.utils import load_args\n", + "\n", + "repo = \"AutonLab/MOMENT-1-large\"\n", + "config = {\"task_name\": \"classification\", \"n_channels\": 1, \"num_class\": 5}\n", + "mmt = MomentModel(config=config, repo=repo)" + ], + "metadata": { + "id": "9KSGbm9r9wVy" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Loading Dataset**" + ], + "metadata": { + "id": "n9Nl5kajSVlc" + } + }, + { + "cell_type": "code", + "source": [ + "!wget https://raw.githubusercontent.com/kage08/Samay/refs/heads/main/data/data/ETTh1.csv" + ], + "metadata": { + "id": "mnIsa2TgJgCN" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Use a SVM classifier to evaluate zero-shot embeddings of Moment model**" + ], + "metadata": { + "id": "4k7hDhXGScrx" + } + }, + { + "cell_type": "code", + "source": [ + "train_dataset = MomentDataset(\n", + " name=\"ecg5000\",\n", + " path=\"/nethome/sli999/TSFMProject/src/tsfmproject/models/moment/data/ECG5000_TRAIN.csv\",\n", + " batchsize=64,\n", + " mode=\"train\",\n", + " task_name=\"classification\",\n", + ")\n", + "test_dataset = MomentDataset(\n", + " name=\"ecg5000\",\n", + " path=\"/nethome/sli999/TSFMProject/src/tsfmproject/models/moment/data/ECG5000_TEST.csv\",\n", + " batchsize=64,\n", + " mode=\"test\",\n", + " task_name=\"classification\",\n", + ")\n", + "\n", + "train_accuracy, train_embeddings, train_labels = mmt.evaluate(\n", + " train_dataset, task_name=\"classification\"\n", + ")\n", + "test_accuracy, test_embeddings, test_labels = mmt.evaluate(\n", + " test_dataset, task_name=\"classification\"\n", + ")\n", + "print(train_embeddings.shape, train_labels.shape)\n", + "\n", + "clf = fit_svm(features=train_embeddings, y=train_labels)\n", + "\n", + "y_pred_train = clf.predict(train_embeddings)\n", + "y_pred_test = clf.predict(test_embeddings)\n", + "train_accuracy = clf.score(train_embeddings, train_labels)\n", + "test_accuracy = clf.score(test_embeddings, test_labels)\n", + "\n", + "print(f\"Train accuracy: {train_accuracy:.2f}\")\n", + "print(f\"Test accuracy: {test_accuracy:.2f}\")" + ], + "metadata": { + "id": "1BcqPINMSf-M" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Visualize the embeddings**" + ], + "metadata": { + "id": "vFe6nzXMSjaT" + } + }, + { + "cell_type": "code", + "source": [ + "from sklearn.decomposition import PCA\n", + "import matplotlib.pyplot as plt\n", + "\n", + "\n", + "test_embeddings_manifold = PCA(n_components=2).fit_transform(test_embeddings)\n", + "\n", + "plt.title(\"ECG5000 Test Embeddings\", fontsize=20)\n", + "plt.scatter(\n", + " test_embeddings_manifold[:, 0],\n", + " test_embeddings_manifold[:, 1],\n", + " c=test_labels.squeeze(),\n", + ")\n", + "plt.xticks(fontsize=16)\n", + "plt.yticks(fontsize=16)\n", + "plt.show()" + ], + "metadata": { + "id": "eTir0t5MSlt9" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Finetune Moment model for classification**" + ], + "metadata": { + "id": "pHu3gQT_SoBO" + } + }, + { + "cell_type": "code", + "source": [ + "finetuned_model = mmt.finetune(\n", + " train_dataset, task_name=\"classification\", epoch=10, lr=0.1\n", + ")" + ], + "metadata": { + "id": "6arKH_p4Sq-Z" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Evaluate the finetuned model**" + ], + "metadata": { + "id": "3cJ8l9v6SuuB" + } + }, + { + "cell_type": "code", + "source": [ + "accuracy, embeddings, lebels = mmt.evaluate(test_dataset, task_name=\"classification\")\n", + "print(accuracy)" + ], + "metadata": { + "id": "Dk5iOBAHSsqF" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file From f7cbf19fd2254bcdf4b449e301640811af5c323a Mon Sep 17 00:00:00 2001 From: Showmick Das Date: Thu, 20 Feb 2025 02:39:11 -0500 Subject: [PATCH 13/17] Move moment_classification.ipynb to example/colab/ directory --- .../colab/moment_classification.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename moment_classification.ipynb => example/colab/moment_classification.ipynb (99%) diff --git a/moment_classification.ipynb b/example/colab/moment_classification.ipynb similarity index 99% rename from moment_classification.ipynb rename to example/colab/moment_classification.ipynb index 8c330a0..88fa7e5 100644 --- a/moment_classification.ipynb +++ b/example/colab/moment_classification.ipynb @@ -233,4 +233,4 @@ "outputs": [] } ] -} \ No newline at end of file +} From 8859cc2f37995aa94001f5ecf7f0261bc9586d8e Mon Sep 17 00:00:00 2001 From: Showmick Das Date: Thu, 20 Feb 2025 02:46:48 -0500 Subject: [PATCH 14/17] Created using Colab --- moment_forecasting.ipynb | 218 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 218 insertions(+) create mode 100644 moment_forecasting.ipynb diff --git a/moment_forecasting.ipynb b/moment_forecasting.ipynb new file mode 100644 index 0000000..d26713f --- /dev/null +++ b/moment_forecasting.ipynb @@ -0,0 +1,218 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "authorship_tag": "ABX9TyMOjbk5TzySlGkxItm+Z25O", + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# **Install Dependencies**" + ], + "metadata": { + "id": "tpaQlDVlxetV" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qF8U27Si8bbj" + }, + "outputs": [], + "source": [ + "!pip install samay-0.1.0-cp311-cp311-linux_x86_64.whl" + ] + }, + { + "cell_type": "code", + "source": [ + "!pip install --upgrade -U numpy --force" + ], + "metadata": { + "id": "VGf7Oy_JxpYO" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Loading Moment Model**" + ], + "metadata": { + "id": "0cbiLmxU2Lw-" + } + }, + { + "cell_type": "code", + "source": [ + "from samay.model import MomentModel\n", + "from samay.dataset import MomentDataset\n", + "from samay.utils import load_args\n", + "\n", + "repo = \"AutonLab/MOMENT-1-large\"\n", + "config = {\n", + " \"task_name\": \"forecasting\",\n", + " \"forecast_horizon\": 192,\n", + " \"head_dropout\": 0.1,\n", + " \"weight_decay\": 0,\n", + " \"freeze_encoder\": True, # Freeze the patch embedding layer\n", + " \"freeze_embedder\": True, # Freeze the transformer encoder\n", + " \"freeze_head\": False, # The linear forecasting head must be trained\n", + "}\n", + "mmt = MomentModel(config=config, repo=repo)" + ], + "metadata": { + "id": "MzAU2V_tTXSB" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Loading Dataset**" + ], + "metadata": { + "id": "Gt2KUswn2UtL" + } + }, + { + "cell_type": "code", + "source": [ + "!wget https://raw.githubusercontent.com/kage08/Samay/refs/heads/main/data/data/ETTh1.csv" + ], + "metadata": { + "id": "mnIsa2TgJgCN" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Finetune Moment Model on the ETT dataset**" + ], + "metadata": { + "id": "49ihlnzET33r" + } + }, + { + "cell_type": "code", + "source": [ + "train_dataset = MomentDataset(name=\"ett\", datetime_col='date', path='../src/tsfmproject/models/moment/data/ETTh1.csv',\n", + " mode='train', horizon_len=192)\n", + "# dtl = train_dataset.get_data_loader()\n", + "\n", + "val_dataset = MomentDataset(name=\"ett\", datetime_col='date', path='../src/tsfmproject/models/moment/data/ETTh1.csv',\n", + " mode='test', horizon_len=192)\n", + "# path = '../src/tsfmproject/models/moment/data/ETTh1.csv'\n", + "\n", + "# dataset = MomentDataset(name=\"ett\", datetime_col='date', path=path,\n", + "# mode='train', horizon=192)\n", + "\n", + "finetuned_model = mmt.finetune(train_dataset, task_name=\"forecasting\")\n", + "mmt.evaluate(val_dataset, task_name=\"forecasting\")" + ], + "metadata": { + "id": "NlPWSR-VT7Zm" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Test the Finetuned Model**" + ], + "metadata": { + "id": "jVTskCB-UKPp" + } + }, + { + "cell_type": "code", + "source": [ + "# avg_loss, trues, preds, histories = mmt.evaluate(val_dataset, task_name='forecasting')\n", + "# print(\"Validation loss:\", avg_loss)\n", + "mmt.plot(val_dataset, task_name='forecasting')" + ], + "metadata": { + "id": "keWlMzyCUQJ3" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Visualization of the evaluation**" + ], + "metadata": { + "id": "Lkrbk9zvUS04" + } + }, + { + "cell_type": "code", + "source": [ + "# import matplotlib.pyplot as plt\n", + "\n", + "# # Pick a random channel and time index\n", + "# trues = np.array(trues)\n", + "# preds = np.array(preds)\n", + "# histories = np.array(histories)\n", + "# channel_idx = np.random.randint(0, 7)\n", + "# time_index = np.random.randint(0, trues.shape[0])\n", + "\n", + "# history = histories[time_index, channel_idx, :]\n", + "# true = trues[time_index, channel_idx, :]\n", + "# pred = preds[time_index, channel_idx, :]\n", + "\n", + "# plt.figure(figsize=(12, 4))\n", + "\n", + "# # Plotting the first time series from history\n", + "# plt.plot(range(len(history)), history, label='History (512 timesteps)', c='darkblue')\n", + "\n", + "# # Plotting ground truth and prediction\n", + "# num_forecasts = len(true)\n", + "\n", + "# offset = len(history)\n", + "# plt.plot(range(offset, offset + len(true)), true, label='Ground Truth (192 timesteps)', color='darkblue', linestyle='--', alpha=0.5)\n", + "# plt.plot(range(offset, offset + len(pred)), pred, label='Forecast (192 timesteps)', color='red', linestyle='--')\n", + "\n", + "# plt.title(f\"ETTh1 (Hourly) -- (idx={time_index}, channel={channel_idx})\", fontsize=18)\n", + "# plt.xlabel('Time', fontsize=14)\n", + "# plt.ylabel('Value', fontsize=14)\n", + "# plt.legend(fontsize=14)\n", + "# plt.show()" + ], + "metadata": { + "id": "o7F-UCfnUUtF" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file From c0a2c3c7377fa8a208a1608a7c733e8963fceb23 Mon Sep 17 00:00:00 2001 From: Showmick Das Date: Thu, 20 Feb 2025 02:47:13 -0500 Subject: [PATCH 15/17] Move to moment_forecasting.ipynb to example/colab/ directory --- .../colab/moment_forecasting.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename moment_forecasting.ipynb => example/colab/moment_forecasting.ipynb (99%) diff --git a/moment_forecasting.ipynb b/example/colab/moment_forecasting.ipynb similarity index 99% rename from moment_forecasting.ipynb rename to example/colab/moment_forecasting.ipynb index d26713f..68f4fbb 100644 --- a/moment_forecasting.ipynb +++ b/example/colab/moment_forecasting.ipynb @@ -215,4 +215,4 @@ "outputs": [] } ] -} \ No newline at end of file +} From 8ffc41449e78cd6e0d1baac15ef05a8d182ff1f6 Mon Sep 17 00:00:00 2001 From: Showmick Das Date: Thu, 20 Feb 2025 02:50:33 -0500 Subject: [PATCH 16/17] Created using Colab --- moment_imputation.ipynb | 240 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 240 insertions(+) create mode 100644 moment_imputation.ipynb diff --git a/moment_imputation.ipynb b/moment_imputation.ipynb new file mode 100644 index 0000000..1194c11 --- /dev/null +++ b/moment_imputation.ipynb @@ -0,0 +1,240 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "authorship_tag": "ABX9TyOEZ3bu8VQnUuKZKhRj8BZu", + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# **Install Dependencies**" + ], + "metadata": { + "id": "tpaQlDVlxetV" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qF8U27Si8bbj" + }, + "outputs": [], + "source": [ + "!pip install samay-0.1.0-cp311-cp311-linux_x86_64.whl" + ] + }, + { + "cell_type": "code", + "source": [ + "!pip install --upgrade -U numpy --force" + ], + "metadata": { + "id": "VGf7Oy_JxpYO" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Loading Moment Model**" + ], + "metadata": { + "id": "0cbiLmxU2Lw-" + } + }, + { + "cell_type": "code", + "source": [ + "from samay.model import MomentModel\n", + "from samay.dataset import MomentDataset\n", + "from samay.utils import load_args\n", + "\n", + "repo = \"AutonLab/MOMENT-1-large\"\n", + "config = {\n", + " \"task_name\": \"reconstruction\",\n", + "}\n", + "mmt = MomentModel(config=config, repo=repo)" + ], + "metadata": { + "id": "EKfkmPeNVEhs" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Loading Dataset**" + ], + "metadata": { + "id": "Gt2KUswn2UtL" + } + }, + { + "cell_type": "code", + "source": [ + "!wget https://raw.githubusercontent.com/kage08/Samay/refs/heads/main/data/data/ETTh1.csv" + ], + "metadata": { + "id": "mnIsa2TgJgCN" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Zero-Shot Imputation Using Moment Model**" + ], + "metadata": { + "id": "G2i1_KoTVHXc" + } + }, + { + "cell_type": "code", + "source": [ + "train_dataset = MomentDataset(\n", + " name=\"ett\",\n", + " datetime_col=\"date\",\n", + " path=\"/nethome/sli999/TSFMProject/src/tsfmproject/models/moment/data/ETTh1.csv\",\n", + " mode=\"train\",\n", + " task_name=\"imputation\",\n", + ")\n", + "test_dataset = MomentDataset(\n", + " name=\"ett\",\n", + " datetime_col=\"date\",\n", + " path=\"/nethome/sli999/TSFMProject/src/tsfmproject/models/moment/data/ETTh1.csv\",\n", + " mode=\"test\",\n", + " task_name=\"imputation\",\n", + ")\n", + "# print(len(train_dataset))\n", + "# print(len(test_dataset))\n", + "# trues, preds, masks = mmt.evaluate(test_dataset, task_name='imputation')" + ], + "metadata": { + "id": "WhjXYZLsVJr1" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Visualization of Zero-Shot Imputation**" + ], + "metadata": { + "id": "1UEUE0QGVSdA" + } + }, + { + "cell_type": "code", + "source": [ + "# import matplotlib.pyplot as plt\n", + "\n", + "# print(trues.shape, preds.shape, masks.shape)\n", + "\n", + "# mse = np.mean((trues[masks==0] - preds[masks==0])**2)\n", + "# mae = np.mean(np.abs(trues[masks==0] - preds[masks==0]))\n", + "# print(f'MSE: {mse}, MAE: {mae}')\n", + "\n", + "# idx = np.random.randint(trues.shape[0])\n", + "# channel_idx = np.random.randint(trues.shape[1])\n", + "\n", + "# fig, axs = plt.subplots(2, 1, figsize=(10, 5))\n", + "# axs[0].set_title(f\"Channel={channel_idx}\")\n", + "# axs[0].plot(trues[idx, channel_idx, :].squeeze(), label='Ground Truth', c='darkblue')\n", + "# axs[0].plot(preds[idx, channel_idx, :].squeeze(), label='Predictions', c='red')\n", + "# axs[0].legend(fontsize=16)\n", + "\n", + "# axs[1].imshow(np.tile(masks[np.newaxis, idx, channel_idx], reps=(8, 1)), cmap='binary')\n", + "# plt.show()\n", + "mmt.plot(test_dataset, task_name='imputation')" + ], + "metadata": { + "id": "ldr2IIb1VWGj" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Finetune the Moment Model for Imputation**" + ], + "metadata": { + "id": "kShP864XVZoB" + } + }, + { + "cell_type": "code", + "source": [ + "fintuned_model = mmt.finetune(train_dataset, task_name=\"imputation\")" + ], + "metadata": { + "id": "a4aYBmjDVcSx" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Evaluate the finetuned model and visualization**" + ], + "metadata": { + "id": "Hz1cVDpVVd0S" + } + }, + { + "cell_type": "code", + "source": [ + "# trues, preds, masks = mmt.evaluate(test_dataset, task_name='imputation')\n", + "# mse = np.mean((trues[masks==0] - preds[masks==0])**2)\n", + "# mae = np.mean(np.abs(trues[masks==0] - preds[masks==0]))\n", + "# print(f'MSE: {mse}, MAE: {mae}')\n", + "\n", + "# idx = np.random.randint(trues.shape[0])\n", + "# channel_idx = np.random.randint(trues.shape[1])\n", + "\n", + "# fig, axs = plt.subplots(2, 1, figsize=(10, 5))\n", + "# axs[0].set_title(f\"Channel={channel_idx}\")\n", + "# axs[0].plot(trues[idx, channel_idx, :].squeeze(), label='Ground Truth', c='darkblue')\n", + "# axs[0].plot(preds[idx, channel_idx, :].squeeze(), label='Predictions', c='red')\n", + "# axs[0].legend(fontsize=16)\n", + "\n", + "# axs[1].imshow(np.tile(masks[np.newaxis, idx, channel_idx], reps=(8, 1)), cmap='binary')\n", + "# plt.show()\n", + "mmt.plot(test_dataset, task_name='imputation')" + ], + "metadata": { + "id": "Yi0FjTCVVflk" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file From 1fd081428b6e5350b473248e3d1d1616e3853a77 Mon Sep 17 00:00:00 2001 From: Showmick Das Date: Thu, 20 Feb 2025 02:50:59 -0500 Subject: [PATCH 17/17] Move moment_imputation.ipynb to example/colab/ directory --- .../colab/moment_imputation.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename moment_imputation.ipynb => example/colab/moment_imputation.ipynb (99%) diff --git a/moment_imputation.ipynb b/example/colab/moment_imputation.ipynb similarity index 99% rename from moment_imputation.ipynb rename to example/colab/moment_imputation.ipynb index 1194c11..c7a247b 100644 --- a/moment_imputation.ipynb +++ b/example/colab/moment_imputation.ipynb @@ -237,4 +237,4 @@ "outputs": [] } ] -} \ No newline at end of file +}