Skip to content

Adding the colab notebooks for the example notebooks #20

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 213 additions & 0 deletions example/colab/chronos.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyOHc9Zna3DRB0e4ryzjGrBk",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/Showmick119/Samay/blob/main/chronos.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"# **Install Dependencies**"
],
"metadata": {
"id": "tpaQlDVlxetV"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "259--YZOr1Lo"
},
"outputs": [],
"source": [
"!pip install samay-0.1.0-cp311-cp311-linux_x86_64.whl"
]
},
{
"cell_type": "code",
"source": [
"!pip install --upgrade -U numpy --force"
],
"metadata": {
"id": "VGf7Oy_JxpYO"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# **Importing Requirements**"
],
"metadata": {
"id": "0cbiLmxU2Lw-"
}
},
{
"cell_type": "code",
"source": [
"import os\n",
"import sys\n",
"import torch\n",
"import numpy as np\n",
"\n",
"src_path = os.path.abspath(os.path.join(\"..\", \"src\"))\n",
"if src_path not in sys.path:\n",
" sys.path.insert(0, src_path)\n",
"\n",
"from samay.model import ChronosModel\n",
"from samay.dataset import ChronosDataset\n",
"# from tsfmproject.utils import load_args\n",
"\n",
"# arg_path = \"../config/timesfm.json\"\n",
"# args = load_args(arg_path)\n",
"repo = \"amazon/chronos-t5-small\"\n",
"chronos_model = ChronosModel(repo=repo)"
],
"metadata": {
"id": "APfhHKAU1qM3"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# **Loading Dataset**"
],
"metadata": {
"id": "Gt2KUswn2UtL"
}
},
{
"cell_type": "code",
"source": [
"!wget https://github.com/kage08/Samay/refs/heads/main/data/data/ETTh1.csv"
],
"metadata": {
"id": "mnIsa2TgJgCN"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"train_dataset = ChronosDataset(name=\"ett\", datetime_col='date', path='../src/tsfmproject/models/moment/data/ETTh1.csv',\n",
" mode='train', batch_size=8)\n",
"val_dataset = ChronosDataset(name=\"ett\", datetime_col='date', path='../src/tsfmproject/models/moment/data/ETTh1.csv',\n",
" mode='test', batch_size=8)"
],
"metadata": {
"id": "MyUXzOME1sgo"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# **Visualize the zero-shot forecasting**"
],
"metadata": {
"id": "OgcxscvD3Ee8"
}
},
{
"cell_type": "code",
"source": [
"chronos_model.plot(val_dataset, horizon_len=64, quantile_levels=[0.1, 0.5, 0.9])"
],
"metadata": {
"id": "K9Zmx3li1u5k"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# **Evaluate the zero-shot Chronos Model**"
],
"metadata": {
"id": "2I0VE7Uu3Jqt"
}
},
{
"cell_type": "code",
"source": [
"metrics = chronos_model.evaluate(val_dataset, horizon_len=64, quantile_levels=[0.1, 0.5, 0.9])\n",
"print(metrics)"
],
"metadata": {
"id": "Z90jyPsp1ykH"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# **Finetune Chronos Model on the ETT dataset**"
],
"metadata": {
"id": "iuzgw9uA3idf"
}
},
{
"cell_type": "code",
"source": [
"chronos_model.finetune(train_dataset)"
],
"metadata": {
"id": "LeFA_RVg12IS"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# **Evaluate the Finetuned Chronos Model**"
],
"metadata": {
"id": "DE5-HDC53n4_"
}
},
{
"cell_type": "code",
"source": [
"metrics = chronos_model.evaluate(val_dataset, horizon_len=64, quantile_levels=[0.1, 0.5, 0.9])\n",
"print(metrics)"
],
"metadata": {
"id": "XF2bgAXi125y"
},
"execution_count": null,
"outputs": []
}
]
}
Loading