diff --git a/example/colab/chronos.ipynb b/example/colab/chronos.ipynb
new file mode 100644
index 0000000..b47ff4c
--- /dev/null
+++ b/example/colab/chronos.ipynb
@@ -0,0 +1,213 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": [],
+ "authorship_tag": "ABX9TyOHc9Zna3DRB0e4ryzjGrBk",
+ "include_colab_link": true
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# **Install Dependencies**"
+ ],
+ "metadata": {
+ "id": "tpaQlDVlxetV"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "259--YZOr1Lo"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install samay-0.1.0-cp311-cp311-linux_x86_64.whl"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!pip install --upgrade -U numpy --force"
+ ],
+ "metadata": {
+ "id": "VGf7Oy_JxpYO"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# **Importing Requirements**"
+ ],
+ "metadata": {
+ "id": "0cbiLmxU2Lw-"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import os\n",
+ "import sys\n",
+ "import torch\n",
+ "import numpy as np\n",
+ "\n",
+ "src_path = os.path.abspath(os.path.join(\"..\", \"src\"))\n",
+ "if src_path not in sys.path:\n",
+ " sys.path.insert(0, src_path)\n",
+ "\n",
+ "from samay.model import ChronosModel\n",
+ "from samay.dataset import ChronosDataset\n",
+ "# from tsfmproject.utils import load_args\n",
+ "\n",
+ "# arg_path = \"../config/timesfm.json\"\n",
+ "# args = load_args(arg_path)\n",
+ "repo = \"amazon/chronos-t5-small\"\n",
+ "chronos_model = ChronosModel(repo=repo)"
+ ],
+ "metadata": {
+ "id": "APfhHKAU1qM3"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# **Loading Dataset**"
+ ],
+ "metadata": {
+ "id": "Gt2KUswn2UtL"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!wget https://raw.githubusercontent.com/kage08/Samay/refs/heads/main/data/data/ETTh1.csv"
+ ],
+ "metadata": {
+ "id": "mnIsa2TgJgCN"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "train_dataset = ChronosDataset(name=\"ett\", datetime_col='date', path='../src/tsfmproject/models/moment/data/ETTh1.csv',\n",
+ " mode='train', batch_size=8)\n",
+ "val_dataset = ChronosDataset(name=\"ett\", datetime_col='date', path='../src/tsfmproject/models/moment/data/ETTh1.csv',\n",
+ " mode='test', batch_size=8)"
+ ],
+ "metadata": {
+ "id": "MyUXzOME1sgo"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# **Visualize the zero-shot forecasting**"
+ ],
+ "metadata": {
+ "id": "OgcxscvD3Ee8"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "chronos_model.plot(val_dataset, horizon_len=64, quantile_levels=[0.1, 0.5, 0.9])"
+ ],
+ "metadata": {
+ "id": "K9Zmx3li1u5k"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# **Evaluate the zero-shot Chronos Model**"
+ ],
+ "metadata": {
+ "id": "2I0VE7Uu3Jqt"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "metrics = chronos_model.evaluate(val_dataset, horizon_len=64, quantile_levels=[0.1, 0.5, 0.9])\n",
+ "print(metrics)"
+ ],
+ "metadata": {
+ "id": "Z90jyPsp1ykH"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# **Finetune Chronos Model on the ETT dataset**"
+ ],
+ "metadata": {
+ "id": "iuzgw9uA3idf"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "chronos_model.finetune(train_dataset)"
+ ],
+ "metadata": {
+ "id": "LeFA_RVg12IS"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# **Evaluate the Finetuned Chronos Model**"
+ ],
+ "metadata": {
+ "id": "DE5-HDC53n4_"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "metrics = chronos_model.evaluate(val_dataset, horizon_len=64, quantile_levels=[0.1, 0.5, 0.9])\n",
+ "print(metrics)"
+ ],
+ "metadata": {
+ "id": "XF2bgAXi125y"
+ },
+ "execution_count": null,
+ "outputs": []
+ }
+ ]
+}
diff --git a/example/colab/chronos_trial.ipynb b/example/colab/chronos_trial.ipynb
new file mode 100644
index 0000000..106a2c6
--- /dev/null
+++ b/example/colab/chronos_trial.ipynb
@@ -0,0 +1,645 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": [],
+ "authorship_tag": "ABX9TyPxvN1ZBJ3PXV+dtaz3/lQX",
+ "include_colab_link": true
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# **Install Dependencies**"
+ ],
+ "metadata": {
+ "id": "tpaQlDVlxetV"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "qF8U27Si8bbj"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install samay-0.1.0-cp311-cp311-linux_x86_64.whl"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!pip install --upgrade -U numpy --force"
+ ],
+ "metadata": {
+ "id": "VGf7Oy_JxpYO"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# **Importing Requirements**"
+ ],
+ "metadata": {
+ "id": "0cbiLmxU2Lw-"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import os\n",
+ "import sys\n",
+ "\n",
+ "import numpy as np\n",
+ "import torch\n",
+ "\n",
+ "src_path = os.path.abspath(os.path.join(\"src\"))\n",
+ "if src_path not in sys.path:\n",
+ " sys.path.insert(0, src_path)\n",
+ "\n",
+ "print(sys.path)"
+ ],
+ "metadata": {
+ "id": "9KSGbm9r9wVy"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from samay.dataset import ChronosDataset\n",
+ "from samay.model import ChronosModel\n",
+ "from samay.visualization import ForecastVisualization"
+ ],
+ "metadata": {
+ "id": "0-Z4X0kdNPR8"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# **Loading Dataset**"
+ ],
+ "metadata": {
+ "id": "Gt2KUswn2UtL"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!wget https://raw.githubusercontent.com/kage08/Samay/refs/heads/main/data/data/ETTh1.csv"
+ ],
+ "metadata": {
+ "id": "mnIsa2TgJgCN"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "train_dataset = ChronosDataset(\n",
+ " name=\"ett\",\n",
+ " mode=\"train\",\n",
+ " path=\"/nethome/sli999/TSFMProject/src/tsfmproject/models/moment/data/ETTh1.csv\",\n",
+ " datetime_col=\"date\",\n",
+ " freq=\"h\",\n",
+ " context_len=128,\n",
+ " horizon_len=64,\n",
+ ")\n",
+ "test_dataset = ChronosDataset(\n",
+ " name=\"ett\",\n",
+ " mode=\"test\",\n",
+ " path=\"/nethome/sli999/TSFMProject/src/tsfmproject/models/moment/data/ETTh1.csv\",\n",
+ " datetime_col=\"date\",\n",
+ " freq=\"h\",\n",
+ " context_len=128,\n",
+ " horizon_len=64,\n",
+ ")\n",
+ "# train_dataset = ChronosDataset(name=\"ett\", mode=\"train\", path='/nethome/abhalerao9/TIMESERIESMODELING/TSFMProject/data/dataset/timesfm_covid_pivot.csv', datetime_col='ds', freq='D', context_len=64, horizon_len=16)\n",
+ "# test_dataset = ChronosDataset(name=\"ett\", mode=\"test\", path='/nethome/abhalerao9/TIMESERIESMODELING/TSFMProject/data/dataset/timesfm_covid_pivot.csv', datetime_col='ds', freq='D', context_len=64, horizon_len=16)\n",
+ "print(len(test_dataset.dataset))\n",
+ "# print(test_dataset.dataset.shape)"
+ ],
+ "metadata": {
+ "id": "UoMPUycM-uX4"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# **Loading the Chronos Model**"
+ ],
+ "metadata": {
+ "id": "OgcxscvD3Ee8"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "repo = \"amazon/chronos-t5-small\"\n",
+ "ch = ChronosModel(config=None, repo=repo)\n",
+ "ch.load_model()"
+ ],
+ "metadata": {
+ "id": "yfcJVg7d-xE0"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "print(ch.model.model.device)"
+ ],
+ "metadata": {
+ "id": "lt5Scq3e-zTu"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "eval_results, trues, preds, histories = ch.evaluate(\n",
+ " test_dataset, batch_size=8, metrics=[\"MSE\", \"MASE\"]\n",
+ ")\n",
+ "print(eval_results)\n",
+ "# visualization = ForecastVisualization(trues, preds[:,:,1,:], histories)\n",
+ "# visualization.plot()"
+ ],
+ "metadata": {
+ "id": "Jsmfu_hS-1S-"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "visualization = ForecastVisualization(trues, preds, histories)\n",
+ "visualization.plot(channel_idx=0, time_idx=0)"
+ ],
+ "metadata": {
+ "id": "Dk_Rkgvd-373"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "print(trues.shape)\n",
+ "print(preds.shape)\n",
+ "print(histories.shape)"
+ ],
+ "metadata": {
+ "id": "k-Y5T3qk-6RV"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "ch.finetune(train_dataset)"
+ ],
+ "metadata": {
+ "id": "TWElWIOD-75d"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "latest_run_dir = ch.get_latest_run_dir()\n",
+ "model_dir = os.path.join(latest_run_dir, \"checkpoint-final\")\n",
+ "model_type = \"seq2seq\"\n",
+ "model = ch.load_model(model_dir, model_type)"
+ ],
+ "metadata": {
+ "id": "X4FyWD2q-9_-"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "eval_results, trues, preds, histories = ch.evaluate(\n",
+ " test_dataset, batch_size=8, metrics=[\"MSE\", \"MASE\"]\n",
+ ")\n",
+ "print(eval_results)"
+ ],
+ "metadata": {
+ "id": "UjsDeYg9-_0T"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "visualization = ForecastVisualization(trues, preds, histories)\n",
+ "visualization.plot(channel_idx=0, time_index=0)"
+ ],
+ "metadata": {
+ "id": "SumTaMAd_B2d"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "data = test_dataset.dataset\n",
+ "data = np.array(data).transpose()\n",
+ "\n",
+ "print(data.shape)"
+ ],
+ "metadata": {
+ "id": "Xvrrb_eL_DWh"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "input = [torch.tensor(ts[:1000]) for i, ts in enumerate(data)]\n",
+ "print(input[0].shape)\n",
+ "predictions = ch.model.predict(context=input, prediction_length=64, num_samples=10)"
+ ],
+ "metadata": {
+ "id": "ex5hbQEF_FBD"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from gluonts.dataset.pandas import PandasDataset\n",
+ "from gluonts.dataset.split import split"
+ ],
+ "metadata": {
+ "id": "cfJ6LK1ONpMf"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "dataset = test_dataset.dataset\n",
+ "dataset = PandasDataset(dict(dataset))\n",
+ "train, test_template = split(dataset, offset=-128 + 20 * 64)\n",
+ "print(test_template)\n",
+ "print(len(dataset))\n",
+ "test_data = test_template.generate_instances(\n",
+ " prediction_length=64, windows=20, distance=64\n",
+ ")"
+ ],
+ "metadata": {
+ "id": "nYGxZbHnNrA-"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "print(test_data)"
+ ],
+ "metadata": {
+ "id": "5TAndQWYNswE"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "input_it = iter(test_data.input)\n",
+ "label_it = iter(test_data.label)\n",
+ "# inp = next(input_it)\n",
+ "# label = next(label_it)\n",
+ "print(inp)\n",
+ "print(label[\"target\"].shape)"
+ ],
+ "metadata": {
+ "id": "b_qowpKnNup2"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "for inp, label in zip(input_it, label_it):\n",
+ " print(inp[\"item_id\"], label[\"item_id\"], label[\"target\"].shape)"
+ ],
+ "metadata": {
+ "id": "uruJM9vqNxs-"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "print(test_dataset.dataset)"
+ ],
+ "metadata": {
+ "id": "PSDdqD-kNz7w"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import torch\n",
+ "from torch.utils.data import DataLoader, Dataset\n",
+ "\n",
+ "\n",
+ "class ChronosWindowDataset(Dataset):\n",
+ " \"\"\"\n",
+ " A PyTorch Dataset for sliding window extraction from time series data.\n",
+ " \"\"\"\n",
+ "\n",
+ " def __init__(self, data, context_len, horizon_len, stride=-1):\n",
+ " \"\"\"\n",
+ " Initialize the dataset with sliding window logic.\n",
+ "\n",
+ " Args:\n",
+ " data (pd.DataFrame): The input time series data.\n",
+ " context_len (int): Length of the context window.\n",
+ " horizon_len (int): Length of the forecast horizon.\n",
+ " stride (int): Step size for sliding the window.\n",
+ " \"\"\"\n",
+ " self.data = data\n",
+ " self.context_len = context_len\n",
+ " self.horizon_len = horizon_len\n",
+ " self.total_len = context_len + horizon_len\n",
+ " self.stride = stride\n",
+ "\n",
+ " if self.stride == -1:\n",
+ " self.stride = self.horizon_len\n",
+ "\n",
+ " # Generate start indices for sliding windows\n",
+ " self.indices = [\n",
+ " start for start in range(0, len(data) - self.total_len + 1, self.stride)\n",
+ " ]\n",
+ "\n",
+ " def __len__(self):\n",
+ " return len(self.indices)\n",
+ "\n",
+ " def __getitem__(self, idx):\n",
+ " start = self.indices[idx]\n",
+ " window = self.data.iloc[start : start + self.total_len]\n",
+ "\n",
+ " # Extract context and actuals, and convert to Torch tensors\n",
+ " context = torch.tensor(\n",
+ " window.iloc[: self.context_len].to_numpy().transpose(), dtype=torch.float32\n",
+ " )\n",
+ " actual = torch.tensor(\n",
+ " window.iloc[self.context_len :].to_numpy().transpose(), dtype=torch.float32\n",
+ " )\n",
+ "\n",
+ " # # Return the input as a list of tensors (one for each column)\n",
+ " # input_list = [context[i] for i in range(context.shape[0])]\n",
+ "\n",
+ " return context, actual"
+ ],
+ "metadata": {
+ "id": "0_HeGsp7N0p2"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "test_data = test_dataset.dataset\n",
+ "print(test_data)"
+ ],
+ "metadata": {
+ "id": "d5gwi5DWN5OM"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import torch\n",
+ "from torch.utils.data import Dataset\n",
+ "\n",
+ "# window_dataset = ChronosWindowDataset(data=test_data, context_len=128, horizon_len=64)\n",
+ "data_loader = DataLoader(test_data, batch_size=8, shuffle=False)"
+ ],
+ "metadata": {
+ "id": "gysL-YTAN61s"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "with torch.no_grad():\n",
+ " for i, (context, actual) in enumerate(data_loader):\n",
+ " print(context.shape)\n",
+ " print(actual.shape)"
+ ],
+ "metadata": {
+ "id": "vOnq2B7cN8ou"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "input, actual = next(iter(data_loader))\n",
+ "input = input.squeeze()\n",
+ "actual = actual.squeeze()\n",
+ "print(input.shape)\n",
+ "print(actual.shape)"
+ ],
+ "metadata": {
+ "id": "TRQ4SDaVN-Xy"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "input_stack = input.reshape(-1, 128)\n",
+ "print(input_stack.shape)"
+ ],
+ "metadata": {
+ "id": "-1KkT_1GOAL8"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "actual = actual.reshape(-1, 64)\n",
+ "print(actual.shape)"
+ ],
+ "metadata": {
+ "id": "Df5hiY9MOAlQ"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "predictions = ch.model.predict(\n",
+ " context=input_stack, prediction_length=64, num_samples=20\n",
+ ")"
+ ],
+ "metadata": {
+ "id": "9QeD4CI0ODBM"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "print(predictions.shape)"
+ ],
+ "metadata": {
+ "id": "Hb2d_xRYOGvA"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "print(predictions.shape)\n",
+ "pred_median = np.median(predictions, axis=1)\n",
+ "print(pred_median.shape)\n",
+ "pred_quantiles = np.quantile(\n",
+ " predictions, [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], axis=1\n",
+ ")\n",
+ "print(pred_quantiles.shape)"
+ ],
+ "metadata": {
+ "id": "xaGEglU5OJB1"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "actual = actual.reshape(-1, 64)\n",
+ "print(actual.shape)"
+ ],
+ "metadata": {
+ "id": "jwKg5EBtOKdF"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "mse1 = np.mean((actual.numpy() - pred_quantiles[4]) ** 2)\n",
+ "print(mse1)\n",
+ "mse1 = np.mean((actual.numpy() - pred_median) ** 2)\n",
+ "print(mse1)"
+ ],
+ "metadata": {
+ "id": "HsFP4XrkOMAB"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "print(ch.model.model.device)"
+ ],
+ "metadata": {
+ "id": "o2O76zOzON49"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "actual = actual.reshape(8, 7, 64)\n",
+ "pred_median = pred_median.reshape(8, 7, 64)\n",
+ "print(actual.shape)"
+ ],
+ "metadata": {
+ "id": "6vgEzR4aOPZD"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "mse1 = np.mean((actual - pred_median) ** 2)\n",
+ "print(mse1)"
+ ],
+ "metadata": {
+ "id": "-2_f2c-aOP8w"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "data = test_dataset.dataset\n",
+ "print(len(data.iloc[:10]))"
+ ],
+ "metadata": {
+ "id": "p2BmmGwPORRa"
+ },
+ "execution_count": null,
+ "outputs": []
+ }
+ ]
+}
diff --git a/example/colab/moirai.ipynb b/example/colab/moirai.ipynb
new file mode 100644
index 0000000..57dde82
--- /dev/null
+++ b/example/colab/moirai.ipynb
@@ -0,0 +1,187 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": [],
+ "authorship_tag": "ABX9TyNMrNyISIi+DxlbMTxeYGbS",
+ "include_colab_link": true
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# **Install Dependencies**"
+ ],
+ "metadata": {
+ "id": "tpaQlDVlxetV"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "qF8U27Si8bbj"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install samay-0.1.0-cp311-cp311-linux_x86_64.whl"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!pip install --upgrade -U numpy --force"
+ ],
+ "metadata": {
+ "id": "VGf7Oy_JxpYO"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# **Importing Requirements**"
+ ],
+ "metadata": {
+ "id": "0cbiLmxU2Lw-"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from tsfmproject.dataset import MoiraiDataset\n",
+ "from tsfmproject.model import MoiraiTSModel"
+ ],
+ "metadata": {
+ "id": "MAojvVF6PEGP"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# **Loading Dataset**"
+ ],
+ "metadata": {
+ "id": "Gt2KUswn2UtL"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!wget https://raw.githubusercontent.com/kage08/Samay/refs/heads/main/data/data/ETTh1.csv"
+ ],
+ "metadata": {
+ "id": "mnIsa2TgJgCN"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "train_dataset = MoiraiDataset(name=\"ett\", mode=\"train\", path='/nethome/sli999/TSFMProject/src/tsfmproject/models/moment/data/ETTh1.csv', datetime_col='date', freq='h', context_len=128, horizon_len=64, normalize=False)\n",
+ "test_dataset = MoiraiDataset(name=\"ett\", mode=\"test\", path='/nethome/sli999/TSFMProject/src/tsfmproject/models/moment/data/ETTh1.csv', datetime_col='date', freq='h', context_len=128, horizon_len=64, normalize=False)"
+ ],
+ "metadata": {
+ "id": "gX48JpxwPJwN"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "config = {\n",
+ " \"context_len\": 128,\n",
+ " \"horizon_len\": 64,\n",
+ "}\n",
+ "model_type = \"moirai-moe\"\n",
+ "model_size = \"small\"\n",
+ "moirai_model = MoiraiTSModel(model_type=model_type, model_size=model_size, config=config)"
+ ],
+ "metadata": {
+ "id": "BofOhe74PO-4"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "print(moirai_model.model.device)"
+ ],
+ "metadata": {
+ "id": "2ff_VDTnPS0o"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "eval_results, trues, preds, histories = moirai_model.evaluate(test_dataset, metrics=[\"MSE\", \"MASE\"])"
+ ],
+ "metadata": {
+ "id": "rPNns3q5PTNe"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "print(eval_results)"
+ ],
+ "metadata": {
+ "id": "9ZBsQ8gzPV6T"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "for i in range(len(histories)):\n",
+ " print(i, histories[i].shape)"
+ ],
+ "metadata": {
+ "id": "9DpENZS7PXen"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "print(len(test_dataset.data))"
+ ],
+ "metadata": {
+ "id": "fLoUWiqOPZAR"
+ },
+ "execution_count": null,
+ "outputs": []
+ }
+ ]
+}
diff --git a/example/colab/moment_anomaly_detection.ipynb b/example/colab/moment_anomaly_detection.ipynb
new file mode 100644
index 0000000..2ed717b
--- /dev/null
+++ b/example/colab/moment_anomaly_detection.ipynb
@@ -0,0 +1,233 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": [],
+ "authorship_tag": "ABX9TyPLZ1a9i7oPxMNhCvZBpzN+",
+ "include_colab_link": true
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# **Install Dependencies**"
+ ],
+ "metadata": {
+ "id": "tpaQlDVlxetV"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "qF8U27Si8bbj"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install samay-0.1.0-cp311-cp311-linux_x86_64.whl"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!pip install --upgrade -U numpy --force"
+ ],
+ "metadata": {
+ "id": "VGf7Oy_JxpYO"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# **Loading Moment Model**"
+ ],
+ "metadata": {
+ "id": "0cbiLmxU2Lw-"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from samay.model import MomentModel\n",
+ "from samay.dataset import MomentDataset\n",
+ "from samay.utils import load_args\n",
+ "\n",
+ "repo = \"AutonLab/MOMENT-1-large\"\n",
+ "config = {\n",
+ " \"task_name\": \"reconstruction\",\n",
+ "}\n",
+ "mmt = MomentModel(config=config, repo=repo)"
+ ],
+ "metadata": {
+ "id": "aQSjnUjKQqlE"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# **Loading Dataset**"
+ ],
+ "metadata": {
+ "id": "Gt2KUswn2UtL"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!wget https://raw.githubusercontent.com/kage08/Samay/refs/heads/main/data/data/ETTh1.csv"
+ ],
+ "metadata": {
+ "id": "mnIsa2TgJgCN"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# **Zero-Shot Anomaly Detection Using Moment Model**"
+ ],
+ "metadata": {
+ "id": "CHgBRG2AQ4vD"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "train_dataset = MomentDataset(name=\"ett\", path='../src/tsfmproject/models/moment/data/198_UCR_Anomaly_tiltAPB2_50000_124159_124985.out',\n",
+ " mode=\"train\", boundaries=[50000, 50000, 0], task_name='detection', stride=512)\n",
+ "test_dataset = MomentDataset(name=\"ett\", path='../src/tsfmproject/models/moment/data/198_UCR_Anomaly_tiltAPB2_50000_124159_124985.out',\n",
+ " mode=\"test\", boundaries=[50000, 50000, 0], task_name='detection', stride=512)\n",
+ "# print(len(train_dataset))\n",
+ "# print(len(test_dataset))\n",
+ "# trues, preds, labels = mmt.evaluate(test_dataset, task_name='detection')\n",
+ "mmt.plot(test_dataset, task_name='detection',)"
+ ],
+ "metadata": {
+ "id": "-GnUpadKQ2pz"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# **Visualization of Zero-Shot Anomaly Detection**"
+ ],
+ "metadata": {
+ "id": "rmBuetJARAR_"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# import numpy as np\n",
+ "# import matplotlib.pyplot as plt\n",
+ "# from tsfmproject.models.moment.momentfm.utils.anomaly_detection_metrics import adjbestf1\n",
+ "\n",
+ "# # We will use the Mean Squared Error (MSE) between the observed values and MOMENT's predictions as the anomaly score\n",
+ "# anomaly_scores = (trues - preds)**2\n",
+ "\n",
+ "# print(f\"Zero-shot Adjusted Best F1 Score: {adjbestf1(y_true=labels, y_scores=anomaly_scores)}\")\n",
+ "\n",
+ "# anomaly_start = 74158\n",
+ "# anomaly_end = 74984\n",
+ "# start = anomaly_start-512\n",
+ "# end = anomaly_end+512\n",
+ "\n",
+ "# plt.plot(trues[start:end], label=\"Observed\", c='darkblue')\n",
+ "# plt.plot(preds[start:end], label=\"Predicted\", c='red')\n",
+ "# plt.plot(anomaly_scores[start:end], label=\"Anomaly Score\", c='black')\n",
+ "# plt.legend(fontsize=16)\n",
+ "# plt.show()"
+ ],
+ "metadata": {
+ "id": "SDspECriRDvN"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# **FInetune the Moment model for anomaly detection**"
+ ],
+ "metadata": {
+ "id": "4woLdKIERIUf"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "finetuned_model = mmt.finetune(\n",
+ " train_dataset, task_name=\"detection\", mask_ratio=0.1, epoch=5\n",
+ ")"
+ ],
+ "metadata": {
+ "id": "Dw7OexJ_RGGH"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# **Evaluate the finetuned model and Visualization**"
+ ],
+ "metadata": {
+ "id": "HTva4_cARNWY"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "trues, preds, labels = mmt.evaluate(test_dataset, task_name=\"detection\")\n",
+ "\n",
+ "anomaly_scores = (trues - preds) ** 2\n",
+ "\n",
+ "print(\n",
+ " f\"Zero-shot Adjusted Best F1 Score: {adjbestf1(y_true=labels, y_scores=anomaly_scores)}\"\n",
+ ")\n",
+ "\n",
+ "anomaly_start = 74158\n",
+ "anomaly_end = 74984\n",
+ "start = anomaly_start - 512\n",
+ "end = anomaly_end + 512\n",
+ "\n",
+ "plt.plot(trues[start:end], label=\"Observed\", c=\"darkblue\")\n",
+ "plt.plot(preds[start:end], label=\"Predicted\", c=\"red\")\n",
+ "plt.plot(anomaly_scores[start:end], label=\"Anomaly Score\", c=\"black\")\n",
+ "plt.legend(fontsize=16)\n",
+ "plt.show()"
+ ],
+ "metadata": {
+ "id": "Mkk6GyTyRQX8"
+ },
+ "execution_count": null,
+ "outputs": []
+ }
+ ]
+}
diff --git a/example/colab/moment_classification.ipynb b/example/colab/moment_classification.ipynb
new file mode 100644
index 0000000..88fa7e5
--- /dev/null
+++ b/example/colab/moment_classification.ipynb
@@ -0,0 +1,236 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": [],
+ "authorship_tag": "ABX9TyNoSDnZw0e+6sXWq4gOcrTv",
+ "include_colab_link": true
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# **Install Dependencies**"
+ ],
+ "metadata": {
+ "id": "tpaQlDVlxetV"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "qF8U27Si8bbj"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install samay-0.1.0-cp311-cp311-linux_x86_64.whl"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!pip install --upgrade -U numpy --force"
+ ],
+ "metadata": {
+ "id": "VGf7Oy_JxpYO"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# **Loading Moment Model**"
+ ],
+ "metadata": {
+ "id": "0cbiLmxU2Lw-"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from samay.model import MomentModel\n",
+ "from samay.dataset import MomentDataset\n",
+ "from samay.models.moment.momentfm.models.statistical_classifiers import fit_svm\n",
+ "from samay.utils import load_args\n",
+ "\n",
+ "repo = \"AutonLab/MOMENT-1-large\"\n",
+ "config = {\"task_name\": \"classification\", \"n_channels\": 1, \"num_class\": 5}\n",
+ "mmt = MomentModel(config=config, repo=repo)"
+ ],
+ "metadata": {
+ "id": "9KSGbm9r9wVy"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# **Loading Dataset**"
+ ],
+ "metadata": {
+ "id": "n9Nl5kajSVlc"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!wget https://raw.githubusercontent.com/kage08/Samay/refs/heads/main/data/data/ETTh1.csv"
+ ],
+ "metadata": {
+ "id": "mnIsa2TgJgCN"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# **Use a SVM classifier to evaluate zero-shot embeddings of Moment model**"
+ ],
+ "metadata": {
+ "id": "4k7hDhXGScrx"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "train_dataset = MomentDataset(\n",
+ " name=\"ecg5000\",\n",
+ " path=\"/nethome/sli999/TSFMProject/src/tsfmproject/models/moment/data/ECG5000_TRAIN.csv\",\n",
+ " batchsize=64,\n",
+ " mode=\"train\",\n",
+ " task_name=\"classification\",\n",
+ ")\n",
+ "test_dataset = MomentDataset(\n",
+ " name=\"ecg5000\",\n",
+ " path=\"/nethome/sli999/TSFMProject/src/tsfmproject/models/moment/data/ECG5000_TEST.csv\",\n",
+ " batchsize=64,\n",
+ " mode=\"test\",\n",
+ " task_name=\"classification\",\n",
+ ")\n",
+ "\n",
+ "train_accuracy, train_embeddings, train_labels = mmt.evaluate(\n",
+ " train_dataset, task_name=\"classification\"\n",
+ ")\n",
+ "test_accuracy, test_embeddings, test_labels = mmt.evaluate(\n",
+ " test_dataset, task_name=\"classification\"\n",
+ ")\n",
+ "print(train_embeddings.shape, train_labels.shape)\n",
+ "\n",
+ "clf = fit_svm(features=train_embeddings, y=train_labels)\n",
+ "\n",
+ "y_pred_train = clf.predict(train_embeddings)\n",
+ "y_pred_test = clf.predict(test_embeddings)\n",
+ "train_accuracy = clf.score(train_embeddings, train_labels)\n",
+ "test_accuracy = clf.score(test_embeddings, test_labels)\n",
+ "\n",
+ "print(f\"Train accuracy: {train_accuracy:.2f}\")\n",
+ "print(f\"Test accuracy: {test_accuracy:.2f}\")"
+ ],
+ "metadata": {
+ "id": "1BcqPINMSf-M"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# **Visualize the embeddings**"
+ ],
+ "metadata": {
+ "id": "vFe6nzXMSjaT"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from sklearn.decomposition import PCA\n",
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "\n",
+ "test_embeddings_manifold = PCA(n_components=2).fit_transform(test_embeddings)\n",
+ "\n",
+ "plt.title(\"ECG5000 Test Embeddings\", fontsize=20)\n",
+ "plt.scatter(\n",
+ " test_embeddings_manifold[:, 0],\n",
+ " test_embeddings_manifold[:, 1],\n",
+ " c=test_labels.squeeze(),\n",
+ ")\n",
+ "plt.xticks(fontsize=16)\n",
+ "plt.yticks(fontsize=16)\n",
+ "plt.show()"
+ ],
+ "metadata": {
+ "id": "eTir0t5MSlt9"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# **Finetune Moment model for classification**"
+ ],
+ "metadata": {
+ "id": "pHu3gQT_SoBO"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "finetuned_model = mmt.finetune(\n",
+ " train_dataset, task_name=\"classification\", epoch=10, lr=0.1\n",
+ ")"
+ ],
+ "metadata": {
+ "id": "6arKH_p4Sq-Z"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# **Evaluate the finetuned model**"
+ ],
+ "metadata": {
+ "id": "3cJ8l9v6SuuB"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "accuracy, embeddings, lebels = mmt.evaluate(test_dataset, task_name=\"classification\")\n",
+ "print(accuracy)"
+ ],
+ "metadata": {
+ "id": "Dk5iOBAHSsqF"
+ },
+ "execution_count": null,
+ "outputs": []
+ }
+ ]
+}
diff --git a/example/colab/moment_forecasting.ipynb b/example/colab/moment_forecasting.ipynb
new file mode 100644
index 0000000..68f4fbb
--- /dev/null
+++ b/example/colab/moment_forecasting.ipynb
@@ -0,0 +1,218 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": [],
+ "authorship_tag": "ABX9TyMOjbk5TzySlGkxItm+Z25O",
+ "include_colab_link": true
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# **Install Dependencies**"
+ ],
+ "metadata": {
+ "id": "tpaQlDVlxetV"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "qF8U27Si8bbj"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install samay-0.1.0-cp311-cp311-linux_x86_64.whl"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!pip install --upgrade -U numpy --force"
+ ],
+ "metadata": {
+ "id": "VGf7Oy_JxpYO"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# **Loading Moment Model**"
+ ],
+ "metadata": {
+ "id": "0cbiLmxU2Lw-"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from samay.model import MomentModel\n",
+ "from samay.dataset import MomentDataset\n",
+ "from samay.utils import load_args\n",
+ "\n",
+ "repo = \"AutonLab/MOMENT-1-large\"\n",
+ "config = {\n",
+ " \"task_name\": \"forecasting\",\n",
+ " \"forecast_horizon\": 192,\n",
+ " \"head_dropout\": 0.1,\n",
+ " \"weight_decay\": 0,\n",
+ " \"freeze_encoder\": True, # Freeze the patch embedding layer\n",
+ " \"freeze_embedder\": True, # Freeze the transformer encoder\n",
+ " \"freeze_head\": False, # The linear forecasting head must be trained\n",
+ "}\n",
+ "mmt = MomentModel(config=config, repo=repo)"
+ ],
+ "metadata": {
+ "id": "MzAU2V_tTXSB"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# **Loading Dataset**"
+ ],
+ "metadata": {
+ "id": "Gt2KUswn2UtL"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!wget https://raw.githubusercontent.com/kage08/Samay/refs/heads/main/data/data/ETTh1.csv"
+ ],
+ "metadata": {
+ "id": "mnIsa2TgJgCN"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# **Finetune Moment Model on the ETT dataset**"
+ ],
+ "metadata": {
+ "id": "49ihlnzET33r"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "train_dataset = MomentDataset(name=\"ett\", datetime_col='date', path='../src/tsfmproject/models/moment/data/ETTh1.csv',\n",
+ " mode='train', horizon_len=192)\n",
+ "# dtl = train_dataset.get_data_loader()\n",
+ "\n",
+ "val_dataset = MomentDataset(name=\"ett\", datetime_col='date', path='../src/tsfmproject/models/moment/data/ETTh1.csv',\n",
+ " mode='test', horizon_len=192)\n",
+ "# path = '../src/tsfmproject/models/moment/data/ETTh1.csv'\n",
+ "\n",
+ "# dataset = MomentDataset(name=\"ett\", datetime_col='date', path=path,\n",
+ "# mode='train', horizon=192)\n",
+ "\n",
+ "finetuned_model = mmt.finetune(train_dataset, task_name=\"forecasting\")\n",
+ "mmt.evaluate(val_dataset, task_name=\"forecasting\")"
+ ],
+ "metadata": {
+ "id": "NlPWSR-VT7Zm"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# **Test the Finetuned Model**"
+ ],
+ "metadata": {
+ "id": "jVTskCB-UKPp"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# avg_loss, trues, preds, histories = mmt.evaluate(val_dataset, task_name='forecasting')\n",
+ "# print(\"Validation loss:\", avg_loss)\n",
+ "mmt.plot(val_dataset, task_name='forecasting')"
+ ],
+ "metadata": {
+ "id": "keWlMzyCUQJ3"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# **Visualization of the evaluation**"
+ ],
+ "metadata": {
+ "id": "Lkrbk9zvUS04"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# import matplotlib.pyplot as plt\n",
+ "\n",
+ "# # Pick a random channel and time index\n",
+ "# trues = np.array(trues)\n",
+ "# preds = np.array(preds)\n",
+ "# histories = np.array(histories)\n",
+ "# channel_idx = np.random.randint(0, 7)\n",
+ "# time_index = np.random.randint(0, trues.shape[0])\n",
+ "\n",
+ "# history = histories[time_index, channel_idx, :]\n",
+ "# true = trues[time_index, channel_idx, :]\n",
+ "# pred = preds[time_index, channel_idx, :]\n",
+ "\n",
+ "# plt.figure(figsize=(12, 4))\n",
+ "\n",
+ "# # Plotting the first time series from history\n",
+ "# plt.plot(range(len(history)), history, label='History (512 timesteps)', c='darkblue')\n",
+ "\n",
+ "# # Plotting ground truth and prediction\n",
+ "# num_forecasts = len(true)\n",
+ "\n",
+ "# offset = len(history)\n",
+ "# plt.plot(range(offset, offset + len(true)), true, label='Ground Truth (192 timesteps)', color='darkblue', linestyle='--', alpha=0.5)\n",
+ "# plt.plot(range(offset, offset + len(pred)), pred, label='Forecast (192 timesteps)', color='red', linestyle='--')\n",
+ "\n",
+ "# plt.title(f\"ETTh1 (Hourly) -- (idx={time_index}, channel={channel_idx})\", fontsize=18)\n",
+ "# plt.xlabel('Time', fontsize=14)\n",
+ "# plt.ylabel('Value', fontsize=14)\n",
+ "# plt.legend(fontsize=14)\n",
+ "# plt.show()"
+ ],
+ "metadata": {
+ "id": "o7F-UCfnUUtF"
+ },
+ "execution_count": null,
+ "outputs": []
+ }
+ ]
+}
diff --git a/example/colab/moment_imputation.ipynb b/example/colab/moment_imputation.ipynb
new file mode 100644
index 0000000..c7a247b
--- /dev/null
+++ b/example/colab/moment_imputation.ipynb
@@ -0,0 +1,240 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": [],
+ "authorship_tag": "ABX9TyOEZ3bu8VQnUuKZKhRj8BZu",
+ "include_colab_link": true
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# **Install Dependencies**"
+ ],
+ "metadata": {
+ "id": "tpaQlDVlxetV"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "qF8U27Si8bbj"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install samay-0.1.0-cp311-cp311-linux_x86_64.whl"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!pip install --upgrade -U numpy --force"
+ ],
+ "metadata": {
+ "id": "VGf7Oy_JxpYO"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# **Loading Moment Model**"
+ ],
+ "metadata": {
+ "id": "0cbiLmxU2Lw-"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from samay.model import MomentModel\n",
+ "from samay.dataset import MomentDataset\n",
+ "from samay.utils import load_args\n",
+ "\n",
+ "repo = \"AutonLab/MOMENT-1-large\"\n",
+ "config = {\n",
+ " \"task_name\": \"reconstruction\",\n",
+ "}\n",
+ "mmt = MomentModel(config=config, repo=repo)"
+ ],
+ "metadata": {
+ "id": "EKfkmPeNVEhs"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# **Loading Dataset**"
+ ],
+ "metadata": {
+ "id": "Gt2KUswn2UtL"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!wget https://raw.githubusercontent.com/kage08/Samay/refs/heads/main/data/data/ETTh1.csv"
+ ],
+ "metadata": {
+ "id": "mnIsa2TgJgCN"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# **Zero-Shot Imputation Using Moment Model**"
+ ],
+ "metadata": {
+ "id": "G2i1_KoTVHXc"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "train_dataset = MomentDataset(\n",
+ " name=\"ett\",\n",
+ " datetime_col=\"date\",\n",
+ " path=\"/nethome/sli999/TSFMProject/src/tsfmproject/models/moment/data/ETTh1.csv\",\n",
+ " mode=\"train\",\n",
+ " task_name=\"imputation\",\n",
+ ")\n",
+ "test_dataset = MomentDataset(\n",
+ " name=\"ett\",\n",
+ " datetime_col=\"date\",\n",
+ " path=\"/nethome/sli999/TSFMProject/src/tsfmproject/models/moment/data/ETTh1.csv\",\n",
+ " mode=\"test\",\n",
+ " task_name=\"imputation\",\n",
+ ")\n",
+ "# print(len(train_dataset))\n",
+ "# print(len(test_dataset))\n",
+ "# trues, preds, masks = mmt.evaluate(test_dataset, task_name='imputation')"
+ ],
+ "metadata": {
+ "id": "WhjXYZLsVJr1"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# **Visualization of Zero-Shot Imputation**"
+ ],
+ "metadata": {
+ "id": "1UEUE0QGVSdA"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# import matplotlib.pyplot as plt\n",
+ "\n",
+ "# print(trues.shape, preds.shape, masks.shape)\n",
+ "\n",
+ "# mse = np.mean((trues[masks==0] - preds[masks==0])**2)\n",
+ "# mae = np.mean(np.abs(trues[masks==0] - preds[masks==0]))\n",
+ "# print(f'MSE: {mse}, MAE: {mae}')\n",
+ "\n",
+ "# idx = np.random.randint(trues.shape[0])\n",
+ "# channel_idx = np.random.randint(trues.shape[1])\n",
+ "\n",
+ "# fig, axs = plt.subplots(2, 1, figsize=(10, 5))\n",
+ "# axs[0].set_title(f\"Channel={channel_idx}\")\n",
+ "# axs[0].plot(trues[idx, channel_idx, :].squeeze(), label='Ground Truth', c='darkblue')\n",
+ "# axs[0].plot(preds[idx, channel_idx, :].squeeze(), label='Predictions', c='red')\n",
+ "# axs[0].legend(fontsize=16)\n",
+ "\n",
+ "# axs[1].imshow(np.tile(masks[np.newaxis, idx, channel_idx], reps=(8, 1)), cmap='binary')\n",
+ "# plt.show()\n",
+ "mmt.plot(test_dataset, task_name='imputation')"
+ ],
+ "metadata": {
+ "id": "ldr2IIb1VWGj"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# **Finetune the Moment Model for Imputation**"
+ ],
+ "metadata": {
+ "id": "kShP864XVZoB"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "fintuned_model = mmt.finetune(train_dataset, task_name=\"imputation\")"
+ ],
+ "metadata": {
+ "id": "a4aYBmjDVcSx"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# **Evaluate the finetuned model and visualization**"
+ ],
+ "metadata": {
+ "id": "Hz1cVDpVVd0S"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# trues, preds, masks = mmt.evaluate(test_dataset, task_name='imputation')\n",
+ "# mse = np.mean((trues[masks==0] - preds[masks==0])**2)\n",
+ "# mae = np.mean(np.abs(trues[masks==0] - preds[masks==0]))\n",
+ "# print(f'MSE: {mse}, MAE: {mae}')\n",
+ "\n",
+ "# idx = np.random.randint(trues.shape[0])\n",
+ "# channel_idx = np.random.randint(trues.shape[1])\n",
+ "\n",
+ "# fig, axs = plt.subplots(2, 1, figsize=(10, 5))\n",
+ "# axs[0].set_title(f\"Channel={channel_idx}\")\n",
+ "# axs[0].plot(trues[idx, channel_idx, :].squeeze(), label='Ground Truth', c='darkblue')\n",
+ "# axs[0].plot(preds[idx, channel_idx, :].squeeze(), label='Predictions', c='red')\n",
+ "# axs[0].legend(fontsize=16)\n",
+ "\n",
+ "# axs[1].imshow(np.tile(masks[np.newaxis, idx, channel_idx], reps=(8, 1)), cmap='binary')\n",
+ "# plt.show()\n",
+ "mmt.plot(test_dataset, task_name='imputation')"
+ ],
+ "metadata": {
+ "id": "Yi0FjTCVVflk"
+ },
+ "execution_count": null,
+ "outputs": []
+ }
+ ]
+}