Skip to content

Gifteval done MOIRAI example works #36

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion config/chronos.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
{
"repo": "amazon/chronos-t5-small"
"repo": "amazon/chronos-t5-small",
"config": {
"num_layers": 6
}
}
56 changes: 30 additions & 26 deletions example/chronos.ipynb

Large diffs are not rendered by default.

75 changes: 18 additions & 57 deletions example/chronosbolt.ipynb

Large diffs are not rendered by default.

44 changes: 12 additions & 32 deletions example/lptm.ipynb

Large diffs are not rendered by default.

493 changes: 493 additions & 0 deletions example/moirai.ipynb

Large diffs are not rendered by default.

138 changes: 57 additions & 81 deletions example/moment_forecasting.ipynb

Large diffs are not rendered by default.

169 changes: 58 additions & 111 deletions example/timesfm.ipynb

Large diffs are not rendered by default.

161 changes: 52 additions & 109 deletions example/tinytimemixer.ipynb

Large diffs are not rendered by default.

21 changes: 16 additions & 5 deletions leaderboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import pandas as pd
import time
import datetime
import torch
import gc

Expand Down Expand Up @@ -156,8 +157,8 @@ def calc_pred_and_context_len(freq):


if __name__ == "__main__":

for model_name in ["moment"]:
mod_times = {}
for model_name in ["chronos"]:
print(f"Evaluating model: {model_name}")
# create csv file for leaderboard if not already created
csv_path = f"leaderboard/{model_name}.csv"
Expand Down Expand Up @@ -188,7 +189,11 @@ def calc_pred_and_context_len(freq):
arg_path = "config/lptm.json"
args = load_args(arg_path)

