diff --git a/example/colab/chronos.ipynb b/example/colab/chronos.ipynb new file mode 100644 index 0000000..b47ff4c --- /dev/null +++ b/example/colab/chronos.ipynb @@ -0,0 +1,213 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "authorship_tag": "ABX9TyOHc9Zna3DRB0e4ryzjGrBk", + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# **Install Dependencies**" + ], + "metadata": { + "id": "tpaQlDVlxetV" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "259--YZOr1Lo" + }, + "outputs": [], + "source": [ + "!pip install samay-0.1.0-cp311-cp311-linux_x86_64.whl" + ] + }, + { + "cell_type": "code", + "source": [ + "!pip install --upgrade -U numpy --force" + ], + "metadata": { + "id": "VGf7Oy_JxpYO" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Importing Requirements**" + ], + "metadata": { + "id": "0cbiLmxU2Lw-" + } + }, + { + "cell_type": "code", + "source": [ + "import os\n", + "import sys\n", + "import torch\n", + "import numpy as np\n", + "\n", + "src_path = os.path.abspath(os.path.join(\"..\", \"src\"))\n", + "if src_path not in sys.path:\n", + " sys.path.insert(0, src_path)\n", + "\n", + "from samay.model import ChronosModel\n", + "from samay.dataset import ChronosDataset\n", + "# from tsfmproject.utils import load_args\n", + "\n", + "# arg_path = \"../config/timesfm.json\"\n", + "# args = load_args(arg_path)\n", + "repo = \"amazon/chronos-t5-small\"\n", + "chronos_model = ChronosModel(repo=repo)" + ], + "metadata": { + "id": "APfhHKAU1qM3" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Loading Dataset**" + ], + "metadata": { + "id": "Gt2KUswn2UtL" + } + }, + { + "cell_type": "code", + "source": [ + "!wget https://raw.githubusercontent.com/kage08/Samay/refs/heads/main/data/data/ETTh1.csv" + ], + "metadata": { + "id": "mnIsa2TgJgCN" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "train_dataset = ChronosDataset(name=\"ett\", datetime_col='date', path='../src/tsfmproject/models/moment/data/ETTh1.csv',\n", + " mode='train', batch_size=8)\n", + "val_dataset = ChronosDataset(name=\"ett\", datetime_col='date', path='../src/tsfmproject/models/moment/data/ETTh1.csv',\n", + " mode='test', batch_size=8)" + ], + "metadata": { + "id": "MyUXzOME1sgo" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Visualize the zero-shot forecasting**" + ], + "metadata": { + "id": "OgcxscvD3Ee8" + } + }, + { + "cell_type": "code", + "source": [ + "chronos_model.plot(val_dataset, horizon_len=64, quantile_levels=[0.1, 0.5, 0.9])" + ], + "metadata": { + "id": "K9Zmx3li1u5k" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Evaluate the zero-shot Chronos Model**" + ], + "metadata": { + "id": "2I0VE7Uu3Jqt" + } + }, + { + "cell_type": "code", + "source": [ + "metrics = chronos_model.evaluate(val_dataset, horizon_len=64, quantile_levels=[0.1, 0.5, 0.9])\n", + "print(metrics)" + ], + "metadata": { + "id": "Z90jyPsp1ykH" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Finetune Chronos Model on the ETT dataset**" + ], + "metadata": { + "id": "iuzgw9uA3idf" + } + }, + { + "cell_type": "code", + "source": [ + "chronos_model.finetune(train_dataset)" + ], + "metadata": { + "id": "LeFA_RVg12IS" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Evaluate the Finetuned Chronos Model**" + ], + "metadata": { + "id": "DE5-HDC53n4_" + } + }, + { + "cell_type": "code", + "source": [ + "metrics = chronos_model.evaluate(val_dataset, horizon_len=64, quantile_levels=[0.1, 0.5, 0.9])\n", + "print(metrics)" + ], + "metadata": { + "id": "XF2bgAXi125y" + }, + "execution_count": null, + "outputs": [] + } + ] +} diff --git a/example/colab/chronos_trial.ipynb b/example/colab/chronos_trial.ipynb new file mode 100644 index 0000000..106a2c6 --- /dev/null +++ b/example/colab/chronos_trial.ipynb @@ -0,0 +1,645 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "authorship_tag": "ABX9TyPxvN1ZBJ3PXV+dtaz3/lQX", + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# **Install Dependencies**" + ], + "metadata": { + "id": "tpaQlDVlxetV" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qF8U27Si8bbj" + }, + "outputs": [], + "source": [ + "!pip install samay-0.1.0-cp311-cp311-linux_x86_64.whl" + ] + }, + { + "cell_type": "code", + "source": [ + "!pip install --upgrade -U numpy --force" + ], + "metadata": { + "id": "VGf7Oy_JxpYO" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Importing Requirements**" + ], + "metadata": { + "id": "0cbiLmxU2Lw-" + } + }, + { + "cell_type": "code", + "source": [ + "import os\n", + "import sys\n", + "\n", + "import numpy as np\n", + "import torch\n", + "\n", + "src_path = os.path.abspath(os.path.join(\"src\"))\n", + "if src_path not in sys.path:\n", + " sys.path.insert(0, src_path)\n", + "\n", + "print(sys.path)" + ], + "metadata": { + "id": "9KSGbm9r9wVy" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "from samay.dataset import ChronosDataset\n", + "from samay.model import ChronosModel\n", + "from samay.visualization import ForecastVisualization" + ], + "metadata": { + "id": "0-Z4X0kdNPR8" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Loading Dataset**" + ], + "metadata": { + "id": "Gt2KUswn2UtL" + } + }, + { + "cell_type": "code", + "source": [ + "!wget https://raw.githubusercontent.com/kage08/Samay/refs/heads/main/data/data/ETTh1.csv" + ], + "metadata": { + "id": "mnIsa2TgJgCN" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "train_dataset = ChronosDataset(\n", + " name=\"ett\",\n", + " mode=\"train\",\n", + " path=\"/nethome/sli999/TSFMProject/src/tsfmproject/models/moment/data/ETTh1.csv\",\n", + " datetime_col=\"date\",\n", + " freq=\"h\",\n", + " context_len=128,\n", + " horizon_len=64,\n", + ")\n", + "test_dataset = ChronosDataset(\n", + " name=\"ett\",\n", + " mode=\"test\",\n", + " path=\"/nethome/sli999/TSFMProject/src/tsfmproject/models/moment/data/ETTh1.csv\",\n", + " datetime_col=\"date\",\n", + " freq=\"h\",\n", + " context_len=128,\n", + " horizon_len=64,\n", + ")\n", + "# train_dataset = ChronosDataset(name=\"ett\", mode=\"train\", path='/nethome/abhalerao9/TIMESERIESMODELING/TSFMProject/data/dataset/timesfm_covid_pivot.csv', datetime_col='ds', freq='D', context_len=64, horizon_len=16)\n", + "# test_dataset = ChronosDataset(name=\"ett\", mode=\"test\", path='/nethome/abhalerao9/TIMESERIESMODELING/TSFMProject/data/dataset/timesfm_covid_pivot.csv', datetime_col='ds', freq='D', context_len=64, horizon_len=16)\n", + "print(len(test_dataset.dataset))\n", + "# print(test_dataset.dataset.shape)" + ], + "metadata": { + "id": "UoMPUycM-uX4" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Loading the Chronos Model**" + ], + "metadata": { + "id": "OgcxscvD3Ee8" + } + }, + { + "cell_type": "code", + "source": [ + "repo = \"amazon/chronos-t5-small\"\n", + "ch = ChronosModel(config=None, repo=repo)\n", + "ch.load_model()" + ], + "metadata": { + "id": "yfcJVg7d-xE0" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "print(ch.model.model.device)" + ], + "metadata": { + "id": "lt5Scq3e-zTu" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "eval_results, trues, preds, histories = ch.evaluate(\n", + " test_dataset, batch_size=8, metrics=[\"MSE\", \"MASE\"]\n", + ")\n", + "print(eval_results)\n", + "# visualization = ForecastVisualization(trues, preds[:,:,1,:], histories)\n", + "# visualization.plot()" + ], + "metadata": { + "id": "Jsmfu_hS-1S-" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "visualization = ForecastVisualization(trues, preds, histories)\n", + "visualization.plot(channel_idx=0, time_idx=0)" + ], + "metadata": { + "id": "Dk_Rkgvd-373" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "print(trues.shape)\n", + "print(preds.shape)\n", + "print(histories.shape)" + ], + "metadata": { + "id": "k-Y5T3qk-6RV" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "ch.finetune(train_dataset)" + ], + "metadata": { + "id": "TWElWIOD-75d" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "latest_run_dir = ch.get_latest_run_dir()\n", + "model_dir = os.path.join(latest_run_dir, \"checkpoint-final\")\n", + "model_type = \"seq2seq\"\n", + "model = ch.load_model(model_dir, model_type)" + ], + "metadata": { + "id": "X4FyWD2q-9_-" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "eval_results, trues, preds, histories = ch.evaluate(\n", + " test_dataset, batch_size=8, metrics=[\"MSE\", \"MASE\"]\n", + ")\n", + "print(eval_results)" + ], + "metadata": { + "id": "UjsDeYg9-_0T" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "visualization = ForecastVisualization(trues, preds, histories)\n", + "visualization.plot(channel_idx=0, time_index=0)" + ], + "metadata": { + "id": "SumTaMAd_B2d" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "data = test_dataset.dataset\n", + "data = np.array(data).transpose()\n", + "\n", + "print(data.shape)" + ], + "metadata": { + "id": "Xvrrb_eL_DWh" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "input = [torch.tensor(ts[:1000]) for i, ts in enumerate(data)]\n", + "print(input[0].shape)\n", + "predictions = ch.model.predict(context=input, prediction_length=64, num_samples=10)" + ], + "metadata": { + "id": "ex5hbQEF_FBD" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "from gluonts.dataset.pandas import PandasDataset\n", + "from gluonts.dataset.split import split" + ], + "metadata": { + "id": "cfJ6LK1ONpMf" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "dataset = test_dataset.dataset\n", + "dataset = PandasDataset(dict(dataset))\n", + "train, test_template = split(dataset, offset=-128 + 20 * 64)\n", + "print(test_template)\n", + "print(len(dataset))\n", + "test_data = test_template.generate_instances(\n", + " prediction_length=64, windows=20, distance=64\n", + ")" + ], + "metadata": { + "id": "nYGxZbHnNrA-" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "print(test_data)" + ], + "metadata": { + "id": "5TAndQWYNswE" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "input_it = iter(test_data.input)\n", + "label_it = iter(test_data.label)\n", + "# inp = next(input_it)\n", + "# label = next(label_it)\n", + "print(inp)\n", + "print(label[\"target\"].shape)" + ], + "metadata": { + "id": "b_qowpKnNup2" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "for inp, label in zip(input_it, label_it):\n", + " print(inp[\"item_id\"], label[\"item_id\"], label[\"target\"].shape)" + ], + "metadata": { + "id": "uruJM9vqNxs-" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "print(test_dataset.dataset)" + ], + "metadata": { + "id": "PSDdqD-kNz7w" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "import torch\n", + "from torch.utils.data import DataLoader, Dataset\n", + "\n", + "\n", + "class ChronosWindowDataset(Dataset):\n", + " \"\"\"\n", + " A PyTorch Dataset for sliding window extraction from time series data.\n", + " \"\"\"\n", + "\n", + " def __init__(self, data, context_len, horizon_len, stride=-1):\n", + " \"\"\"\n", + " Initialize the dataset with sliding window logic.\n", + "\n", + " Args:\n", + " data (pd.DataFrame): The input time series data.\n", + " context_len (int): Length of the context window.\n", + " horizon_len (int): Length of the forecast horizon.\n", + " stride (int): Step size for sliding the window.\n", + " \"\"\"\n", + " self.data = data\n", + " self.context_len = context_len\n", + " self.horizon_len = horizon_len\n", + " self.total_len = context_len + horizon_len\n", + " self.stride = stride\n", + "\n", + " if self.stride == -1:\n", + " self.stride = self.horizon_len\n", + "\n", + " # Generate start indices for sliding windows\n", + " self.indices = [\n", + " start for start in range(0, len(data) - self.total_len + 1, self.stride)\n", + " ]\n", + "\n", + " def __len__(self):\n", + " return len(self.indices)\n", + "\n", + " def __getitem__(self, idx):\n", + " start = self.indices[idx]\n", + " window = self.data.iloc[start : start + self.total_len]\n", + "\n", + " # Extract context and actuals, and convert to Torch tensors\n", + " context = torch.tensor(\n", + " window.iloc[: self.context_len].to_numpy().transpose(), dtype=torch.float32\n", + " )\n", + " actual = torch.tensor(\n", + " window.iloc[self.context_len :].to_numpy().transpose(), dtype=torch.float32\n", + " )\n", + "\n", + " # # Return the input as a list of tensors (one for each column)\n", + " # input_list = [context[i] for i in range(context.shape[0])]\n", + "\n", + " return context, actual" + ], + "metadata": { + "id": "0_HeGsp7N0p2" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "test_data = test_dataset.dataset\n", + "print(test_data)" + ], + "metadata": { + "id": "d5gwi5DWN5OM" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "import torch\n", + "from torch.utils.data import Dataset\n", + "\n", + "# window_dataset = ChronosWindowDataset(data=test_data, context_len=128, horizon_len=64)\n", + "data_loader = DataLoader(test_data, batch_size=8, shuffle=False)" + ], + "metadata": { + "id": "gysL-YTAN61s" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "with torch.no_grad():\n", + " for i, (context, actual) in enumerate(data_loader):\n", + " print(context.shape)\n", + " print(actual.shape)" + ], + "metadata": { + "id": "vOnq2B7cN8ou" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "input, actual = next(iter(data_loader))\n", + "input = input.squeeze()\n", + "actual = actual.squeeze()\n", + "print(input.shape)\n", + "print(actual.shape)" + ], + "metadata": { + "id": "TRQ4SDaVN-Xy" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "input_stack = input.reshape(-1, 128)\n", + "print(input_stack.shape)" + ], + "metadata": { + "id": "-1KkT_1GOAL8" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "actual = actual.reshape(-1, 64)\n", + "print(actual.shape)" + ], + "metadata": { + "id": "Df5hiY9MOAlQ" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "predictions = ch.model.predict(\n", + " context=input_stack, prediction_length=64, num_samples=20\n", + ")" + ], + "metadata": { + "id": "9QeD4CI0ODBM" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "print(predictions.shape)" + ], + "metadata": { + "id": "Hb2d_xRYOGvA" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "print(predictions.shape)\n", + "pred_median = np.median(predictions, axis=1)\n", + "print(pred_median.shape)\n", + "pred_quantiles = np.quantile(\n", + " predictions, [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], axis=1\n", + ")\n", + "print(pred_quantiles.shape)" + ], + "metadata": { + "id": "xaGEglU5OJB1" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "actual = actual.reshape(-1, 64)\n", + "print(actual.shape)" + ], + "metadata": { + "id": "jwKg5EBtOKdF" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "mse1 = np.mean((actual.numpy() - pred_quantiles[4]) ** 2)\n", + "print(mse1)\n", + "mse1 = np.mean((actual.numpy() - pred_median) ** 2)\n", + "print(mse1)" + ], + "metadata": { + "id": "HsFP4XrkOMAB" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "print(ch.model.model.device)" + ], + "metadata": { + "id": "o2O76zOzON49" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "actual = actual.reshape(8, 7, 64)\n", + "pred_median = pred_median.reshape(8, 7, 64)\n", + "print(actual.shape)" + ], + "metadata": { + "id": "6vgEzR4aOPZD" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "mse1 = np.mean((actual - pred_median) ** 2)\n", + "print(mse1)" + ], + "metadata": { + "id": "-2_f2c-aOP8w" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "data = test_dataset.dataset\n", + "print(len(data.iloc[:10]))" + ], + "metadata": { + "id": "p2BmmGwPORRa" + }, + "execution_count": null, + "outputs": [] + } + ] +} diff --git a/example/colab/moirai.ipynb b/example/colab/moirai.ipynb new file mode 100644 index 0000000..57dde82 --- /dev/null +++ b/example/colab/moirai.ipynb @@ -0,0 +1,187 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "authorship_tag": "ABX9TyNMrNyISIi+DxlbMTxeYGbS", + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# **Install Dependencies**" + ], + "metadata": { + "id": "tpaQlDVlxetV" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qF8U27Si8bbj" + }, + "outputs": [], + "source": [ + "!pip install samay-0.1.0-cp311-cp311-linux_x86_64.whl" + ] + }, + { + "cell_type": "code", + "source": [ + "!pip install --upgrade -U numpy --force" + ], + "metadata": { + "id": "VGf7Oy_JxpYO" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Importing Requirements**" + ], + "metadata": { + "id": "0cbiLmxU2Lw-" + } + }, + { + "cell_type": "code", + "source": [ + "from tsfmproject.dataset import MoiraiDataset\n", + "from tsfmproject.model import MoiraiTSModel" + ], + "metadata": { + "id": "MAojvVF6PEGP" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Loading Dataset**" + ], + "metadata": { + "id": "Gt2KUswn2UtL" + } + }, + { + "cell_type": "code", + "source": [ + "!wget https://raw.githubusercontent.com/kage08/Samay/refs/heads/main/data/data/ETTh1.csv" + ], + "metadata": { + "id": "mnIsa2TgJgCN" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "train_dataset = MoiraiDataset(name=\"ett\", mode=\"train\", path='/nethome/sli999/TSFMProject/src/tsfmproject/models/moment/data/ETTh1.csv', datetime_col='date', freq='h', context_len=128, horizon_len=64, normalize=False)\n", + "test_dataset = MoiraiDataset(name=\"ett\", mode=\"test\", path='/nethome/sli999/TSFMProject/src/tsfmproject/models/moment/data/ETTh1.csv', datetime_col='date', freq='h', context_len=128, horizon_len=64, normalize=False)" + ], + "metadata": { + "id": "gX48JpxwPJwN" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "config = {\n", + " \"context_len\": 128,\n", + " \"horizon_len\": 64,\n", + "}\n", + "model_type = \"moirai-moe\"\n", + "model_size = \"small\"\n", + "moirai_model = MoiraiTSModel(model_type=model_type, model_size=model_size, config=config)" + ], + "metadata": { + "id": "BofOhe74PO-4" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "print(moirai_model.model.device)" + ], + "metadata": { + "id": "2ff_VDTnPS0o" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "eval_results, trues, preds, histories = moirai_model.evaluate(test_dataset, metrics=[\"MSE\", \"MASE\"])" + ], + "metadata": { + "id": "rPNns3q5PTNe" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "print(eval_results)" + ], + "metadata": { + "id": "9ZBsQ8gzPV6T" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "for i in range(len(histories)):\n", + " print(i, histories[i].shape)" + ], + "metadata": { + "id": "9DpENZS7PXen" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "print(len(test_dataset.data))" + ], + "metadata": { + "id": "fLoUWiqOPZAR" + }, + "execution_count": null, + "outputs": [] + } + ] +} diff --git a/example/colab/moment_anomaly_detection.ipynb b/example/colab/moment_anomaly_detection.ipynb new file mode 100644 index 0000000..2ed717b --- /dev/null +++ b/example/colab/moment_anomaly_detection.ipynb @@ -0,0 +1,233 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "authorship_tag": "ABX9TyPLZ1a9i7oPxMNhCvZBpzN+", + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# **Install Dependencies**" + ], + "metadata": { + "id": "tpaQlDVlxetV" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qF8U27Si8bbj" + }, + "outputs": [], + "source": [ + "!pip install samay-0.1.0-cp311-cp311-linux_x86_64.whl" + ] + }, + { + "cell_type": "code", + "source": [ + "!pip install --upgrade -U numpy --force" + ], + "metadata": { + "id": "VGf7Oy_JxpYO" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Loading Moment Model**" + ], + "metadata": { + "id": "0cbiLmxU2Lw-" + } + }, + { + "cell_type": "code", + "source": [ + "from samay.model import MomentModel\n", + "from samay.dataset import MomentDataset\n", + "from samay.utils import load_args\n", + "\n", + "repo = \"AutonLab/MOMENT-1-large\"\n", + "config = {\n", + " \"task_name\": \"reconstruction\",\n", + "}\n", + "mmt = MomentModel(config=config, repo=repo)" + ], + "metadata": { + "id": "aQSjnUjKQqlE" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Loading Dataset**" + ], + "metadata": { + "id": "Gt2KUswn2UtL" + } + }, + { + "cell_type": "code", + "source": [ + "!wget https://raw.githubusercontent.com/kage08/Samay/refs/heads/main/data/data/ETTh1.csv" + ], + "metadata": { + "id": "mnIsa2TgJgCN" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Zero-Shot Anomaly Detection Using Moment Model**" + ], + "metadata": { + "id": "CHgBRG2AQ4vD" + } + }, + { + "cell_type": "code", + "source": [ + "train_dataset = MomentDataset(name=\"ett\", path='../src/tsfmproject/models/moment/data/198_UCR_Anomaly_tiltAPB2_50000_124159_124985.out',\n", + " mode=\"train\", boundaries=[50000, 50000, 0], task_name='detection', stride=512)\n", + "test_dataset = MomentDataset(name=\"ett\", path='../src/tsfmproject/models/moment/data/198_UCR_Anomaly_tiltAPB2_50000_124159_124985.out',\n", + " mode=\"test\", boundaries=[50000, 50000, 0], task_name='detection', stride=512)\n", + "# print(len(train_dataset))\n", + "# print(len(test_dataset))\n", + "# trues, preds, labels = mmt.evaluate(test_dataset, task_name='detection')\n", + "mmt.plot(test_dataset, task_name='detection',)" + ], + "metadata": { + "id": "-GnUpadKQ2pz" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Visualization of Zero-Shot Anomaly Detection**" + ], + "metadata": { + "id": "rmBuetJARAR_" + } + }, + { + "cell_type": "code", + "source": [ + "# import numpy as np\n", + "# import matplotlib.pyplot as plt\n", + "# from tsfmproject.models.moment.momentfm.utils.anomaly_detection_metrics import adjbestf1\n", + "\n", + "# # We will use the Mean Squared Error (MSE) between the observed values and MOMENT's predictions as the anomaly score\n", + "# anomaly_scores = (trues - preds)**2\n", + "\n", + "# print(f\"Zero-shot Adjusted Best F1 Score: {adjbestf1(y_true=labels, y_scores=anomaly_scores)}\")\n", + "\n", + "# anomaly_start = 74158\n", + "# anomaly_end = 74984\n", + "# start = anomaly_start-512\n", + "# end = anomaly_end+512\n", + "\n", + "# plt.plot(trues[start:end], label=\"Observed\", c='darkblue')\n", + "# plt.plot(preds[start:end], label=\"Predicted\", c='red')\n", + "# plt.plot(anomaly_scores[start:end], label=\"Anomaly Score\", c='black')\n", + "# plt.legend(fontsize=16)\n", + "# plt.show()" + ], + "metadata": { + "id": "SDspECriRDvN" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **FInetune the Moment model for anomaly detection**" + ], + "metadata": { + "id": "4woLdKIERIUf" + } + }, + { + "cell_type": "code", + "source": [ + "finetuned_model = mmt.finetune(\n", + " train_dataset, task_name=\"detection\", mask_ratio=0.1, epoch=5\n", + ")" + ], + "metadata": { + "id": "Dw7OexJ_RGGH" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Evaluate the finetuned model and Visualization**" + ], + "metadata": { + "id": "HTva4_cARNWY" + } + }, + { + "cell_type": "code", + "source": [ + "trues, preds, labels = mmt.evaluate(test_dataset, task_name=\"detection\")\n", + "\n", + "anomaly_scores = (trues - preds) ** 2\n", + "\n", + "print(\n", + " f\"Zero-shot Adjusted Best F1 Score: {adjbestf1(y_true=labels, y_scores=anomaly_scores)}\"\n", + ")\n", + "\n", + "anomaly_start = 74158\n", + "anomaly_end = 74984\n", + "start = anomaly_start - 512\n", + "end = anomaly_end + 512\n", + "\n", + "plt.plot(trues[start:end], label=\"Observed\", c=\"darkblue\")\n", + "plt.plot(preds[start:end], label=\"Predicted\", c=\"red\")\n", + "plt.plot(anomaly_scores[start:end], label=\"Anomaly Score\", c=\"black\")\n", + "plt.legend(fontsize=16)\n", + "plt.show()" + ], + "metadata": { + "id": "Mkk6GyTyRQX8" + }, + "execution_count": null, + "outputs": [] + } + ] +} diff --git a/example/colab/moment_classification.ipynb b/example/colab/moment_classification.ipynb new file mode 100644 index 0000000..88fa7e5 --- /dev/null +++ b/example/colab/moment_classification.ipynb @@ -0,0 +1,236 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "authorship_tag": "ABX9TyNoSDnZw0e+6sXWq4gOcrTv", + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# **Install Dependencies**" + ], + "metadata": { + "id": "tpaQlDVlxetV" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qF8U27Si8bbj" + }, + "outputs": [], + "source": [ + "!pip install samay-0.1.0-cp311-cp311-linux_x86_64.whl" + ] + }, + { + "cell_type": "code", + "source": [ + "!pip install --upgrade -U numpy --force" + ], + "metadata": { + "id": "VGf7Oy_JxpYO" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Loading Moment Model**" + ], + "metadata": { + "id": "0cbiLmxU2Lw-" + } + }, + { + "cell_type": "code", + "source": [ + "from samay.model import MomentModel\n", + "from samay.dataset import MomentDataset\n", + "from samay.models.moment.momentfm.models.statistical_classifiers import fit_svm\n", + "from samay.utils import load_args\n", + "\n", + "repo = \"AutonLab/MOMENT-1-large\"\n", + "config = {\"task_name\": \"classification\", \"n_channels\": 1, \"num_class\": 5}\n", + "mmt = MomentModel(config=config, repo=repo)" + ], + "metadata": { + "id": "9KSGbm9r9wVy" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Loading Dataset**" + ], + "metadata": { + "id": "n9Nl5kajSVlc" + } + }, + { + "cell_type": "code", + "source": [ + "!wget https://raw.githubusercontent.com/kage08/Samay/refs/heads/main/data/data/ETTh1.csv" + ], + "metadata": { + "id": "mnIsa2TgJgCN" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Use a SVM classifier to evaluate zero-shot embeddings of Moment model**" + ], + "metadata": { + "id": "4k7hDhXGScrx" + } + }, + { + "cell_type": "code", + "source": [ + "train_dataset = MomentDataset(\n", + " name=\"ecg5000\",\n", + " path=\"/nethome/sli999/TSFMProject/src/tsfmproject/models/moment/data/ECG5000_TRAIN.csv\",\n", + " batchsize=64,\n", + " mode=\"train\",\n", + " task_name=\"classification\",\n", + ")\n", + "test_dataset = MomentDataset(\n", + " name=\"ecg5000\",\n", + " path=\"/nethome/sli999/TSFMProject/src/tsfmproject/models/moment/data/ECG5000_TEST.csv\",\n", + " batchsize=64,\n", + " mode=\"test\",\n", + " task_name=\"classification\",\n", + ")\n", + "\n", + "train_accuracy, train_embeddings, train_labels = mmt.evaluate(\n", + " train_dataset, task_name=\"classification\"\n", + ")\n", + "test_accuracy, test_embeddings, test_labels = mmt.evaluate(\n", + " test_dataset, task_name=\"classification\"\n", + ")\n", + "print(train_embeddings.shape, train_labels.shape)\n", + "\n", + "clf = fit_svm(features=train_embeddings, y=train_labels)\n", + "\n", + "y_pred_train = clf.predict(train_embeddings)\n", + "y_pred_test = clf.predict(test_embeddings)\n", + "train_accuracy = clf.score(train_embeddings, train_labels)\n", + "test_accuracy = clf.score(test_embeddings, test_labels)\n", + "\n", + "print(f\"Train accuracy: {train_accuracy:.2f}\")\n", + "print(f\"Test accuracy: {test_accuracy:.2f}\")" + ], + "metadata": { + "id": "1BcqPINMSf-M" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Visualize the embeddings**" + ], + "metadata": { + "id": "vFe6nzXMSjaT" + } + }, + { + "cell_type": "code", + "source": [ + "from sklearn.decomposition import PCA\n", + "import matplotlib.pyplot as plt\n", + "\n", + "\n", + "test_embeddings_manifold = PCA(n_components=2).fit_transform(test_embeddings)\n", + "\n", + "plt.title(\"ECG5000 Test Embeddings\", fontsize=20)\n", + "plt.scatter(\n", + " test_embeddings_manifold[:, 0],\n", + " test_embeddings_manifold[:, 1],\n", + " c=test_labels.squeeze(),\n", + ")\n", + "plt.xticks(fontsize=16)\n", + "plt.yticks(fontsize=16)\n", + "plt.show()" + ], + "metadata": { + "id": "eTir0t5MSlt9" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Finetune Moment model for classification**" + ], + "metadata": { + "id": "pHu3gQT_SoBO" + } + }, + { + "cell_type": "code", + "source": [ + "finetuned_model = mmt.finetune(\n", + " train_dataset, task_name=\"classification\", epoch=10, lr=0.1\n", + ")" + ], + "metadata": { + "id": "6arKH_p4Sq-Z" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Evaluate the finetuned model**" + ], + "metadata": { + "id": "3cJ8l9v6SuuB" + } + }, + { + "cell_type": "code", + "source": [ + "accuracy, embeddings, lebels = mmt.evaluate(test_dataset, task_name=\"classification\")\n", + "print(accuracy)" + ], + "metadata": { + "id": "Dk5iOBAHSsqF" + }, + "execution_count": null, + "outputs": [] + } + ] +} diff --git a/example/colab/moment_forecasting.ipynb b/example/colab/moment_forecasting.ipynb new file mode 100644 index 0000000..68f4fbb --- /dev/null +++ b/example/colab/moment_forecasting.ipynb @@ -0,0 +1,218 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "authorship_tag": "ABX9TyMOjbk5TzySlGkxItm+Z25O", + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# **Install Dependencies**" + ], + "metadata": { + "id": "tpaQlDVlxetV" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qF8U27Si8bbj" + }, + "outputs": [], + "source": [ + "!pip install samay-0.1.0-cp311-cp311-linux_x86_64.whl" + ] + }, + { + "cell_type": "code", + "source": [ + "!pip install --upgrade -U numpy --force" + ], + "metadata": { + "id": "VGf7Oy_JxpYO" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Loading Moment Model**" + ], + "metadata": { + "id": "0cbiLmxU2Lw-" + } + }, + { + "cell_type": "code", + "source": [ + "from samay.model import MomentModel\n", + "from samay.dataset import MomentDataset\n", + "from samay.utils import load_args\n", + "\n", + "repo = \"AutonLab/MOMENT-1-large\"\n", + "config = {\n", + " \"task_name\": \"forecasting\",\n", + " \"forecast_horizon\": 192,\n", + " \"head_dropout\": 0.1,\n", + " \"weight_decay\": 0,\n", + " \"freeze_encoder\": True, # Freeze the patch embedding layer\n", + " \"freeze_embedder\": True, # Freeze the transformer encoder\n", + " \"freeze_head\": False, # The linear forecasting head must be trained\n", + "}\n", + "mmt = MomentModel(config=config, repo=repo)" + ], + "metadata": { + "id": "MzAU2V_tTXSB" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Loading Dataset**" + ], + "metadata": { + "id": "Gt2KUswn2UtL" + } + }, + { + "cell_type": "code", + "source": [ + "!wget https://raw.githubusercontent.com/kage08/Samay/refs/heads/main/data/data/ETTh1.csv" + ], + "metadata": { + "id": "mnIsa2TgJgCN" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Finetune Moment Model on the ETT dataset**" + ], + "metadata": { + "id": "49ihlnzET33r" + } + }, + { + "cell_type": "code", + "source": [ + "train_dataset = MomentDataset(name=\"ett\", datetime_col='date', path='../src/tsfmproject/models/moment/data/ETTh1.csv',\n", + " mode='train', horizon_len=192)\n", + "# dtl = train_dataset.get_data_loader()\n", + "\n", + "val_dataset = MomentDataset(name=\"ett\", datetime_col='date', path='../src/tsfmproject/models/moment/data/ETTh1.csv',\n", + " mode='test', horizon_len=192)\n", + "# path = '../src/tsfmproject/models/moment/data/ETTh1.csv'\n", + "\n", + "# dataset = MomentDataset(name=\"ett\", datetime_col='date', path=path,\n", + "# mode='train', horizon=192)\n", + "\n", + "finetuned_model = mmt.finetune(train_dataset, task_name=\"forecasting\")\n", + "mmt.evaluate(val_dataset, task_name=\"forecasting\")" + ], + "metadata": { + "id": "NlPWSR-VT7Zm" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Test the Finetuned Model**" + ], + "metadata": { + "id": "jVTskCB-UKPp" + } + }, + { + "cell_type": "code", + "source": [ + "# avg_loss, trues, preds, histories = mmt.evaluate(val_dataset, task_name='forecasting')\n", + "# print(\"Validation loss:\", avg_loss)\n", + "mmt.plot(val_dataset, task_name='forecasting')" + ], + "metadata": { + "id": "keWlMzyCUQJ3" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Visualization of the evaluation**" + ], + "metadata": { + "id": "Lkrbk9zvUS04" + } + }, + { + "cell_type": "code", + "source": [ + "# import matplotlib.pyplot as plt\n", + "\n", + "# # Pick a random channel and time index\n", + "# trues = np.array(trues)\n", + "# preds = np.array(preds)\n", + "# histories = np.array(histories)\n", + "# channel_idx = np.random.randint(0, 7)\n", + "# time_index = np.random.randint(0, trues.shape[0])\n", + "\n", + "# history = histories[time_index, channel_idx, :]\n", + "# true = trues[time_index, channel_idx, :]\n", + "# pred = preds[time_index, channel_idx, :]\n", + "\n", + "# plt.figure(figsize=(12, 4))\n", + "\n", + "# # Plotting the first time series from history\n", + "# plt.plot(range(len(history)), history, label='History (512 timesteps)', c='darkblue')\n", + "\n", + "# # Plotting ground truth and prediction\n", + "# num_forecasts = len(true)\n", + "\n", + "# offset = len(history)\n", + "# plt.plot(range(offset, offset + len(true)), true, label='Ground Truth (192 timesteps)', color='darkblue', linestyle='--', alpha=0.5)\n", + "# plt.plot(range(offset, offset + len(pred)), pred, label='Forecast (192 timesteps)', color='red', linestyle='--')\n", + "\n", + "# plt.title(f\"ETTh1 (Hourly) -- (idx={time_index}, channel={channel_idx})\", fontsize=18)\n", + "# plt.xlabel('Time', fontsize=14)\n", + "# plt.ylabel('Value', fontsize=14)\n", + "# plt.legend(fontsize=14)\n", + "# plt.show()" + ], + "metadata": { + "id": "o7F-UCfnUUtF" + }, + "execution_count": null, + "outputs": [] + } + ] +} diff --git a/example/colab/moment_imputation.ipynb b/example/colab/moment_imputation.ipynb new file mode 100644 index 0000000..c7a247b --- /dev/null +++ b/example/colab/moment_imputation.ipynb @@ -0,0 +1,240 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "authorship_tag": "ABX9TyOEZ3bu8VQnUuKZKhRj8BZu", + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# **Install Dependencies**" + ], + "metadata": { + "id": "tpaQlDVlxetV" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qF8U27Si8bbj" + }, + "outputs": [], + "source": [ + "!pip install samay-0.1.0-cp311-cp311-linux_x86_64.whl" + ] + }, + { + "cell_type": "code", + "source": [ + "!pip install --upgrade -U numpy --force" + ], + "metadata": { + "id": "VGf7Oy_JxpYO" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Loading Moment Model**" + ], + "metadata": { + "id": "0cbiLmxU2Lw-" + } + }, + { + "cell_type": "code", + "source": [ + "from samay.model import MomentModel\n", + "from samay.dataset import MomentDataset\n", + "from samay.utils import load_args\n", + "\n", + "repo = \"AutonLab/MOMENT-1-large\"\n", + "config = {\n", + " \"task_name\": \"reconstruction\",\n", + "}\n", + "mmt = MomentModel(config=config, repo=repo)" + ], + "metadata": { + "id": "EKfkmPeNVEhs" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Loading Dataset**" + ], + "metadata": { + "id": "Gt2KUswn2UtL" + } + }, + { + "cell_type": "code", + "source": [ + "!wget https://raw.githubusercontent.com/kage08/Samay/refs/heads/main/data/data/ETTh1.csv" + ], + "metadata": { + "id": "mnIsa2TgJgCN" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Zero-Shot Imputation Using Moment Model**" + ], + "metadata": { + "id": "G2i1_KoTVHXc" + } + }, + { + "cell_type": "code", + "source": [ + "train_dataset = MomentDataset(\n", + " name=\"ett\",\n", + " datetime_col=\"date\",\n", + " path=\"/nethome/sli999/TSFMProject/src/tsfmproject/models/moment/data/ETTh1.csv\",\n", + " mode=\"train\",\n", + " task_name=\"imputation\",\n", + ")\n", + "test_dataset = MomentDataset(\n", + " name=\"ett\",\n", + " datetime_col=\"date\",\n", + " path=\"/nethome/sli999/TSFMProject/src/tsfmproject/models/moment/data/ETTh1.csv\",\n", + " mode=\"test\",\n", + " task_name=\"imputation\",\n", + ")\n", + "# print(len(train_dataset))\n", + "# print(len(test_dataset))\n", + "# trues, preds, masks = mmt.evaluate(test_dataset, task_name='imputation')" + ], + "metadata": { + "id": "WhjXYZLsVJr1" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Visualization of Zero-Shot Imputation**" + ], + "metadata": { + "id": "1UEUE0QGVSdA" + } + }, + { + "cell_type": "code", + "source": [ + "# import matplotlib.pyplot as plt\n", + "\n", + "# print(trues.shape, preds.shape, masks.shape)\n", + "\n", + "# mse = np.mean((trues[masks==0] - preds[masks==0])**2)\n", + "# mae = np.mean(np.abs(trues[masks==0] - preds[masks==0]))\n", + "# print(f'MSE: {mse}, MAE: {mae}')\n", + "\n", + "# idx = np.random.randint(trues.shape[0])\n", + "# channel_idx = np.random.randint(trues.shape[1])\n", + "\n", + "# fig, axs = plt.subplots(2, 1, figsize=(10, 5))\n", + "# axs[0].set_title(f\"Channel={channel_idx}\")\n", + "# axs[0].plot(trues[idx, channel_idx, :].squeeze(), label='Ground Truth', c='darkblue')\n", + "# axs[0].plot(preds[idx, channel_idx, :].squeeze(), label='Predictions', c='red')\n", + "# axs[0].legend(fontsize=16)\n", + "\n", + "# axs[1].imshow(np.tile(masks[np.newaxis, idx, channel_idx], reps=(8, 1)), cmap='binary')\n", + "# plt.show()\n", + "mmt.plot(test_dataset, task_name='imputation')" + ], + "metadata": { + "id": "ldr2IIb1VWGj" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Finetune the Moment Model for Imputation**" + ], + "metadata": { + "id": "kShP864XVZoB" + } + }, + { + "cell_type": "code", + "source": [ + "fintuned_model = mmt.finetune(train_dataset, task_name=\"imputation\")" + ], + "metadata": { + "id": "a4aYBmjDVcSx" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# **Evaluate the finetuned model and visualization**" + ], + "metadata": { + "id": "Hz1cVDpVVd0S" + } + }, + { + "cell_type": "code", + "source": [ + "# trues, preds, masks = mmt.evaluate(test_dataset, task_name='imputation')\n", + "# mse = np.mean((trues[masks==0] - preds[masks==0])**2)\n", + "# mae = np.mean(np.abs(trues[masks==0] - preds[masks==0]))\n", + "# print(f'MSE: {mse}, MAE: {mae}')\n", + "\n", + "# idx = np.random.randint(trues.shape[0])\n", + "# channel_idx = np.random.randint(trues.shape[1])\n", + "\n", + "# fig, axs = plt.subplots(2, 1, figsize=(10, 5))\n", + "# axs[0].set_title(f\"Channel={channel_idx}\")\n", + "# axs[0].plot(trues[idx, channel_idx, :].squeeze(), label='Ground Truth', c='darkblue')\n", + "# axs[0].plot(preds[idx, channel_idx, :].squeeze(), label='Predictions', c='red')\n", + "# axs[0].legend(fontsize=16)\n", + "\n", + "# axs[1].imshow(np.tile(masks[np.newaxis, idx, channel_idx], reps=(8, 1)), cmap='binary')\n", + "# plt.show()\n", + "mmt.plot(test_dataset, task_name='imputation')" + ], + "metadata": { + "id": "Yi0FjTCVVflk" + }, + "execution_count": null, + "outputs": [] + } + ] +}