mod_start = time.time()
mod_timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
for fpath, attrs in NAMES.items():
print(f"Model eval started at: {mod_timestamp}")
print(f"Evaluating {fname} ({freq}) started at: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
if SERIES == "monash":
freq = attrs[0]
horizon = attrs[1]
Expand Down Expand Up @@ -266,7 +271,7 @@ def calc_pred_and_context_len(freq):
args["config"]["context_length"] = dataset.horizon_len
model = ChronosModel(**args)
start = time.time()
metrics = model.evaluate(dataset, horizon_len=dataset_config["prediction_length"], quantile_levels=[0.1, 0.5, 0.9])
metrics = model.evaluate(dataset, horizon_len=dataset.horizon_len, quantile_levels=[0.1, 0.5, 0.9])
end = time.time()
print(f"Size of dataset: {fs:.2f} MB")
print(f"Time taken for evaluation of {fname}: {end-start:.2f} seconds")
Expand All @@ -281,7 +286,7 @@ def calc_pred_and_context_len(freq):
model = ChronosBoltModel(repo=repo)
dataset = ChronosBoltDataset(datetime_col='timestamp', path=dataset_path, mode='test', batch_size=8, context_len=context_len, horizon_len=pred_len, boundaries=[-1, -1, -1])
start = time.time()
metrics = model.evaluate(dataset, horizon_len=pred_len, quantile_levels=[0.1, 0.5, 0.9])
metrics = model.evaluate(dataset, horizon_len=dataset.horizon_lend, quantile_levels=[0.1, 0.5, 0.9])
end = time.time()
print(f"Size of dataset: {fs:.2f} MB")
print(f"Time taken for evaluation of {fname}: {end-start:.2f} seconds")
Expand Down Expand Up @@ -360,4 +365,10 @@ def calc_pred_and_context_len(freq):
new_row = pd.DataFrame([{**{"dataset": row_name, "size_in_MB":round(fs,2), "eval_time":str(round(eval_time,2)) + unit}, **metrics}])
df = pd.concat([df, new_row], ignore_index=True)

df.to_csv(csv_path, index=False)
df.to_csv(csv_path, index=False)
mod_end = time.time()
print(f"Time taken for model {model_name}: {mod_end-mod_start:.2f} seconds")
mod_times[model_name] = round(mod_end - mod_start,2)

print("All models evaluated!")
print("Model evaluation times: ", mod_times)
52 changes: 52 additions & 0 deletions leaderboard/chronos.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
dataset,size_in_MB,eval_time,mse,mae,mase,mape,rmse,nrmse,smape,msis,nd,mwsq,crps
us_births (M),0.0,1.46s,1070716255.6687956,26338.37321553956,,0.082518286383001,32721.800923372102,0.423134031437707,0.088048193854164,0.3613017791188073,0.084168461693457,7173196090.338237,308769509665836.6
ett1 (W),0.01,0.53s,3099526.260364619,1193.7378923969372,,0.2930298804892131,1760.5471480095666,0.1912749813153974,0.512972149846269,0.1501171918408109,0.4785215521881815,248400.7417727985,6185510031.543578
ett2 (W),0.01,0.55s,14225512.48481565,2829.765762766776,,0.0947369165889474,3771.6723724119583,0.0931368103877156,0.6658664286407251,0.0614323845511389,0.3232260710576363,4668417.230741796,171132961105.95767
saugeenday (M),0.02,6.94s,627.8970473209555,15.259393398050973,0.6148912009158484,0.4136521061096469,25.05787395851762,0.1729052421341959,0.5380455088157352,0.1056241056607616,0.5129422540443951,212.56665126685607,10498.319296781188
us_births (W),0.02,15.27s,9883877.88392094,2404.706848711204,0.7868247507819572,0.0343006028613132,3143.86352819599,0.1441213682385063,0.0336398871239602,0.1879783149683595,0.0339719342505688,-87874193.55386622,722531947563.6707
ett1 (D),0.05,13.89s,77703.79733327327,182.22849569134,1.106771804823633,0.7536703572630021,278.7540086407248,0.1230130131160192,0.5339657300704401,0.1357502763936661,0.5316404110236946,-58567.70429485454,44544945.64045838
ett2 (D),0.05,13.89s,290622.9389991356,383.395674190513,1.222664009407882,178318.10799518676,539.0945547852766,0.0884058810431536,0.525808980136574,0.0601772912197115,0.3138714841483165,-224684.49591576247,610102362.5763342
solar (W),0.06,4.47s,2574359.100198917,1159.7615232126495,0.5230783571658404,35.64832805621711,1604.480944168212,0.0724966312474921,0.2063676123776172,0.0965741863536925,0.2161131928984082,-6323313.778421548,17291278705.404995
saugeenday (W),0.07,105.23s,1380.3853343704643,18.443613443697465,0.6305865483991746,0.5303415403574707,37.15353730629783,0.0953737777596653,0.5575299528058556,0.0923766407498736,0.606279591713562,218.53095006708224,21989.64635487761
jena_weather (D),0.08,5.24s,1271.9936627175637,14.121910989623458,0.0742661785147763,1406.5664033447458,35.66502015585528,0.0279164292329496,0.4283532478368708,0.0114705977766313,0.1266684432548202,-231.3159277629416,44416.8338367404
us_births (D),0.13,272.8s,360785.0886544891,357.3637462902868,0.3196578758011552,0.0380578709094196,600.6538842415731,0.097256133948998,0.0368721390324203,0.0770706088061421,0.0371304916646136,-734505.0470849834,3512340257.3552613
hierarchical_sales (W),0.15,17.09s,1050.8863073881837,12.385761262992412,0.5013297131127319,8710.79989978191,32.417376627176104,0.0397759217538857,0.7063815171620929,0.1075096258282307,0.5677228058839355,114.14585764827432,13173.016635400329
bizitobs_l2c (H),0.18,245.76s,214.63596418922825,10.218923888869194,1.5383591161749426,259170.56359589697,14.650459521435778,0.1232857612326743,0.9878199932046644,0.1763220166102508,0.6755374322073854,4.313762415615535,3202.268589669489
M_DENSE (D),0.21,66.0s,15176.088346706623,54.42842138489396,0.687189898774255,17916.080030597608,123.19126733136008,0.0353566095939711,0.1364537298466204,0.0229198268498131,0.1031914171261527,-272.1150118669217,11433415.049366308
covid_deaths (D),0.27,28.37s,22957920.907926608,880.7822132616342,0.2117632298955713,775.9698879260366,4791.442466306635,0.0276678916139554,0.5040606507142434,0.0462691824356696,0.3939194184600028,1476277.85878325,28651818879.280907
bizitobs_application (10S),0.33,358.9s,15557706.38233198,1401.1723191371186,2.077995305415274,0.1009440527334787,3944.3258463686766,0.0691195393888243,0.0951309832302461,0.0230686644834092,0.0561494504795159,-23549306.816979617,527651382178.3079
solar (D),0.35,23.26s,206800.62638232368,302.05789660582235,0.903028528562988,91.00958367480553,454.7533687421388,0.1271415020537425,0.4419523620124611,0.157557902013762,0.4168723534169756,-288752.95385953854,217677277.455618
hospital (M),0.35,34.74s,6688.969960054369,24.716277092444088,0.0547121093414436,0.2318290427882759,81.78612327317128,0.0067647744586868,0.2291511728547863,0.0621418113448518,0.0891729259218946,1132.9050183317645,1520771.093566524
saugeenday (D),0.38,786.38s,1781.8092213153784,19.34995787344081,1.1305251278712,0.646553373401029,42.211482102804425,0.066193322002307,0.5580092844387878,0.1068818458612478,0.6455013709597474,53.350101235814805,77913.61217554902
car_parts_with_missing (M),0.58,82.07s,1.2902717529707168,0.4588652940376371,0.6510613888995413,9969.40856703676,1.1359012954349144,0.0283975252864915,1.5694098392944826,0.113665897969186,1.0536828126470517,0.168018356291267,0.4966927660013864
electricity (W),0.66,38.87s,380823236582.31104,67351.21224253096,0.1133829471113846,10256554.59160086,617108.7720834238,0.0112945805101854,0.2815844816731884,0.2015180235290622,0.1731633897109376,-2866586460.8839283,2.623461745640408e+16
hierarchical_sales (D),0.9,40.55m,28.988754085780112,2.475848692515505,4.235618863260359,47420.91324963331,5.384120548964344,0.0283374750820504,0.9366422238385608,0.1262872707568716,0.7601434814176691,2.143709788328041,75.0708508600737
kdd_cup_2018_with_missing (D),1.08,42.88s,2816.4983112556693,24.18689539664323,0.5607849519207019,1.9249712841896447,53.07069164101472,0.0301281245187246,0.7485277601190685,0.1122912149614622,0.6324435191203909,370.8475677671445,46660.84743076861
LOOP_SEATTLE (D),1.13,52.83s,20.27796869216717,3.085326835740117,0.4135521420162194,113.71166262161324,4.503106560161237,0.0580882041626706,0.068666173762054,0.0446772150541244,0.0573819077241711,-46.873068285613165,1129.432381413238
SZ_TAXI (H),1.14,349.96s,7.449679432743471,1.847635665351579,2.1210984698630138,52.260896502502014,2.729410088781726,0.0478742902467435,0.319943864300778,0.0583274904547935,0.1763071930979095,3.183231431423246,73.87325857711251
ett1 (H),1.22,32.33m,128.99806681703785,5.960723658957624,0.4756073012353153,5479.840693430655,11.35773158764715,0.0403378658010352,0.4145933017517849,0.0630632021843624,0.3337127510690003,-12.459825340052234,2855.692640038825
ett2 (H),1.26,32.11m,356.52685242197504,9.885463489539712,0.8241745202484587,33215.44812622008,18.88191866368392,0.0344378937372534,0.3168701488597072,0.0513021560458156,0.1467239994328198,-62.23530813696315,34539.6915894645
jena_weather (H),1.65,67.81m,933.6369950899513,7.138803043405957,1.3123364989283992,39092.49588423617,30.55547406095921,0.0102654912724198,0.3476992085409913,0.0079178643247827,0.0669706796447115,130.77318668097283,179758.94099407512
bizitobs_l2c (5T),1.68,60.24m,48.80109265881645,3.725292456138251,3.653111685783605,45197.70324987285,6.985777885018707,0.0577336967577003,0.5722011635931349,0.0737591663548156,0.242361829828971,-8.041869188075735,1017.370235445925
restaurant (D),1.77,126.72s,214.9445405859645,9.261560480049438,0.5588257739199045,4.601487187462462,14.660987026321404,0.0167172028040471,0.5917429178508737,0.1342944317800294,0.4564819532008836,69.97731253715945,3577.847514000256
m4_hourly (H),2.43,43.19m,411239.5034629393,84.44613490579154,0.2408447057814572,10.649611282729468,641.2795829144565,0.0011343857050169,0.0838021294046007,0.1193250222833427,0.0117625377544244,366434.0697115722,42791135542.7828
bizitobs_service (10S),3.07,103.65m,36385.41431350356,51.13304102231828,3.967831241045328,298.426147608331,190.7496115684212,0.0183008357848424,0.1637362310515734,0.0283602689150836,0.0437434102087924,-46196.26490158782,82649059.01267329
M_DENSE (H),3.7,142.83m,30438.209732862622,72.26826032945687,0.3251431745108055,22574.941828348303,174.4654972562272,0.0311323156575489,0.1874125659019757,0.0327889861994436,0.1497631397996068,-2742.903610287603,20793800.38531069
ett1 (15T),4.4,134.55m,7.902966634820784,1.4070990881329977,0.965916470228538,1058.7241150459984,2.8112215556268034,0.0388429707795215,0.40611685038138,0.0584102775492825,0.3078892638652236,0.1913912653977594,37.16699395992611
ett2 (15T),4.57,131.98m,11.669641830082558,1.923974617605437,1.1296687530011345,9056.622500184922,3.416085746886714,0.0245135481450341,0.2941813921797,0.0445991624877151,0.1131694301403363,-0.1732480604518405,281.0344935534655
SZ_TAXI (15T),4.58,99.36m,22.7552028695056,3.0972015729500275,6.544316948989107,88.8631842496993,4.770241384825887,0.0709908690543078,0.4768412653501147,0.0858899442583421,0.285432235529781,12.969547112162374,210.08519533057887
electricity (D),4.63,86.77m,3847221792.3956194,6682.674150709679,7.138597957366795,759064.076810113,62025.97675486956,0.0095807810866193,0.135746750278681,0.2004144795297502,0.124884380758456,-446277123.0454973,87900049059688.58
solar (H),5.97,309.1m,1104.1975335803993,13.828688231471151,1.4813141915713617,14864.482316737676,33.22946784979259,0.0652774132148481,0.3481412353498579,0.1184332105474461,0.6435343953216613,-343.97791496435684,41461.18927455779
bitbrains_rnd (H),6.1,31.94m,1685807.7540758748,157.94062796198395,4.5976783598922655,7042.046885232568,1298.386596540443,0.0593443712053878,0.4695872744492119,0.1147877352507691,0.7166549130859665,11836.713267604291,479107665.7287495
m4_weekly (W-SUN),7.18,210.75m,233461.3473261553,68.67040650871243,1.8109400576767407,56.97187734574267,483.1783804415873,0.0116978974878563,0.0188352063096125,0.0136681064706709,0.012796277695779,-141906.65364825667,1462504939.78642
jena_weather (10T),7.18,436.6m,4235.588143111297,18.04743556410679,5.818739352321612,764089.7799143484,65.08139629042464,0.0057504997767107,0.421382865699952,0.0172230188558686,0.1661592748021694,12.438611171537069,250325.44320036672
kdd_cup_2018_with_missing (H),14.28,734.71m,68346.45840357346,89.9252662461674,56.97534126587538,42.77638372757119,261.4315558680196,0.0795906947581552,0.5645173240001358,0.4240539079215995,1.386456167563295,-5727.921547585381,7283561.359337727
bitbrains_fast_storage (H),15.64,78.75m,3566226.2155539,334.4988766711608,6.068988060989243,2654222.096550064,1888.4454494514528,0.0860224001143914,0.4375045265535703,0.2129106719138004,0.9760974821843668,-307320.8729587518,6402495684.138967
LOOP_SEATTLE (H),27.05,721.51m,46.68375526605733,3.878665272570364,8.535666702526681,92.53679131319464,6.8325511535631644,0.0920278007885866,0.0923765603848125,0.0564560149200861,0.0692173883805199,77.93875607251898,2607.36530773479
solar (10T),33.4,2221.24m,70.19500014749424,5.329740013660233,13.14189276278725,145295.1140406266,8.378245648552818,0.0942434725097648,1.243226802768305,0.3394038595613324,2.3019819289256547,-30.049849823263514,574.9339455141185
m4_yearly (A-DEC),51.4,55.75m,79908.69161742584,90.30393807643414,0.0177415755081005,0.0922809309469351,282.6812544500003,0.0018742955472169,0.0160454754854022,0.0182655408593853,0.0134168003871133,127680.82547845124,645158955.0087659
bitbrains_rnd (5T),63.69,2205.07m,4466584.184806083,275.8740227711492,120.837249288525,70848.7190274737,2113.4294842284385,0.0965924591502191,0.5309859230003978,0.0861645439921316,0.7473691867685713,31501.42173768196,655669077.0786965
electricity (H),110.58,3501.66m,1209370.8818536743,171.39892390230494,43.79730072903688,30947.12460372161,1099.7140000262225,0.0051581332081362,0.1343634508114718,0.1534817952898222,0.084284083460593,-23752.182307856423,2102830894.5911813
temperature_rain_with_missing (D),113.99,1002.36m,205.03022190583,6.241181752835473,6.657853939949664,51694.74606395123,14.31887641911299,0.015093155126153,1.2784206178794175,0.0676423399351983,0.7235590725008252,22.109837039168767,1586.5362851703171
bitbrains_fast_storage (5T),160.06,5462.14m,4541854.873548941,300.5792547371016,93.51763828613898,24035.316556981328,2131.162798462131,0.097041617769732,0.4695270773816692,0.1494352858001602,0.6638371169694463,43146.85142794323,4339175665.723207
m4_quarterly (Q-DEC),163.93,1884.55m,306778.53303239355,68.1215126598923,0.3839067096648085,0.011150869718858605,553.8759184441886,0.010908867278575226,0.009280502195086129,0.018839454020718998,0.011092976086835439,-180253.41966356008,2243600249.334237
Loading